emnenmf.m 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. function [ T , RMSE ] = emnenmf( W , X , G , F , Omega_G, Omega_F, Phi_G, Phi_F , InnerMinIter , InnerMaxIter , Tmax , v, F_theo, delta_measure)
  2. X0 =X;
  3. Omega_G = (Omega_G == 1); % Logical mask is faster than indexing in matlab.
  4. Omega_F = (Omega_F == 1); % Logical mask is faster than indexing in matlab.
  5. nOmega_G = ~Omega_G; % Logical mask is faster than indexing in matlab.
  6. nOmega_F = ~Omega_F; % Logical mask is faster than indexing in matlab.
  7. [~, num_sensor] = size(F);
  8. num_sensor = num_sensor-1;
  9. em_iter_max = round(Tmax / delta_measure) ;
  10. T = nan(1,em_iter_max);
  11. RMSE = nan(2,em_iter_max);
  12. X = G*F+W.*(X0-G*F);
  13. GG = G'*G;
  14. GX = G'*X;
  15. GradF = GG*F-GX;
  16. FF = F*F';
  17. XF = X*F';
  18. GradG = G*FF-XF;
  19. d = Grad_P([GradG',GradF],[G',F]);
  20. StoppingCritF = 1.e-3*d;
  21. StoppingCritG = StoppingCritF;
  22. tic
  23. i = 1;
  24. T(i) = toc;
  25. RMSE(:,i) = vecnorm(F(:,1:end-1)- F_theo(:,1:end-1),2,2)/sqrt(num_sensor);
  26. niter = 0;
  27. T_E = [];
  28. T_M = [];
  29. while toc<Tmax
  30. t_e = toc;
  31. X = G*F+W.*(X0-G*F);
  32. T_E = cat(1,T_E,toc - t_e);
  33. for j =1:v
  34. t_m = toc;
  35. FF = F*F';
  36. XF = X*F' - Phi_G*FF;
  37. G(Omega_G) = 0; % Convert G to \Delta G
  38. [ G , iterG ] = MaJ_G_EM_NeNMF( FF , XF , G , InnerMinIter , InnerMaxIter , StoppingCritG , nOmega_G); % Update \Delta G
  39. G(Omega_G) = Phi_G(Omega_G); % Convert \Delta G to G
  40. niter = niter + iterG;
  41. if(iterG<=InnerMinIter)
  42. StoppingCritG = 1.e-1*StoppingCritG;
  43. end
  44. GG = G'*G;
  45. GX = G'*X-GG*Phi_F;
  46. F(Omega_F) = 0; % Convert F to \Delta F
  47. [ F , iterF ] = MaJ_F_EM_NeNMF( GG , GX , F , InnerMinIter , InnerMaxIter , StoppingCritF , nOmega_F); % Update \Delta F
  48. F(Omega_F) = Phi_F(Omega_F); % Convert \Delta F to F
  49. niter = niter + iterF;
  50. if(iterF<=InnerMinIter)
  51. StoppingCritF = 1.e-1*StoppingCritF;
  52. end
  53. if toc - i*delta_measure >= delta_measure
  54. i = i+1;
  55. if i > em_iter_max
  56. break
  57. end
  58. T(i) = toc;
  59. RMSE(:,i) = vecnorm(F(:,1:end-1) - F_theo(:,1:end-1),2,2)/sqrt(num_sensor);
  60. end
  61. T_M = cat(1,T_M,toc - t_m);
  62. end
  63. end
  64. niter
  65. disp(['em E step : ',num2str(median(T_E))])
  66. disp(['em M step : ',num2str(median(T_M))])
  67. end