match_extracts_scene_mean.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # main imports
  2. import argparse
  3. import numpy as np
  4. import sys, os
  5. # mongo import
  6. from pymongo import MongoClient
  7. # modules imports
  8. sys.path.insert(0, '') # trick to enable import of main folder module
  9. # config imports
  10. import custom_config as cfg
  11. def main():
  12. parser = argparse.ArgumentParser(description="Get error during calibration experiment for each user")
  13. parser.add_argument('--expeId', type=str, help='Experiment identifier')
  14. parser.add_argument('--experiment', type=str, help='Experiment name', choices=cfg.experiment_list, required=True)
  15. parser.add_argument('--scene', type=str, help='Scene identifier to use', choices=cfg.scenes_indices, required=True)
  16. args = parser.parse_args()
  17. p_expe_id = args.expeId
  18. p_experiment = args.experiment
  19. p_scene = args.scene
  20. # connect to Mongo db and collect data
  21. client = MongoClient(cfg.default_host)
  22. db = client.sin3d
  23. query = {
  24. 'data.msg.experimentName': p_experiment,
  25. 'data.msgId': "EXPERIMENT_VALIDATED"
  26. }
  27. # add of expeid into query if exists
  28. if p_expe_id:
  29. print("Expe id used", p_expe_id)
  30. query['data.experimentId'] = p_expe_id
  31. index = cfg.scenes_indices.index(p_scene.strip())
  32. scene_name = cfg.scenes_names[index]
  33. # from dataset retrieve human thresholds for each zone
  34. zone_thresholds = []
  35. scene_folder = os.path.join(cfg.dataset_path, scene_name)
  36. zone_folders = sorted([zone for zone in os.listdir(scene_folder) if 'zone' in zone])
  37. for zone in zone_folders:
  38. threshold_file_path = os.path.join(scene_folder, zone, cfg.seuil_expe_filename)
  39. with open(threshold_file_path, 'r') as f:
  40. current_threshold = int(f.readline())
  41. zone_thresholds.append(current_threshold)
  42. print(zone_thresholds)
  43. print("Scene used", scene_name)
  44. query['data.msg.sceneName'] = scene_name
  45. print(query)
  46. res = db.datas.find(query)
  47. all_experiment_thresholds = []
  48. for cursor in res:
  49. user_data = cursor['data']
  50. user_id = user_data['userId']
  51. experiment_user_thresholds = []
  52. for id, val in enumerate(user_data['msg']['extracts']):
  53. experiment_user_thresholds.append(val['quality'])
  54. all_experiment_thresholds.append(experiment_user_thresholds)
  55. # TODO : Voir pour sauvegarde des fichiers et applications des stats sur les résultats obtenus
  56. print(np.mean(all_experiment_thresholds, axis=0))
  57. if __name__== "__main__":
  58. main()