TechnologyNovember 2, 2023

GPT-4V with Context: Using Retrieval Augmented Generation with Multimodal Models

Ryan Smith
Ryan SmithMachine learning engineer
GPT-4V with Context: Using Retrieval Augmented Generation with Multimodal Models
!mkdir multimodal_data

import numpy as np
import pandas as pd
from huggingface_hub import hf_hub_download

REPO_ID = "laion/laion2b-en-vit-l-14-embeddings"

for filename in [
    hf_hub_download(repo_id=REPO_ID, filename=filename, repo_type="dataset",

img_embs = np.load("multimodal_data/img_emb/img_emb_0000.npy")
text_embs = np.load("multimodal_data/text_emb/text_emb_0000.npy")
metadata_df = pd.read_parquet("multimodal_data/metadata/metadata_0000.parquet")
import cassio
from cassio.table import MetadataVectorCassandraTable

    token="YOUR TOKEN HERE",
    database_id="YOUR DB ID HERE",

mm_table = MetadataVectorCassandraTable(
    vector_dimension=768# CLIP VIT-L/14 embedding dimension
import numpy as np

def merge_mm_embeddings(img_emb=None, text_emb=None):
    if text_emb is not None and img_emb is not None:
        return np.mean([img_emb, text_emb], axis=0)
    elif text_emb is not None:
        return text_emb
    elif img_emb is not None:
        return img_emb

        raise ValueError("Must specify one of `img_emb` or `text_emb`")
from tqdm import tqdm

def add_row_to_db(ndx: int):
    # NOTE: These vectors have already been normalized
   row = df.iloc[ndx]
   img_emb = tmp_img[ndx]
    text_emb = tmp_text[ndx]

   emb = merge_mm_embeddings(img_emb, text_emb)

    return table.put_async(
            key: row[key]
            for key in df.columns
            if key not in ["key", "caption"]

all_futures = []
for ndx in tqdm(range(len(all_futures), len(df))):

for future in tqdm(all_futures):
import os
import sys
from huggingface_hub import hf_hub_download

if not os.path.exists("MiniGPT-4-RAG-Demo"):
    !git clone


# Download minigpt-4 weights for Vicuna-7B
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import (
    Chat as MiniGPT4Chat,

class TmpArgs:
    options = None
    cfg_path = "MiniGPT-4-RAG-DEMO/eval_configs/minigpt4_eval.yaml"


def init_chat() -> MiniGPT4Chat:
    """Initialize a basic chat session with MiniGPT-4

   We make some quality of life changes and fit into the expected
   infrastructure of the MiniGPT-4 repo in order to load the model for
   inference locally
    global CONV_VISION

    conv_dict = {'pretrain_vicuna0': CONV_VISION_Vicuna0,
                 'pretrain_llama2': CONV_VISION_LLama2}

    print('Initializing Chat')
    cfg = Config(TmpArgs())
    model_config = cfg.model_cfg

   # Config adjustments
    # This enables loading the non-quantized model for better performance
    model_config.low_resource = False

    model_cls = registry.get_model_class(model_config.arch)
    model = model_cls.from_config(model_config).to('cuda:0')

    CONV_VISION = conv_dict[model_config.model_type]

    vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
    vis_processor =
    chat = MiniGPT4Chat(model, vis_processor, device='cuda:0')
    print('Initialization Finished')

    return chat
chat = init_chat()

chat.query("What is this a picture of?")
chat.query("What does the Apache Cassandra logo look like?")
import torch
from transformers import CLIPProcessor, CLIPModel

clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda:0")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

def _embed(img_url: str | None, text: str | None) -> torch.Tensor:
   """Get an embedding based on an Img URL and/or Text"""
    if img_url is None and text is None:
        raise ValueError(f"Must specify one of img_url or text")

    img_emb = None
    text_emb = None

    if img_url is not None:
        image = img_from_url(img_url)
        img_emb = clip_model.get_image_features(
            clip_processor(text=None, images=image,
        assert img_emb.shape == torch.Size([768])

    if text is not None:
        text_emb = clip_model.get_text_features(
            clip_processor(text=[text], images=None, return_tensors="pt",
        assert text_emb.shape == torch.Size([768])

    emb = merge_mm_embeddings(img_emb, text_emb)
    return emb
from typing import Any, Dict, List
from uuid import uuid4

class RAGChat(Chat):
    def __init__(
        *args: Any,
        table: MetadataVectorCassandraTable | None = None,
        **kwargs: Any,
        super().__init__(*args, **kwargs)

        if table is None:
            table = mm_table
        self.table = table

    def query_vectorstore(
        text: str | None = None,
        img_url: str | None = None,
    ) -> List[Dict[str, Any]]:

    def embed_and_store_image(self, url: str, caption: str | None = None) -> int:


    def query(self, text: str, debug: bool = True, **generate_kwargs) -> str:
    def query_vectorstore(
        text: str | None = None,
        img_url: str | None = None,
    ) -> List[Dict[str, Any]]:

        emb = _embed(img_url=img_url, text=text)
        results_gen = self.table.metric_ann_search(
        return list(results_gen)
    def embed_and_store_image(self, url: str, caption: str | None = None) ->
        row_kwargs = dict(
            metadata={"url": url},
            vector=_embed(img_url=url, text=caption),


        return row_kwargs["row_id"]
    def query(self, text: str, debug=True, **generate_kwargs) -> str:
        results = self.query_vectorstore(text=text, n=3)
        if debug:
            print("Search Results:", [
                    key: result[key]
                    for key in ["row_id", "body_blob", "distance"]
                for result in results
            print("-" * 80)

       # Try all results in case links are broken or tokenization issues
        for ndx, result in enumerate(results):
            img_url = result["metadata"]["url"]

                if debug:
                    print(f"Successfully tokenized IMG {ndx}: {img_url}")
                    print("-" * 80)
            except Exception as e:
                print(f"Could not tokenize IMG {ndx}: {img_url}, got the
following error:")
                print("-" * 80)

        return super().query(text, **generate_kwargs)
rag_chat = RagChat()

    caption="Cassandra logo",

rag_chat.query("What does the Apache Cassandra logo look like?")
rag_chat.query("What are some outfits from Pulp Fiction?")
Discover more
Retrieval-augmented generation

NoSQL and Vector DB
for Generative AI,
Instantly, At Scale

Vector search capabilities on Astra DB enable complex, context-sensitive searches across diverse data formats for use in Generative AI applications, powered by Apache Cassandra®.