from copy import deepcopy
from datetime import timedelta
from typing import Optional, Union, List, Sequence, Dict
import os
import json

from .BaseEmbedder import BaseEmbedder

from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs
from accelerate.state import AcceleratorState
from accelerate.utils import gather_object

import numpy as np
from transformers import CLIPModel, CLIPProcessor
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import torch
from PIL import Image


class ClipBgeEmbedder(BaseEmbedder):
    def __init__(
        self,
        name: str,
        output_path: str,
        mm_pretrained: str = "openai/clip-vit-large-patch14",
        txt_pretrained: str = "BAAI/bge-m3",
        device: str = "cuda",
        device_map: str = "",
    ) -> None:
        super().__init__(name, output_path)
        self.model = CLIPModel.from_pretrained(mm_pretrained)
        self.processor = CLIPProcessor.from_pretrained(mm_pretrained)
        self.text_model = SentenceTransformer(txt_pretrained)
        accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
        self.accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
        if self.accelerator.num_processes > 1 and device_map == "":
            self.device = torch.device(f"cuda:{self.accelerator.local_process_index}")
            self.device_map = f"cuda:{self.accelerator.local_process_index}"
        else:
            self.device = torch.device(device)
            self.device_map = device_map

        self.model.to(self.device)
        self.text_model.to(self.device)

    def embed_task(self, task: str, ignored_ids: Union[set, List] = None):
        DATASET_PATH, DATASET_NAME, split, requests, task_obj, self.docs = BaseEmbedder.init_task(task, ignored_ids)
        self.accelerator.wait_for_everyone()
        with self.accelerator.split_between_processes(requests, apply_padding=False) as requests_split:
            results = {"outputs": []}
            for req in tqdm(requests_split, disable=not self.accelerator.is_main_process):
                contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = req.args
                visuals = [doc_to_visual(self.docs[doc_id])]
                visuals = self.flatten(visuals)

                text_embedding = self.text_model.encode([contexts])
                text_embedding = torch.from_numpy(text_embedding).flatten()

                if len(visuals) > 0:
                    img_inputs = self.processor(images=visuals, return_tensors="pt")
                    img_inputs = {k: v.to(self.device) for k, v in img_inputs.items()}

                    # For multiple images, we take the mean of it
                    image_embedding = self.model.get_image_features(**img_inputs).mean(dim=0).detach().cpu()
                else:
                    image_embedding = torch.zeros(self.model.config.projection_dim)

                embedding = torch.concat([image_embedding, text_embedding])

                results["outputs"].append(embedding)
            results = [results]

        self.accelerator.wait_for_everyone()
        results_gathered = gather_object(results)
        if self.accelerator.is_main_process:
            outputs = []
            for r in results_gathered:
                outputs += r["outputs"]
            results_gathered = torch.stack(outputs)
            np.save(open(os.path.join(self.output_path, f"{task}_embed.npy"), "wb"), results_gathered)
            return results_gathered


if __name__ == "__main__":
    model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
    text = ["a photo of a cat", "a photo of a dog"]
    inputs = processor(text=text, return_tensors="pt")
    inputs = {k: v.to("cuda") for k, v in inputs.items()}
    outputs = model.get_text_features(**inputs)
    print(outputs.mean(dim=0).shape)
