Skip to content

Commit 360cd9f

Browse files
committed
merge and add links
2 parents 693cd2b + e09dbda commit 360cd9f

File tree

3 files changed

+200
-21
lines changed

3 files changed

+200
-21
lines changed

README.md

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
## News
2323

2424
### July 2022
25-
- Inference code and model weights to run our [retrieval-augmented diffusion models](https://arxiv.org/abs/2204.11824) are now available. See ##RDM.
25+
- Inference code and model weights to run our [retrieval-augmented diffusion models](https://arxiv.org/abs/2204.11824) are now available. See [this section](#rdm).
2626
### April 2022
2727
- Thanks to [Katherine Crowson](https://github.com/crowsonkb), classifier-free guidance received a ~2x speedup and the [PLMS sampler](https://arxiv.org/abs/2202.09778) is available. See also [this PR](https://github.com/CompVis/latent-diffusion/pull/51).
2828

@@ -49,15 +49,16 @@ If you use any of these models in your work, we are always happy to receive a [c
4949
![rdm-figure](assets/rdm-preview.jpg)
5050
We include inference code to run our retrieval-augmented diffusion models (RDMs) as described in [https://arxiv.org/abs/2204.11824](https://arxiv.org/abs/2204.11824).
5151

52-
To get started, install the following dependencies into the `ldm` conda environment,
53-
```bash
54-
pip install transformers==4.19.2 scann kornia
52+
53+
To get started, install the additionally required python packages into your ldm environment
54+
```shell script
55+
pip install transformers==4.19.2 scann kornia==0.6.4
5556
```
56-
and download the weights:
57+
and download the trained weights:
58+
5759
```bash
58-
mkdir -p models/rdm/rdm768x768
59-
wget -O models/rdm/rdm768x768/model.ckpt TODO
60-
wget -O models/rdm/rdm768x768/config.yaml TODO
60+
mkdir models/rdm/rdm768x768/
61+
wget -O models/rdm/rdm768x768/model.ckpt https://ommer-lab.com/files/rdm/model.ckpt
6162
```
6263
As these models are conditioned on a set of CLIP image embeddings, our RDMs support different inference modes,
6364
which are described in the following.
@@ -70,27 +71,45 @@ python scripts/knn2img.py --prompt "a happy bear reading a newspaper, oil on ca
7071
```
7172

7273
#### RDM with text-to-image retrieval
73-
Download the retrieval-databases which contain the retrieval-datasets (OpenImages and ArtBench) compressed into CLIP image embeddings:
74+
75+
To be able to run a RDM conditioned on a text-prompt and additionally images retrieved from this prompt, you will also need to download the corresponding retrieval database.
76+
We provide two distinct databases extracted from the [Openimages-](https://storage.googleapis.com/openimages/web/index.html) and [ArtBench-](https://github.com/liaopeiyuan/artbench) datasets.
77+
Interchanging the databases results in different capabilities
78+
of the resulting semi-parametric model as visualized below #TODO although the learned weights are the same in both cases.
79+
80+
Download the retrieval-databases which contain the retrieval-datasets ([Openimages](https://storage.googleapis.com/openimages/web/index.html) (~11GB) and [ArtBench](https://github.com/liaopeiyuan/artbench) (~82MB)) compressed into CLIP image embeddings:
7481
```bash
75-
mkdir -p data/rdm/openimages
76-
mkdir -p data/rdm/artbench
77-
wget -O data/rdm/openimages/data.p TODO
78-
wget -O data/rdm/artbench/data.p TODO
82+
mkdir -p data/rdm/retrieval_databases
83+
wget -O data/rdm/retrieval_databases/artbench.zip https://ommer-lab.com/files/rdm/artbench_databases.zip
84+
wget -O data/rdm/retrieval_databases/openimages.zip https://ommer-lab.com/files/rdm/openimages_database.zip
85+
unzip data/rdm/retrieval_databases/artbench.zip -d data/rdm/retrieval_databases/
86+
unzip data/rdm/retrieval_databases/openimages.zip -d data/rdm/retrieval_databases/
7987
```
80-
We also provide trained [ScaNN]()/[faiss]() search indices [here](TODO). Download via
88+
We also provide trained [ScaNN](https://github.com/google-research/google-research/tree/master/scann) search indices for ArtBench. Download and extract via
8189
```bash
82-
wget -O data/rdm/openimages/searcher.p TODO
83-
wget -O data/rdm/artbench/searcher TODO
90+
mkdir -p data/rdm/searchers
91+
wget -O data/rdm/searchers/artbench.zip https://ommer-lab.com/files/rdm/artbench_searchers.zip
92+
unzip data/rdm/searchers/openimages.zip -d data/rdm/searchers
8493
```
8594

95+
Since the index for OpenImages is large (~21 GB), we provide a script to create and save it for usage during sampling. Note however,
96+
that sampling with the OpenImages database will not be possible without this index. Run the script via
97+
```bash
98+
python scripts/train_searcher.py
99+
```
100+
101+
After this, retrieval based text-guided sampling with visual nearest neighbors can be started via
102+
```
103+
python scripts/knn2img.py --prompt "a happy bear reading a newspaper, oil on canvas" --use_neighbors --knn <number_of_neighbors>
104+
```
105+
Note that the maximum supported number of neighbors is 20. The database can be changed via the cmd parameter ``--database`` which can be `[openimages, artbench-art_nouveau, artbench-baroque, artbench-expressionism, artbench-impressionism, artbench-post_impressionism, artbench-realism, artbench-renaissance, artbench-romanticism, artbench-surrealism, artbench-ukiyo_e]`.
86106

87107

88-
#### RDM with image-to-image retrieval (maybe?, TODO)
89-
- simple modification of above section, support image encoding
90108

91109
#### Coming Soon
92110
- better models
93111
- more resolutions
112+
- image-to-image retrieval
94113

95114
## Text-to-Image
96115
![text2img-figure](assets/txt2img-preview.png)
@@ -323,6 +342,19 @@ Thanks for open-sourcing!
323342
archivePrefix={arXiv},
324343
primaryClass={cs.CV}
325344
}
345+
346+
@misc{https://doi.org/10.48550/arxiv.2204.11824,
347+
doi = {10.48550/ARXIV.2204.11824},
348+
url = {https://arxiv.org/abs/2204.11824},
349+
author = {Blattmann, Andreas and Rombach, Robin and Oktay, Kaan and Ommer, Björn},
350+
keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
351+
title = {Retrieval-Augmented Diffusion Models},
352+
publisher = {arXiv},
353+
year = {2022},
354+
copyright = {arXiv.org perpetual, non-exclusive license}
355+
}
356+
357+
326358
```
327359

328360

scripts/knn2img.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def __init__(self, database, retriever_version='ViT-L/14'):
6363
assert database in DATABASES
6464
# self.database = self.load_database(database)
6565
self.database_name = database
66-
self.searcher_savedir = f'models/searchers/{self.database_name}'
67-
self.database_path = f'data/retrieval_databases/{self.database_name}'
66+
self.searcher_savedir = f'data/rdm/searchers/{self.database_name}'
67+
self.database_path = f'data/rdm/retrieval_databases/{self.database_name}'
6868
self.retriever = self.load_retriever(version=retriever_version)
6969
self.database = {'embedding': [],
7070
'img_id': [],
@@ -287,7 +287,7 @@ def __call__(self, x, n):
287287
parser.add_argument(
288288
"--database",
289289
type=str,
290-
default=DATABASES[0],
290+
default='artbench-surrealism',
291291
choices=DATABASES,
292292
help="The database used for the search, only applied when --use_neighbors=True",
293293
)

scripts/train_searcher.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import os, sys
2+
import numpy as np
3+
import scann
4+
import argparse
5+
import glob
6+
from multiprocessing import cpu_count
7+
from tqdm import tqdm
8+
9+
from ldm.util import parallel_data_prefetch
10+
11+
12+
def search_bruteforce(searcher):
13+
return searcher.score_brute_force().build()
14+
15+
16+
def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k,
17+
partioning_trainsize, num_leaves, num_leaves_to_search):
18+
return searcher.tree(num_leaves=num_leaves,
19+
num_leaves_to_search=num_leaves_to_search,
20+
training_sample_size=partioning_trainsize). \
21+
score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
22+
23+
24+
def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
25+
return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(
26+
reorder_k).build()
27+
28+
def load_datapool(dpath):
29+
30+
31+
def load_single_file(saved_embeddings):
32+
compressed = np.load(saved_embeddings)
33+
database = {key: compressed[key] for key in compressed.files}
34+
return database
35+
36+
def load_multi_files(data_archive):
37+
database = {key: [] for key in data_archive[0].files}
38+
for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
39+
for key in d.files:
40+
database[key].append(d[key])
41+
42+
return database
43+
44+
print(f'Load saved patch embedding from "{dpath}"')
45+
file_content = glob.glob(os.path.join(dpath, '*.npz'))
46+
47+
if len(file_content) == 1:
48+
data_pool = load_single_file(file_content[0])
49+
elif len(file_content) > 1:
50+
data = [np.load(f) for f in file_content]
51+
prefetched_data = parallel_data_prefetch(load_multi_files, data,
52+
n_proc=min(len(data), cpu_count()), target_data_type='dict')
53+
54+
data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()}
55+
else:
56+
raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
57+
58+
print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.')
59+
return data_pool
60+
61+
62+
def train_searcher(opt,
63+
metric='dot_product',
64+
partioning_trainsize=None,
65+
reorder_k=None,
66+
# todo tune
67+
aiq_thld=0.2,
68+
dims_per_block=2,
69+
num_leaves=None,
70+
num_leaves_to_search=None,):
71+
72+
data_pool = load_datapool(opt.database)
73+
k = opt.knn
74+
75+
if not reorder_k:
76+
reorder_k = 2 * k
77+
78+
# normalize
79+
# embeddings =
80+
searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric)
81+
pool_size = data_pool['embedding'].shape[0]
82+
83+
print(*(['#'] * 100))
84+
print('Initializing scaNN searcher with the following values:')
85+
print(f'k: {k}')
86+
print(f'metric: {metric}')
87+
print(f'reorder_k: {reorder_k}')
88+
print(f'anisotropic_quantization_threshold: {aiq_thld}')
89+
print(f'dims_per_block: {dims_per_block}')
90+
print(*(['#'] * 100))
91+
print('Start training searcher....')
92+
print(f'N samples in pool is {pool_size}')
93+
94+
# this reflects the recommended design choices proposed at
95+
# https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
96+
if pool_size < 2e4:
97+
print('Using brute force search.')
98+
searcher = search_bruteforce(searcher)
99+
elif 2e4 <= pool_size and pool_size < 1e5:
100+
print('Using asymmetric hashing search and reordering.')
101+
searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
102+
else:
103+
print('Using using partioning, asymmetric hashing search and reordering.')
104+
105+
if not partioning_trainsize:
106+
partioning_trainsize = data_pool['embedding'].shape[0] // 10
107+
if not num_leaves:
108+
num_leaves = int(np.sqrt(pool_size))
109+
110+
if not num_leaves_to_search:
111+
num_leaves_to_search = max(num_leaves // 20, 1)
112+
113+
print('Partitioning params:')
114+
print(f'num_leaves: {num_leaves}')
115+
print(f'num_leaves_to_search: {num_leaves_to_search}')
116+
# self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
117+
searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k,
118+
partioning_trainsize, num_leaves, num_leaves_to_search)
119+
120+
print('Finish training searcher')
121+
searcher_savedir = opt.target_path
122+
os.makedirs(searcher_savedir, exist_ok=True)
123+
searcher.serialize(searcher_savedir)
124+
print(f'Saved trained searcher under "{searcher_savedir}"')
125+
126+
if __name__ == '__main__':
127+
sys.path.append(os.getcwd())
128+
parser = argparse.ArgumentParser()
129+
parser.add_argument('--database',
130+
'-d',
131+
default='data/rdm/retrieval_databases/openimages',
132+
type=str,
133+
help='path to folder containing the clip feature of the database')
134+
parser.add_argument('--target_path',
135+
'-t',
136+
default='data/rdm/searchers/openimages',
137+
type=str,
138+
help='path to the target folder where the searcher shall be stored.')
139+
parser.add_argument('--knn',
140+
'-k',
141+
default=20,
142+
type=int,
143+
help='number of nearest neighbors, for which the searcher shall be optimized')
144+
145+
opt, _ = parser.parse_known_args()
146+
147+
train_searcher(opt,)

0 commit comments

Comments
 (0)