import torch
import torch.nn as nn
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests

# Load CLIP model and processor
device = torch.device("cuda")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = model.to(torch.float32)
model = model.to(device)
      
class Chicken(nn.Module):
    def __init__(self, model, device="cpu", init_val=0.1, max_mult=1, matching_texts=["layernorm", "bias", "embeddings", "layrnorm", "layer_norm"]):
        super().__init__()
        self.init_val = init_val
        self.matching_texts = matching_texts
        self.device = device

        self.model = model
        self.model = self.model.to(device)

        self.base_params = model.state_dict()
        self.decomposed_params = self.decompose(self.base_params, matching_texts)

        self.class_policy_map = {}

        self.learnable_params = {}
        self.trainable_params = {}
        self.num_params = 0
        self.max_mult = max_mult
        self.enable_mask = []
        self.mask_idx = 0
        
    def add_class(self, class_names):
        """
        Call this to add a new set of classes
        """
        for name in class_names:
            self.class_policy_map[name] = self.mask_idx

        mask_params = {}
        for k, v in self.base_params.items():
            # each param initialized with small gaussian noise
            if any(text in k for text in self.matching_texts):
                continue
            else:
                mask_params[k] = nn.Parameter(
                    torch.randn(min(v.shape), device=self.device, dtype=torch.bfloat16) * self.init_val,
                    requires_grad=True
                )
            self.num_params += mask_params[k].numel()

        self.learnable_params[self.mask_idx] = mask_params
        self.trainable_params[self.mask_idx] = list(mask_params.values())
        self.enable_mask.append(True)

        self.mask_idx += 1

    def get_trainable_parameters(self, mask_idx=None):
        if mask_idx is None:
            mask_idx = self.latest_mask_idx

        return self.trainable_params[mask_idx]
    
    @property
    def class_map(self):
        """
        Print policy map
        """
        # inverse map
        inverse_map = {}
        for name in self.class_policy_map:
            mask_idx = self.class_policy_map[name]
            if mask_idx not in inverse_map:
                inverse_map[mask_idx] = []
            inverse_map[mask_idx].append(name)

        string = "CLASS MAP\n"
        string += "------------------\n"
        for mask_idx in inverse_map:
            string += f"{mask_idx}: {', '.join(inverse_map[mask_idx])}\n"
        string += "------------------\n"
        return string

    @property
    def latest_mask_idx(self):
        """
        return the latest mask
        """
        return self.mask_idx - 1

    @staticmethod
    def decompose(base_params, skip_match_texts = []):
        decomposed_params = {}
        for k, v in base_params.items():
            if any(text in k for text in skip_match_texts):
                continue  # skip this param

            U, S, V = torch.svd(v)
            decomposed_params[f"{k}.U"] = U
            decomposed_params[f"{k}.S"] = S
            decomposed_params[f"{k}.V"] = V
        return decomposed_params

    def get_mask(self, p, mask_idx):
        """
        apply sigmoid on the mask
        """
        if self.enable_mask[mask_idx]:
            return torch.sigmoid(p).to(torch.bfloat16) * self.max_mult
        else:
            return torch.ones_like(p).to(torch.bfloat16)

    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def compose_new_params(
        self,
        param_name,
        mask_idx,
    ):
        """
        Compose new parameters from decomposed parameters.
        """
        mm = self.get_mask(self.learnable_params[mask_idx][param_name], mask_idx)
        return (
            self.decomposed_params[f"{param_name}.U"]
            @ torch.diag_embed(self.decomposed_params[f"{param_name}.S"] * mm)
            @ self.decomposed_params[f"{param_name}.V"].T
        ) * (
            self.decomposed_params[f"{param_name}.S"].sum()
            / (self.decomposed_params[f"{param_name}.S"] * mm).sum()
        )
    
    def toggle_mask(self, mask_value=True, mask_idx=None):
        """
        turn on or off the mask
        """    
        if mask_idx is None:
            mask_idx = self.latest_mask_idx

        self.enable_mask[mask_idx] = mask_value

        self.apply_policy_to_model(mask_idx)
    
    def update_backward(
        self,
        mask_idx = None,
    ):
        """
        backpropagate through the learnable parameter

        Then update the model weights
        """
        if mask_idx is None:
            mask_idx = self.latest_mask_idx

        """Backward pass."""
        keys_to_backprop = [k for k in self.base_params if all(text not in k for text in self.matching_texts)]
        last_key = keys_to_backprop[-1]
        for k in keys_to_backprop[:-1]:
            self.compose_new_params(k, mask_idx).backward(
                self.model.get_parameter(k).grad, retain_graph=True
            )
        # release graph
        self.compose_new_params(last_key, mask_idx).backward(
            self.model.get_parameter(last_key).grad, retain_graph=False
        )

        self.apply_policy_to_model(mask_idx)

    def train(self, mask_idx=None):
        """
        Set the learnable parameters to training mode
        """
        if mask_idx is None:
            mask_idx = self.latest_mask_idx

        # set the learnable params to training
        for k in self.learnable_params[mask_idx]:
            self.model.get_parameter(k).requires_grad_(True)
            
    def apply_policy_to_model(self, mask_idx=None):
        """
        apply the weights on the model
        """
        if mask_idx is None:
            mask_idx = self.latest_mask_idx

        updated_params = {}
        for k in self.base_params:
            if any(skip in k for skip in self.matching_texts):
                updated_params[k] = self.base_params[k]
                continue
            updated_params[k] = self.compose_new_params(k, mask_idx)
        self.model.load_state_dict(updated_params, strict=False)

import requests
from PIL import Image
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

labels = ["cat", "dog", "bicycle"]
correct_label_idx = 2  # suppose bicycle is correct
text_inputs = clip_processor(text=labels, return_tensors="pt").to(device)
target = torch.tensor([correct_label_idx], device=device)

# patch the vision model
vision_model = Chicken(model.vision_model, device=device)
vision_model.add_class(class_names=labels)
vision_model.apply_policy_to_model()
vision_model.train()
print(vision_model.class_map)

# update the vision model
model.vision_model = vision_model

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(vision_model.get_trainable_parameters(), lr=1e-2)

for i in range(100):
    # Forward
    image_inputs = clip_processor(images=image, return_tensors="pt").to(device)
    outputs = model(**image_inputs, **text_inputs)
    logits = outputs.logits_per_image  # [1, num_labels]

    # Loss
    loss = loss_fn(logits, target)

    # Backward
    optimizer.zero_grad()
    loss.backward()
    vision_model.update_backward()
    optimizer.step()

    # Report
    predicted_class_idx = logits.argmax(-1).item()
    print(f"Step {i:03d} | Loss: {loss.item():.4f} | Prediction: {labels[predicted_class_idx]}")


# Image input
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image_inputs = clip_processor(images=image, return_tensors="pt").to(device)

# ---- ORIGINAL EMBEDDING ----
import time
start_time = time.time()
vision_model.toggle_mask(False)
with torch.no_grad():
    f_orig = model.get_image_features(**image_inputs)
    f_orig = f_orig / f_orig.norm(dim=-1, keepdim=True)

# # ---- ADAPTED EMBEDDING ----
vision_model.toggle_mask(True)
with torch.no_grad():
    f_adapted = model.get_image_features(**image_inputs)
    f_adapted = f_adapted / f_adapted.norm(dim=-1, keepdim=True)

# # ROUTING SCORE
routing_score = torch.nn.functional.cosine_similarity(f_orig, f_adapted, dim=-1)
print("\n[ROUTING CHECK]")
print("Cosine similarity:", routing_score.item())

# TEXT EMBEDDINGS AND SIMILARITY
labels = ["cat", "dog", "bicycle"]
text_inputs = clip_processor(text=labels, return_tensors="pt").to(device)

with torch.no_grad():
    text_embeds = model.get_text_features(**text_inputs)
    text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)

    sims_orig = torch.matmul(f_orig, text_embeds.T)
    sims_adapted = torch.matmul(f_adapted, text_embeds.T)

print("Total Time Taken", time.time() - start_time)
print("\n[TEXT SIMILARITY COMPARISON]")
print("Text prompts:", labels)
print("Similarity with original vision weights:", sims_orig.squeeze().tolist())
print("Similarity with adapted vision weights:", sims_adapted.squeeze().tolist())
print("Routing delta per class:", (sims_adapted - sims_orig).squeeze().tolist())
