@@ -83,4 +83,113 @@ def get_obj_from_str(string, reload=False):
8383 if reload :
8484 module_imp = importlib .import_module (module )
8585 importlib .reload (module_imp )
86- return getattr (importlib .import_module (module , package = None ), cls )
86+ return getattr (importlib .import_module (module , package = None ), cls )
87+
88+ def _do_parallel_data_prefetch (func , Q , data , idx , idx_to_fn = False ):
89+ # create dummy dataset instance
90+
91+ # run prefetching
92+ if idx_to_fn :
93+ res = func (data ,worker_id = idx )
94+ else :
95+ res = func (data )
96+ Q .put ([idx , res ])
97+ Q .put ("Done" )
98+
99+
100+ def parallel_data_prefetch (
101+ func : callable , data , n_proc , target_data_type = "ndarray" ,cpu_intensive = True ,use_worker_id = False
102+ ):
103+ # if target_data_type not in ["ndarray", "list"]:
104+ # raise ValueError(
105+ # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
106+ # )
107+ if isinstance (data , np .ndarray ) and target_data_type == "list" :
108+ raise ValueError ("list expected but function got ndarray." )
109+ elif isinstance (data , abc .Iterable ):
110+ if isinstance (data , dict ):
111+ print (
112+ f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
113+ )
114+ data = list (data .values ())
115+ if target_data_type == "ndarray" :
116+ data = np .asarray (data )
117+ else :
118+ data = list (data )
119+ else :
120+ raise TypeError (
121+ f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually { type (data )} ."
122+ )
123+
124+ if cpu_intensive :
125+ Q = mp .Queue (1000 )
126+ proc = mp .Process
127+ else :
128+ Q = Queue (1000 )
129+ proc = Thread
130+ # spawn processes
131+ if target_data_type == "ndarray" :
132+ arguments = [
133+ [func , Q , part , i , use_worker_id ]
134+ for i , part in enumerate (np .array_split (data , n_proc ))
135+ ]
136+ else :
137+ step = (
138+ int (len (data ) / n_proc + 1 )
139+ if len (data ) % n_proc != 0
140+ else int (len (data ) / n_proc )
141+ )
142+ arguments = [
143+ [func , Q , part , i , use_worker_id ]
144+ for i , part in enumerate (
145+ [data [i : i + step ] for i in range (0 , len (data ), step )]
146+ )
147+ ]
148+ processes = []
149+ for i in range (n_proc ):
150+ p = proc (target = _do_parallel_data_prefetch , args = arguments [i ])
151+ processes += [p ]
152+
153+ # start processes
154+ print (f"Start prefetching..." )
155+ import time
156+
157+ start = time .time ()
158+ gather_res = [[] for _ in range (n_proc )]
159+ try :
160+ for p in processes :
161+ p .start ()
162+
163+ k = 0
164+ while k < n_proc :
165+ # get result
166+ res = Q .get ()
167+ if res == "Done" :
168+ k += 1
169+ else :
170+ gather_res [res [0 ]] = res [1 ]
171+
172+ except Exception as e :
173+ print ("Exception: " , e )
174+ for p in processes :
175+ p .terminate ()
176+
177+ raise e
178+ finally :
179+ for p in processes :
180+ p .join ()
181+ print (f"Prefetching complete. [{ time .time () - start } sec.]" )
182+
183+ if target_data_type == 'ndarray' :
184+ if not isinstance (gather_res [0 ], np .ndarray ):
185+ return np .concatenate ([np .asarray (r ) for r in gather_res ], axis = 0 )
186+
187+ # order outputs
188+ return np .concatenate (gather_res , axis = 0 )
189+ elif target_data_type == 'list' :
190+ out = []
191+ for r in gather_res :
192+ out .extend (r )
193+ return out
194+ else :
195+ return gather_res
0 commit comments