Skip to content

Commit 78ef9ca

Browse files
author
ablattmann
committed
add databases to inference script
1 parent c8a6b66 commit 78ef9ca

File tree

3 files changed

+322
-50
lines changed

3 files changed

+322
-50
lines changed

ldm/modules/encoders/modules.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import torch
22
import torch.nn as nn
33
from functools import partial
4+
import clip
5+
from einops import rearrange, repeat
6+
import kornia
7+
48

59
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
610

@@ -129,3 +133,70 @@ def forward(self,x):
129133

130134
def encode(self, x):
131135
return self(x)
136+
137+
138+
class FrozenCLIPTextEmbedder(nn.Module):
139+
"""
140+
Uses the CLIP transformer encoder for text.
141+
"""
142+
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
143+
super().__init__()
144+
self.model, _ = clip.load(version, jit=False, device="cpu")
145+
self.device = device
146+
self.max_length = max_length
147+
self.n_repeat = n_repeat
148+
self.normalize = normalize
149+
150+
def freeze(self):
151+
self.model = self.model.eval()
152+
for param in self.parameters():
153+
param.requires_grad = False
154+
155+
def forward(self, text):
156+
tokens = clip.tokenize(text).to(self.device)
157+
z = self.model.encode_text(tokens)
158+
if self.normalize:
159+
z = z / torch.linalg.norm(z, dim=1, keepdim=True)
160+
return z
161+
162+
def encode(self, text):
163+
z = self(text)
164+
if z.ndim==2:
165+
z = z[:, None, :]
166+
z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
167+
return z
168+
169+
170+
class FrozenClipImageEmbedder(nn.Module):
171+
"""
172+
Uses the CLIP image encoder.
173+
"""
174+
def __init__(
175+
self,
176+
model,
177+
jit=False,
178+
device='cuda' if torch.cuda.is_available() else 'cpu',
179+
antialias=False,
180+
):
181+
super().__init__()
182+
self.model, _ = clip.load(name=model, device=device, jit=jit)
183+
184+
self.antialias = antialias
185+
186+
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
187+
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
188+
189+
def preprocess(self, x):
190+
# normalize to [0,1]
191+
x = kornia.geometry.resize(x, (224, 224),
192+
interpolation='bicubic',align_corners=True,
193+
antialias=self.antialias)
194+
x = (x + 1.) / 2.
195+
# renormalize according to clip
196+
x = kornia.enhance.normalize(x, self.mean, self.std)
197+
return x
198+
199+
def forward(self, x):
200+
# x is assumed to be in range [-1,1]
201+
return self.model.encode_image(self.preprocess(x))
202+

ldm/util.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,113 @@ def get_obj_from_str(string, reload=False):
8383
if reload:
8484
module_imp = importlib.import_module(module)
8585
importlib.reload(module_imp)
86-
return getattr(importlib.import_module(module, package=None), cls)
86+
return getattr(importlib.import_module(module, package=None), cls)
87+
88+
def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
89+
# create dummy dataset instance
90+
91+
# run prefetching
92+
if idx_to_fn:
93+
res = func(data,worker_id=idx)
94+
else:
95+
res = func(data)
96+
Q.put([idx, res])
97+
Q.put("Done")
98+
99+
100+
def parallel_data_prefetch(
101+
func: callable, data, n_proc, target_data_type="ndarray",cpu_intensive=True,use_worker_id=False
102+
):
103+
# if target_data_type not in ["ndarray", "list"]:
104+
# raise ValueError(
105+
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
106+
# )
107+
if isinstance(data, np.ndarray) and target_data_type == "list":
108+
raise ValueError("list expected but function got ndarray.")
109+
elif isinstance(data, abc.Iterable):
110+
if isinstance(data, dict):
111+
print(
112+
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
113+
)
114+
data = list(data.values())
115+
if target_data_type == "ndarray":
116+
data = np.asarray(data)
117+
else:
118+
data = list(data)
119+
else:
120+
raise TypeError(
121+
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
122+
)
123+
124+
if cpu_intensive:
125+
Q = mp.Queue(1000)
126+
proc = mp.Process
127+
else:
128+
Q = Queue(1000)
129+
proc = Thread
130+
# spawn processes
131+
if target_data_type == "ndarray":
132+
arguments = [
133+
[func, Q, part, i, use_worker_id]
134+
for i, part in enumerate(np.array_split(data, n_proc))
135+
]
136+
else:
137+
step = (
138+
int(len(data) / n_proc + 1)
139+
if len(data) % n_proc != 0
140+
else int(len(data) / n_proc)
141+
)
142+
arguments = [
143+
[func, Q, part, i, use_worker_id]
144+
for i, part in enumerate(
145+
[data[i : i + step] for i in range(0, len(data), step)]
146+
)
147+
]
148+
processes = []
149+
for i in range(n_proc):
150+
p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
151+
processes += [p]
152+
153+
# start processes
154+
print(f"Start prefetching...")
155+
import time
156+
157+
start = time.time()
158+
gather_res = [[] for _ in range(n_proc)]
159+
try:
160+
for p in processes:
161+
p.start()
162+
163+
k = 0
164+
while k < n_proc:
165+
# get result
166+
res = Q.get()
167+
if res == "Done":
168+
k += 1
169+
else:
170+
gather_res[res[0]] = res[1]
171+
172+
except Exception as e:
173+
print("Exception: ", e)
174+
for p in processes:
175+
p.terminate()
176+
177+
raise e
178+
finally:
179+
for p in processes:
180+
p.join()
181+
print(f"Prefetching complete. [{time.time() - start} sec.]")
182+
183+
if target_data_type == 'ndarray':
184+
if not isinstance(gather_res[0], np.ndarray):
185+
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
186+
187+
# order outputs
188+
return np.concatenate(gather_res, axis=0)
189+
elif target_data_type == 'list':
190+
out = []
191+
for r in gather_res:
192+
out.extend(r)
193+
return out
194+
else:
195+
return gather_res

0 commit comments

Comments
 (0)