1313import time
1414from multiprocessing import cpu_count
1515
16-
1716from ldm .util import instantiate_from_config , parallel_data_prefetch
1817from ldm .models .diffusion .ddim import DDIMSampler
1918from ldm .models .diffusion .plms import PLMSSampler
2019from ldm .modules .encoders .modules import FrozenClipImageEmbedder , FrozenCLIPTextEmbedder
2120
22-
2321DATABASES = [
2422 "openimages" ,
2523 "artbench-art_nouveau" ,
@@ -69,8 +67,8 @@ def __init__(self, database, retriever_version='ViT-L/14'):
6967 self .database_path = f'data/retrieval_databases/{ self .database_name } '
7068 self .retriever = self .load_retriever (version = retriever_version )
7169 self .database = {'embedding' : [],
72- 'img_id' : [],
73- 'patch_coords' : []}
70+ 'img_id' : [],
71+ 'patch_coords' : []}
7472 self .load_database ()
7573 self .load_searcher ()
7674
@@ -105,10 +103,8 @@ def load_multi_files(self, data_archive):
105103
106104 def load_database (self ):
107105
108-
109106 print (f'Load saved patch embedding from "{ self .database_path } "' )
110- file_content = glob .glob (os .path .join (self .database_path ,'*.npz' ))
111-
107+ file_content = glob .glob (os .path .join (self .database_path , '*.npz' ))
112108
113109 if len (file_content ) == 1 :
114110 self .load_single_file (file_content [0 ])
@@ -117,15 +113,14 @@ def load_database(self):
117113 prefetched_data = parallel_data_prefetch (self .load_multi_files , data ,
118114 n_proc = min (len (data ), cpu_count ()), target_data_type = 'dict' )
119115
120- self .database = {key : np .concatenate ([od [key ] for od in prefetched_data ], axis = 1 )[0 ] for key in self .database }
116+ self .database = {key : np .concatenate ([od [key ] for od in prefetched_data ], axis = 1 )[0 ] for key in
117+ self .database }
121118 else :
122119 raise ValueError (f'No npz-files in specified path "{ self .database_path } " is this directory existing?' )
123120
124121 print (f'Finished loading of retrieval database of length { self .database ["embedding" ].shape [0 ]} .' )
125122
126-
127- def load_retriever (self ,version = 'ViT-L/14' ,):
128-
123+ def load_retriever (self , version = 'ViT-L/14' , ):
129124 model = FrozenClipImageEmbedder (model = version )
130125 if torch .cuda .is_available ():
131126 model .cuda ()
@@ -137,15 +132,14 @@ def load_searcher(self):
137132 self .searcher = scann .scann_ops_pybind .load_searcher (self .searcher_savedir )
138133 print ('Finished loading searcher.' )
139134
140-
141135 def search (self , x , k ):
142136 if self .searcher is None and self .database ['embedding' ].shape [0 ] < 2e4 :
143- self .train_searcher (k )
137+ self .train_searcher (k ) # quickly fit searcher on the fly for small databases
144138 assert self .searcher is not None , 'Cannot search with uninitialized searcher'
145- if isinstance (x ,torch .Tensor ):
139+ if isinstance (x , torch .Tensor ):
146140 x = x .detach ().cpu ().numpy ()
147141 if len (x .shape ) == 3 :
148- x = x [:,0 ]
142+ x = x [:, 0 ]
149143 query_embeddings = x / np .linalg .norm (x , axis = 1 )[:, np .newaxis ]
150144
151145 start = time .time ()
@@ -169,9 +163,11 @@ def search(self, x, k):
169163 def __call__ (self , x , n ):
170164 return self .search (x , n )
171165
166+
172167if __name__ == "__main__" :
173168 parser = argparse .ArgumentParser ()
174169 # TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc)
170+ # TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt?
175171 parser .add_argument (
176172 "--prompt" ,
177173 type = str ,
@@ -293,7 +289,7 @@ def __call__(self, x, n):
293289 type = str ,
294290 default = DATABASES [0 ],
295291 choices = DATABASES ,
296- help = "The database used for the search" ,
292+ help = "The database used for the search, only applied when --use_neighbors=True " ,
297293 )
298294 parser .add_argument (
299295 "--use_neighbors" ,
@@ -308,7 +304,6 @@ def __call__(self, x, n):
308304 help = "The number of included neighbors, only applied when --use_neighbors=True" ,
309305 )
310306
311-
312307 opt = parser .parse_args ()
313308
314309 config = OmegaConf .load (f"{ opt .config } " )
@@ -344,7 +339,7 @@ def __call__(self, x, n):
344339 os .makedirs (sample_path , exist_ok = True )
345340 base_count = len (os .listdir (sample_path ))
346341 grid_count = len (os .listdir (outpath )) - 1
347-
342+
348343 print (f"sampling scale for cfg is { opt .scale :.2f} " )
349344
350345 searcher = None
@@ -360,15 +355,15 @@ def __call__(self, x, n):
360355 if isinstance (prompts , tuple ):
361356 prompts = list (prompts )
362357 c = clip_text_encoder .encode (prompts )
358+ uc = None
363359 if searcher is not None :
364- nn_dict = searcher (c ,opt .knn )
365- c = torch .cat ([c ,torch .from_numpy (nn_dict ['nn_embeddings' ]).cuda ()],dim = 1 )
366- uc = None
360+ nn_dict = searcher (c , opt .knn )
361+ c = torch .cat ([c , torch .from_numpy (nn_dict ['nn_embeddings' ]).cuda ()], dim = 1 )
367362 if opt .scale != 1.0 :
368363 uc = torch .zeros_like (c )
369364 if isinstance (prompts , tuple ):
370365 prompts = list (prompts )
371- shape = [16 , opt .H // 16 , opt .W // 16 ] # note: currently hardcoded for f16 model
366+ shape = [16 , opt .H // 16 , opt .W // 16 ] # note: currently hardcoded for f16 model
372367 samples_ddim , _ = sampler .sample (S = opt .ddim_steps ,
373368 conditioning = c ,
374369 batch_size = c .shape [0 ],
@@ -380,11 +375,12 @@ def __call__(self, x, n):
380375 )
381376
382377 x_samples_ddim = model .decode_first_stage (samples_ddim )
383- x_samples_ddim = torch .clamp ((x_samples_ddim + 1.0 )/ 2.0 , min = 0.0 , max = 1.0 )
378+ x_samples_ddim = torch .clamp ((x_samples_ddim + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
384379
385380 for x_sample in x_samples_ddim :
386381 x_sample = 255. * rearrange (x_sample .cpu ().numpy (), 'c h w -> h w c' )
387- Image .fromarray (x_sample .astype (np .uint8 )).save (os .path .join (sample_path , f"{ base_count :05} .png" ))
382+ Image .fromarray (x_sample .astype (np .uint8 )).save (
383+ os .path .join (sample_path , f"{ base_count :05} .png" ))
388384 base_count += 1
389385 all_samples .append (x_samples_ddim )
390386
0 commit comments