from models.utils import get_prompts
from bayesvlm.hessians import compute_covariances, optimize_prior_precision
import open_clip 
from bayesvlm.vlm import CLIP
import types
import warnings

from torch.nn import Linear
import torch

def get_covariances(
    clip_model,
    A_img,
    B_img,
    A_txt,
    B_txt,
    lambda_init: int = 1500,
    lr = 1e-2,
    num_steps = 300,
    pseudo_data_count: int = 10,
    device: str = "cuda",
    verbose: str = False
):
    info = {'n_img': pseudo_data_count, 'n_txt': pseudo_data_count}

    visual_proj = Linear(in_features=768, out_features=512, bias=False)
    visual_proj.weight.data = clip_model.visual.proj.data.T.clone()

    info['lambda_img'] = optimize_prior_precision(
        visual_proj,
        A=A_img,
        B=B_img,
        lmbda_init=lambda_init,
        n=info['n_img'],
        lr=lr,
        num_steps=num_steps,
        device=device,
        verbose=verbose,
    ).item()

    text_projection = Linear(in_features=768, out_features=512, bias=False)
    text_projection.weight.data = clip_model.text_projection.data.T.clone()

    info['lambda_txt'] = optimize_prior_precision(
        text_projection,
        A=A_txt,
        B=B_txt,
        lmbda_init=lambda_init,
        n=info['n_txt'],
        lr=lr,
        num_steps=num_steps,
        device=device,
        verbose=verbose,
    ).item()

    cov_img, cov_txt = compute_covariances(A_img, B_img, A_txt, B_txt, info)

    return cov_img, cov_txt

def get_bayes_vlm_model(
        dataset_name: str,
        model_name: str = "ViT-B-32",
        device: str = "cuda"
):  
    warnings.filterwarnings('ignore')
    avail_models = {
        "ViT-B-32": "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
    }
    if model_name in avail_models:
        model_url = avail_models[model_name]
    else:
        raise ValueError(f"Model '{model_name}' not found.")

    if model_name == "ViT-B-32":
        clip_open, _, _ = open_clip.create_model_and_transforms("ViT-B-32",
                                                                "laion2B-s34B-b79K")
    else:
        raise ValueError(f"Model '{model_name}' not supported.")

    vlm = CLIP.from_huggingface(model_url, device=device).eval().to(device)
    vlm.open_clip_model = clip_open.to(device)
    
    tokenizer = open_clip.get_tokenizer("ViT-B-32")

    def forward(self, x: torch.Tensor):
        x = self._embeds(x)
        x = self.transformer(x)
        activations, _ = self._pool(x)
        pooled = activations @ self.proj
            
        return pooled, activations
    
    def encode_text(self, text):
        cast_dtype = self.transformer.get_cast_dtype()

        x = self.token_embedding(text).to(cast_dtype)  # [batch_size, n_ctx, d_model]

        x = x + self.positional_embedding.to(cast_dtype)
        x = self.transformer(x, attn_mask=self.attn_mask)
        x = self.ln_final(x)  # [batch_size, n_ctx, transformer.width]
        activations = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
        embeds = activations @ self.text_projection

        return embeds, activations

    vlm.open_clip_model.visual.forward = types.MethodType(forward, 
                                                          vlm.open_clip_model.visual)
    vlm.open_clip_model.encode_text = types.MethodType(encode_text, vlm.open_clip_model)

    prompts = get_prompts(dataset = dataset_name)
    text_tokens = tokenizer(prompts).to(device)
    with torch.no_grad():
        text_embeds, text_activations = vlm.open_clip_model.encode_text(text_tokens)
        vlm.text_embeds, vlm.text_activations = text_embeds, text_activations

    return vlm