TechnologyNovember 2, 2023

GPT-4V with Context: Using Retrieval Augmented Generation with Multimodal Models

Ryan Smith
Ryan SmithMachine learning engineer
GPT-4V with Context: Using Retrieval Augmented Generation with Multimodal Models
!mkdir multimodal_data

import numpy as np
import pandas as pd
from huggingface_hub import hf_hub_download

REPO_ID = "laion/laion2b-en-vit-l-14-embeddings"

for filename in [
    "img_emb/img_emb_0000.npy",
    "text_emb/text_emb_0000.npy",
    "metadata/metadata_0000.parquet",
]:
    hf_hub_download(repo_id=REPO_ID, filename=filename, repo_type="dataset",
local_dir=
"multimodal_data")

img_embs = np.load("multimodal_data/img_emb/img_emb_0000.npy")
text_embs = np.load("multimodal_data/text_emb/text_emb_0000.npy")
metadata_df = pd.read_parquet("multimodal_data/metadata/metadata_0000.parquet")
import cassio
from cassio.table import MetadataVectorCassandraTable

cassio.init(
    token="YOUR TOKEN HERE",
    database_id="YOUR DB ID HERE",
)

mm_table = MetadataVectorCassandraTable(
    table="multimodal_demo_vs",
    vector_dimension=768# CLIP VIT-L/14 embedding dimension
)
import numpy as np

def merge_mm_embeddings(img_emb=None, text_emb=None):
    if text_emb is not None and img_emb is not None:
        return np.mean([img_emb, text_emb], axis=0)
    elif text_emb is not None:
        return text_emb
    elif img_emb is not None:
        return img_emb

    else:
        raise ValueError("Must specify one of `img_emb` or `text_emb`")
from tqdm import tqdm

def add_row_to_db(ndx: int):
    # NOTE: These vectors have already been normalized
   row = df.iloc[ndx]
   img_emb = tmp_img[ndx]
    text_emb = tmp_text[ndx]

   emb = merge_mm_embeddings(img_emb, text_emb)

    return table.put_async(
        row_id=row["key"],
        body_blob=row["caption"],
        vector=emb,
        metadata={
            key: row[key]
            for key in df.columns
            if key not in ["key", "caption"]
       },
    )

all_futures = []
for ndx in tqdm(range(len(all_futures), len(df))):
    all_futures.append(add_row_to_db(ndx))

for future in tqdm(all_futures):
    future.result()
import os
import sys
from huggingface_hub import hf_hub_download

if not os.path.exists("MiniGPT-4-RAG-Demo"):
    !git clone https://github.com/rsmith49/MiniGPT-4-RAG-Demo.git

sys.path.append("MiniGPT-4-RAG-Demo/")

# Download minigpt-4 weights for Vicuna-7B
hf_hub_download(
    repo_id="Vision-CAIR/minigpt4",
    filename="prerained_minigpt4_7b.pth",
    repo_type="space",
    local_dir="./",
)
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import (
    Chat as MiniGPT4Chat,
   CONV_VISION_Vicuna0,
   CONV_VISION_LLama2,
)


class TmpArgs:
    options = None
    cfg_path = "MiniGPT-4-RAG-DEMO/eval_configs/minigpt4_eval.yaml"

CONV_VISION = None


def init_chat() -> MiniGPT4Chat:
    """Initialize a basic chat session with MiniGPT-4

   We make some quality of life changes and fit into the expected
   infrastructure of the MiniGPT-4 repo in order to load the model for
   inference locally
   """
    global CONV_VISION

    conv_dict = {'pretrain_vicuna0': CONV_VISION_Vicuna0,
                 'pretrain_llama2': CONV_VISION_LLama2}

    print('Initializing Chat')
    cfg = Config(TmpArgs())
    model_config = cfg.model_cfg

   # Config adjustments
    # This enables loading the non-quantized model for better performance
    model_config.low_resource = False

    model_cls = registry.get_model_class(model_config.arch)
    model = model_cls.from_config(model_config).to('cuda:0')

    CONV_VISION = conv_dict[model_config.model_type]

    vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
    vis_processor =
registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
    chat = MiniGPT4Chat(model, vis_processor, device='cuda:0')
    print('Initialization Finished')

    return chat
chat = init_chat()

chat.upload_image("http://images.cocodataset.org/val2017/000000039769.jpg")
chat.query("What is this a picture of?")
chat.reset()
chat.query("What does the Apache Cassandra logo look like?")
 
import torch
from transformers import CLIPProcessor, CLIPModel

clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda:0")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

def _embed(img_url: str | None, text: str | None) -> torch.Tensor:
   """Get an embedding based on an Img URL and/or Text"""
    if img_url is None and text is None:
        raise ValueError(f"Must specify one of img_url or text")

    img_emb = None
    text_emb = None

    if img_url is not None:
        image = img_from_url(img_url)
        img_emb = clip_model.get_image_features(
            clip_processor(text=None, images=image,
return_tensors="pt")["pixel_values"].to("cuda:0")
        )[0].to("cpu").detach().numpy()
        assert img_emb.shape == torch.Size([768])

    if text is not None:
        text_emb = clip_model.get_text_features(
            clip_processor(text=[text], images=None, return_tensors="pt",
padding=False)["input_ids"].to("cuda:0")
        )[0].to("cpu").detach().numpy()
        assert text_emb.shape == torch.Size([768])

    emb = merge_mm_embeddings(img_emb, text_emb)
    return emb
from typing import Any, Dict, List
from uuid import uuid4

class RAGChat(Chat):
    def __init__(
        self,
        *args: Any,
        table: MetadataVectorCassandraTable | None = None,
        **kwargs: Any,
    ):
        super().__init__(*args, **kwargs)

        if table is None:
            table = mm_table
        self.table = table

    def query_vectorstore(
        self,
        text: str | None = None,
        img_url: str | None = None,
        **ann_kwargs,
    ) -> List[Dict[str, Any]]:
        ...

    def embed_and_store_image(self, url: str, caption: str | None = None) -> int:

        ...

    def query(self, text: str, debug: bool = True, **generate_kwargs) -> str:
        ...
    def query_vectorstore(
        self,
        text: str | None = None,
        img_url: str | None = None,
        **ann_kwargs,
    ) -> List[Dict[str, Any]]:

        emb = _embed(img_url=img_url, text=text)
        results_gen = self.table.metric_ann_search(
            vector=emb,
            metric="cos",
            **ann_kwargs,
        )
        return list(results_gen)
    def embed_and_store_image(self, url: str, caption: str | None = None) ->
int:
        row_kwargs = dict(
            row_id=str(uuid4()),
            body_blob=caption,
            metadata={"url": url},
            vector=_embed(img_url=url, text=caption),
        )

        self.table.put(**row_kwargs)

        return row_kwargs["row_id"]
    def query(self, text: str, debug=True, **generate_kwargs) -> str:
        results = self.query_vectorstore(text=text, n=3)
        if debug:
            print("Search Results:", [
                {
                    key: result[key]
                    for key in ["row_id", "body_blob", "distance"]
                }
                for result in results
            ])
            print("-" * 80)

       # Try all results in case links are broken or tokenization issues
        for ndx, result in enumerate(results):
            img_url = result["metadata"]["url"]

            try:
                self.upload_image(img_url)
                if debug:
                    print(f"Successfully tokenized IMG {ndx}: {img_url}")
                    print("-" * 80)
                break
            except Exception as e:
                print(f"Could not tokenize IMG {ndx}: {img_url}, got the
following error:")
               print(e)
                print("-" * 80)

        return super().query(text, **generate_kwargs)
rag_chat = RagChat()

rag_chat.embed_and_store_image(
    url="https://i0.wp.com/blog.knoldus.com/wp-content/uploads/2018/08/cassandra.pn
g"
,
    caption="Cassandra logo",

)
rag_chat.query("What does the Apache Cassandra logo look like?")
rag_chat.reset()
rag_chat.query("What are some outfits from Pulp Fiction?")
Discover more
Retrieval-augmented generation
Share

NoSQL and Vector DB
for Generative AI,
Instantly, At Scale

Vector search capabilities on Astra DB enable complex, context-sensitive searches across diverse data formats for use in Generative AI applications, powered by Apache Cassandra®.