ATTPCROOT  0.3.0-alpha
A ROOT-based framework for analyzing data from active target detectors
AtMCFitter.cxx
Go to the documentation of this file.
1 #include "AtMCFitter.h"
2 // IWYU pragma: no_include <ext/alloc_traits.h>
3 
4 #include "AtClusterize.h" // for AtClusterize
5 #include "AtDigiPar.h" // for AtDigiPar
6 #include "AtEvent.h" // for AtEvent
7 #include "AtMCResult.h"
8 #include "AtPSA.h" // for AtPSA
10 #include "AtPatternEvent.h" // for AtPatternEvent
11 #include "AtPulse.h" // for AtPulse
12 #include "AtRawEvent.h" // for AtRawEvent
13 #include "AtSimpleSimulation.h" // for AtSimpleSimulation
14 #include "AtSimulatedPoint.h" // IWYU pragma: keep
15 #include "AtSpaceChargeModel.h"
16 
17 #include <FairLogger.h> // for LOG, Logger
18 #include <FairParSet.h> // for FairParSet
19 #include <FairRunAna.h> // for FairRunAna
20 #include <FairRuntimeDb.h> // for FairRuntimeDb
21 
22 #include <TROOT.h>
23 
24 #include <algorithm> // for max
25 #include <chrono>
26 #include <mutex>
27 #include <thread>
28 using std::move;
29 namespace MCFitter {
30 
32  : fMap(pulse->GetMap()), fSim(move(sim)), fClusterize(move(cluster)), fPulse(move(pulse)),
33  fResults([](const AtMCResult &a, const AtMCResult &b) { return a.fObjective < b.fObjective; })
34 {
35 }
36 
37 AtMCFitter::ParamPtr AtMCFitter::GetParameter(const std::string &name) const
38 {
39  if (fParameters.find(name) != fParameters.end()) {
40  return fParameters.at(name);
41  }
42  return nullptr;
43 }
44 
46 {
47  if (num > 1)
48  ROOT::EnableThreadSafety();
49  fNumThreads = num;
50 }
51 
53 {
55 
56  FairRunAna *ana = FairRunAna::Instance();
57  FairRuntimeDb *rtdb = ana->GetRuntimeDb();
58  fPar = dynamic_cast<AtDigiPar *>(rtdb->getContainer("AtDigiPar"));
59 
60  fPulse->SetParameters(fPar);
61  fClusterize->GetParameters(fPar);
62  if (fPSA)
63  fPSA->Init();
64  if (fSim->GetSpaceChargeModel())
65  fSim->GetSpaceChargeModel()->LoadParameters(fPar);
66 
67  fThPulse.resize(fNumThreads);
68  for (int i = 0; i < fNumThreads; ++i)
69  fThPulse[i] = fPulse->Clone();
70 }
71 
72 void AtMCFitter::RunIterRange(int startIter, int numIter, AtPulse *pulse)
73 {
74  // Here we should copy each thread their own version of the clusterize, pulse, and simulation
75  // objects (only if the number of threads is greater than 1). Needs to be deep copies
76 
77  for (int i = 0; i < numIter; ++i) {
78 
79  int idx = startIter + i;
80  auto result = DefineEvent();
81  auto mcPoints = SimulateEvent(result);
82 
83  DigitizeEvent(mcPoints, idx, pulse);
84  double obj = ObjectiveFunction(*fCurrentEvent, idx, result);
85 
86  result.fIterNum = idx;
87  result.fObjective = obj;
88  // result.Print();
89  {
90  std::lock_guard<std::mutex> lk(fResultMutex);
91  fResults.insert(result);
92  }
93  }
94  LOG(debug) << "Done with run iter range";
95 }
96 
97 void AtMCFitter::Exec(const AtPatternEvent &event)
98 {
99  fRawEventArray.clear();
100  fEventArray.clear();
101  fResults.clear();
102 
103  SetParamDistributions(event);
104 
105  // Set the conditions for simulating the event
106  fCurrentEvent = &event;
107 
108  // Make sure the event arrays are large enough so no resizing will happen
109  fRawEventArray.resize(fNumIter);
110  fEventArray.resize(fNumIter);
111 
112  for (int i = 0; i < fNumRounds; ++i) {
113  RunRound();
115  }
116 }
118 {
119  // Begining of round
120  auto start = std::chrono::high_resolution_clock::now();
121 
122  // Get what iterations to do on what thread.
123  std::vector<std::pair<int, int>> threadParam;
124  int iterPerTh = fNumIter / fNumThreads;
125  for (int i = 0; i < fNumThreads; ++i)
126  threadParam.emplace_back(0, iterPerTh);
127  for (int i = 0; i < fNumIter % fNumThreads; ++i)
128  threadParam[i].second++;
129  for (int i = 1; i < fNumThreads; ++i)
130  threadParam[i].first = threadParam[i - 1].first + threadParam[i - 1].second;
131 
132  for (int i = 0; i < threadParam.size(); ++i) {
133  LOG(info) << i << ": " << threadParam[i].first << " " << threadParam[i].second;
134  }
135 
136  std::vector<std::thread> threads;
137  for (int i = 0; i < fNumThreads; ++i) {
138  LOG(debug) << "Creating thread " << i << " with " << threadParam[i].first << " " << threadParam[i].second
139  << " and " << fPulse.get();
140 
141  // Spawn a thread to call RunIterRange.
142  threads.emplace_back(
143  [this](std::pair<int, int> param, AtPulse *pulse) { this->RunIterRange(param.first, param.second, pulse); },
144  threadParam[i], fThPulse[i].get());
145  }
146 
147  // Wait for all threads to finish
148  for (auto &th : threads)
149  th.join();
150 
151  auto stop = std::chrono::high_resolution_clock::now();
152 
153  if (fTimeEvent)
154  LOG(info) << "Simulation of " << fNumIter << " events took "
155  << std::chrono::duration_cast<std::chrono::milliseconds>(stop - start).count() << " ms.";
156 }
157 
158 int AtMCFitter::DigitizeEvent(const TClonesArray &points, int idx, AtPulse *pulse)
159 {
160  // Event has been simulated and is sitting in the fSim
161  auto vec = fClusterize->ProcessEvent(points);
162  LOG(debug) << "Digitizing event at " << idx;
163 
164  fRawEventArray[idx] = pulse->GenerateEvent(vec);
165 
166  if (fPSA) {
167  LOG(debug) << "Running PSA at " << idx;
168  fEventArray[idx] = fPSA->Analyze(fRawEventArray[idx]);
169  }
170  LOG(debug) << "Done digitizing event at " << idx;
171  return idx;
172 }
173 
177 void AtMCFitter::FillResultArrays(TClonesArray &resultArray, TClonesArray &simEvent, TClonesArray &simRawEvent)
178 {
179  resultArray.Delete();
180  simEvent.Delete();
181  simRawEvent.Delete();
182 
183  for (auto &res : fResults) {
184 
185  int clonesIdx = resultArray.GetEntries();
186  int eventIdx = res.fIterNum;
187  LOG(debug) << "Filling iteration " << eventIdx << " at index " << resultArray.GetEntries();
188 
189  new (resultArray[clonesIdx]) AtMCResult(std::move(res));
190  if (clonesIdx < fNumEventsToSave) {
191  new (simEvent[clonesIdx]) AtEvent(std::move(fEventArray[eventIdx]));
192  new (simRawEvent[clonesIdx]) AtRawEvent(std::move(fRawEventArray[eventIdx]));
193  }
194  }
195 
196  fEventArray.clear();
197  fRawEventArray.clear();
198 }
199 
201 {
202  AtMCResult result;
203  for (auto &[name, distro] : fParameters)
204  result.fParameters[name] = distro->Sample();
205  return result;
206 }
208 {
209  for (auto &[name, distro] : fParameters) {
210  AtMCResult result = *fResults.begin();
211  distro->SetMean(result.fParameters[name]);
212  distro->TruncateSpace();
213  }
214 }
215 
216 } // namespace MCFitter
AtParameterDistribution.h
AtRawEvent.h
AtPatternEvent
Definition: AtPatternEvent.h:19
AtEvent.h
MCFitter::AtMCFitter::AtMCFitter
AtMCFitter(SimPtr sim, ClusterPtr cluster, PulsePtr pulse)
Definition: AtMCFitter.cxx:31
MCFitter::AtMCFitter::fPSA
PsaPtr fPSA
Definition: AtMCFitter.h:48
AtPulse::GenerateEvent
AtRawEvent GenerateEvent(std::vector< SimPointPtr > &vec)
Definition: AtPulse.cxx:56
AtSimpleSimulation.h
MCFitter::AtMCFitter::fResultMutex
std::mutex fResultMutex
Store the iteration number sorted by lowest objective funtion.
Definition: AtMCFitter.h:69
AtClusterize.h
AtMCFitter.h
AtMCResult.h
MCFitter::AtMCFitter::FillResultArrays
void FillResultArrays(TClonesArray &resultArray, TClonesArray &simEvent, TClonesArray &simRawEvent)
Definition: AtMCFitter.cxx:177
MCFitter::AtMCFitter::fPar
const AtDigiPar * fPar
Definition: AtMCFitter.h:60
MCFitter::AtMCFitter::fNumRounds
int fNumRounds
Definition: AtMCFitter.h:51
AtEvent
Definition: AtEvent.h:22
MCFitter::AtMCFitter::fNumEventsToSave
int fNumEventsToSave
Definition: AtMCFitter.h:52
MCFitter::AtMCResult::fParameters
ParamMap fParameters
Definition: AtMCResult.h:23
AtRawEvent
Definition: AtRawEvent.h:34
MCFitter::AtMCFitter::Init
void Init()
Definition: AtMCFitter.cxx:52
MCFitter::AtMCFitter::ObjectiveFunction
virtual double ObjectiveFunction(const AtBaseEvent &expEvent, int SimEventID, AtMCResult &definition)=0
This is the thing we are minimizing between events (SimEventID is index in TClonesArray)
AtSimulatedPoint.h
hc::cluster
std::vector< size_t > cluster
Definition: hc.h:25
MCFitter::AtMCFitter::DigitizeEvent
int DigitizeEvent(const TClonesArray &points, int idx, AtPulse *pulse)
Definition: AtMCFitter.cxx:158
MCFitter::AtMCFitter::PulsePtr
std::shared_ptr< AtPulse > PulsePtr
Definition: AtMCFitter.h:37
MCFitter::AtMCFitter::DefineEvent
virtual AtMCResult DefineEvent()
Definition: AtMCFitter.cxx:200
MCFitter::AtMCFitter::RunRound
void RunRound()
Definition: AtMCFitter.cxx:117
MCFitter::AtMCFitter::fRawEventArray
std::vector< AtRawEvent > fRawEventArray
Definition: AtMCFitter.h:64
AtSpaceChargeModel.h
MCFitter::AtMCFitter::fClusterize
ClusterPtr fClusterize
Definition: AtMCFitter.h:46
MCFitter::AtMCFitter::ClusterPtr
std::shared_ptr< AtClusterize > ClusterPtr
Definition: AtMCFitter.h:36
MCFitter::AtMCFitter::fSim
SimPtr fSim
Definition: AtMCFitter.h:45
MCFitter::AtMCResult
Definition: AtMCResult.h:18
MCFitter::AtMCFitter::fNumIter
int fNumIter
Definition: AtMCFitter.h:50
MCFitter::AtMCFitter::CreateParamDistros
virtual void CreateParamDistros()=0
Create the parameter distributions to use for the fit.
AtDigiPar.h
AtDigiPar
Definition: AtDigiPar.h:14
AtPatternEvent.h
MCFitter::AtMCFitter::ParamPtr
std::shared_ptr< AtParameterDistribution > ParamPtr
Definition: AtMCFitter.h:33
MCFitter::AtMCFitter::SimulateEvent
virtual TClonesArray SimulateEvent(AtMCResult &definition)=0
MCFitter::AtMCFitter::fPulse
PulsePtr fPulse
Definition: AtMCFitter.h:47
MCFitter::AtMCFitter::SetNumThreads
void SetNumThreads(int num)
Definition: AtMCFitter.cxx:45
MCFitter::AtMCFitter::SetParamDistributions
virtual void SetParamDistributions(const AtPatternEvent &event)=0
Set parameter distributions (mean/spread) from the event.
MCFitter::AtMCFitter::GetParameter
ParamPtr GetParameter(const std::string &name) const
Definition: AtMCFitter.cxx:37
MCFitter::AtMCFitter::fTimeEvent
bool fTimeEvent
Definition: AtMCFitter.h:53
MCFitter::AtMCFitter::Exec
void Exec(const AtPatternEvent &event)
Definition: AtMCFitter.cxx:97
MCFitter::AtMCFitter::fParameters
std::map< std::string, ParamPtr > fParameters
Definition: AtMCFitter.h:42
MCFitter::AtMCFitter::SimPtr
std::shared_ptr< AtSimpleSimulation > SimPtr
Definition: AtMCFitter.h:34
AtPSA.h
MCFitter::AtMCFitter::RecenterParamDistributions
virtual void RecenterParamDistributions()
Definition: AtMCFitter.cxx:207
MCFitter::AtMCFitter::fEventArray
std::vector< AtEvent > fEventArray
Definition: AtMCFitter.h:65
AtPulse
Definition: AtPulse.h:22
MCFitter::AtMCFitter::fThPulse
std::vector< PulsePtr > fThPulse
Definition: AtMCFitter.h:59
MCFitter::AtMCFitter::fCurrentEvent
const AtPatternEvent * fCurrentEvent
Definition: AtMCFitter.h:58
MCFitter::AtMCFitter::fNumThreads
int fNumThreads
Definition: AtMCFitter.h:54
MCFitter::AtMCFitter::RunIterRange
void RunIterRange(int startIter, int numIter, AtPulse *pulse)
Definition: AtMCFitter.cxx:72
MCFitter::AtMCFitter::fResults
std::set< AtMCResult, std::function< bool(AtMCResult, AtMCResult)> > fResults
Definition: AtMCFitter.h:70
MCFitter
Definition: AtMCResult.cxx:5
AtPulse.h