Technology•November 2, 2023
GPT-4V with Context: Using Retrieval Augmented Generation with Multimodal Models

!mkdir multimodal_data
import numpy as np
import pandas as pd
from huggingface_hub import hf_hub_download
REPO_ID = "laion/laion2b-en-vit-l-14-embeddings"
for filename in [
"img_emb/img_emb_0000.npy",
"text_emb/text_emb_0000.npy",
"metadata/metadata_0000.parquet",
]:
hf_hub_download(repo_id=REPO_ID, filename=filename, repo_type="dataset",
local_dir="multimodal_data")
img_embs = np.load("multimodal_data/img_emb/img_emb_0000.npy")
text_embs = np.load("multimodal_data/text_emb/text_emb_0000.npy")
metadata_df = pd.read_parquet("multimodal_data/metadata/metadata_0000.parquet")
import cassio
from cassio.table import MetadataVectorCassandraTable
cassio.init(
token="YOUR TOKEN HERE",
database_id="YOUR DB ID HERE",
)
mm_table = MetadataVectorCassandraTable(
table="multimodal_demo_vs",
vector_dimension=768, # CLIP VIT-L/14 embedding dimension
)
import numpy as np
def merge_mm_embeddings(img_emb=None, text_emb=None):
if text_emb is not None and img_emb is not None:
return np.mean([img_emb, text_emb], axis=0)
elif text_emb is not None:
return text_emb
elif img_emb is not None:
return img_emb
else:
raise ValueError("Must specify one of `img_emb` or `text_emb`")
from tqdm import tqdm
def add_row_to_db(ndx: int):
# NOTE: These vectors have already been normalized
row = df.iloc[ndx]
img_emb = tmp_img[ndx]
text_emb = tmp_text[ndx]
emb = merge_mm_embeddings(img_emb, text_emb)
return table.put_async(
row_id=row["key"],
body_blob=row["caption"],
vector=emb,
metadata={
key: row[key]
for key in df.columns
if key not in ["key", "caption"]
},
)
all_futures = []
for ndx in tqdm(range(len(all_futures), len(df))):
all_futures.append(add_row_to_db(ndx))
for future in tqdm(all_futures):
future.result()
import os
import sys
from huggingface_hub import hf_hub_download
if not os.path.exists("MiniGPT-4-RAG-Demo"):
!git clone https://github.com/rsmith49/MiniGPT-4-RAG-Demo.git
sys.path.append("MiniGPT-4-RAG-Demo/")
# Download minigpt-4 weights for Vicuna-7B
hf_hub_download(
repo_id="Vision-CAIR/minigpt4",
filename="prerained_minigpt4_7b.pth",
repo_type="space",
local_dir="./",
)
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import (
Chat as MiniGPT4Chat,
CONV_VISION_Vicuna0,
CONV_VISION_LLama2,
)
class TmpArgs:
options = None
cfg_path = "MiniGPT-4-RAG-DEMO/eval_configs/minigpt4_eval.yaml"
CONV_VISION = None
def init_chat() -> MiniGPT4Chat:
"""Initialize a basic chat session with MiniGPT-4
We make some quality of life changes and fit into the expected
infrastructure of the MiniGPT-4 repo in order to load the model for
inference locally
"""
global CONV_VISION
conv_dict = {'pretrain_vicuna0': CONV_VISION_Vicuna0,
'pretrain_llama2': CONV_VISION_LLama2}
print('Initializing Chat')
cfg = Config(TmpArgs())
model_config = cfg.model_cfg
# Config adjustments
# This enables loading the non-quantized model for better performance
model_config.low_resource = False
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:0')
CONV_VISION = conv_dict[model_config.model_type]
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor =
registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = MiniGPT4Chat(model, vis_processor, device='cuda:0')
print('Initialization Finished')
return chat
chat = init_chat()
chat.upload_image("http://images.cocodataset.org/val2017/000000039769.jpg")
chat.query("What is this a picture of?")
chat.reset()
chat.query("What does the Apache Cassandra logo look like?")
import torch
from transformers import CLIPProcessor, CLIPModel
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda:0")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
def _embed(img_url: str | None, text: str | None) -> torch.Tensor:
"""Get an embedding based on an Img URL and/or Text"""
if img_url is None and text is None:
raise ValueError(f"Must specify one of img_url or text")
img_emb = None
text_emb = None
if img_url is not None:
image = img_from_url(img_url)
img_emb = clip_model.get_image_features(
clip_processor(text=None, images=image,
return_tensors="pt")["pixel_values"].to("cuda:0")
)[0].to("cpu").detach().numpy()
assert img_emb.shape == torch.Size([768])
if text is not None:
text_emb = clip_model.get_text_features(
clip_processor(text=[text], images=None, return_tensors="pt",
padding=False)["input_ids"].to("cuda:0")
)[0].to("cpu").detach().numpy()
assert text_emb.shape == torch.Size([768])
emb = merge_mm_embeddings(img_emb, text_emb)
return emb
from typing import Any, Dict, List
from uuid import uuid4
class RAGChat(Chat):
def __init__(
self,
*args: Any,
table: MetadataVectorCassandraTable | None = None,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
if table is None:
table = mm_table
self.table = table
def query_vectorstore(
self,
text: str | None = None,
img_url: str | None = None,
**ann_kwargs,
) -> List[Dict[str, Any]]:
...
def embed_and_store_image(self, url: str, caption: str | None = None) -> int:
...
def query(self, text: str, debug: bool = True, **generate_kwargs) -> str:
...
def query_vectorstore(
self,
text: str | None = None,
img_url: str | None = None,
**ann_kwargs,
) -> List[Dict[str, Any]]:
emb = _embed(img_url=img_url, text=text)
results_gen = self.table.metric_ann_search(
vector=emb,
metric="cos",
**ann_kwargs,
)
return list(results_gen)
def embed_and_store_image(self, url: str, caption: str | None = None) ->
int:
row_kwargs = dict(
row_id=str(uuid4()),
body_blob=caption,
metadata={"url": url},
vector=_embed(img_url=url, text=caption),
)
self.table.put(**row_kwargs)
return row_kwargs["row_id"]
def query(self, text: str, debug=True, **generate_kwargs) -> str:
results = self.query_vectorstore(text=text, n=3)
if debug:
print("Search Results:", [
{
key: result[key]
for key in ["row_id", "body_blob", "distance"]
}
for result in results
])
print("-" * 80)
# Try all results in case links are broken or tokenization issues
for ndx, result in enumerate(results):
img_url = result["metadata"]["url"]
try:
self.upload_image(img_url)
if debug:
print(f"Successfully tokenized IMG {ndx}: {img_url}")
print("-" * 80)
break
except Exception as e:
print(f"Could not tokenize IMG {ndx}: {img_url}, got the
following error:")
print(e)
print("-" * 80)
return super().query(text, **generate_kwargs)
rag_chat = RagChat()
rag_chat.embed_and_store_image(
url="https://i0.wp.com/blog.knoldus.com/wp-content/uploads/2018/08/cassandra.pn
g",
caption="Cassandra logo",
)
rag_chat.query("What does the Apache Cassandra logo look like?")
rag_chat.reset()
rag_chat.query("What are some outfits from Pulp Fiction?")