import torch
import torch.nn as nn
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests

# Load CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda:1")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
device = torch.device("cuda:1")
model = model.to(torch.float32)
base_params = model.vision_model.state_dict()
original_model_params = {k: v.clone().detach().cpu() for k, v in base_params.items() if "classifier" in k}
decomposed_params = {}

class Policy(nn.Module):
    def __init__(self, base_params, gpu, num_masks=1, vertical_mask=True, enable_mask = None, init_val=0.1, max_mult=1, **kwargs):
        # Create learnable parameters.
        super().__init__()
        self.learnable_params = {}
        self.num_params = 0
        self.max_mult = max_mult
        self.enable_mask = [True] * num_masks
        self.trainable_params = {}
        if vertical_mask == True:
            for mask_idx in range(num_masks):
                mask_params = {}
                for k, v in base_params.items():
                    # each param initialized with small gaussian noise
                    if 'layernorm' in k or 'bias' in k or 'embeddings' in k or 'layrnorm' in k or 'layer_norm' in k:
                        continue
                    else:
                        mask_params[k] = nn.Parameter(
                            torch.randn(min(v.shape), device=gpu, dtype=torch.bfloat16) * init_val,
                            requires_grad=True
                        )
                    self.num_params += mask_params[k].numel()
                self.learnable_params[mask_idx] = mask_params
                self.learnable_params_list = list(mask_params.values())
                self.trainable_params[mask_idx] = self.learnable_params_list
        else:
            AssertionError("Havent Implemented Horizontal Masking Yet")
        print(f"#params={self.num_params}")

    def get_learnable_params(self, detach=False):
        return self.learnable_params

    def set_trainable_params_values(self, new_values):
        with torch.no_grad():
            for p, v in zip(self.trainable_params, new_values):
                p.data.copy_(v)

    def get_mask(self, p, mask_idx):
        if self.enable_mask[mask_idx]:
            return torch.sigmoid(p).to(torch.bfloat16) * self.max_mult
        else:
            return torch.ones_like(p).to(torch.bfloat16)
        
for k, v in base_params.items():
    if 'layernorm' in k or 'bias' in k or 'embeddings' in k or 'layrnorm' in k or 'layer_norm' in k:
        continue  # skip this param

    # print(k)
    U, S, V = torch.svd(v)
    decomposed_params[f"{k}.U"] = U
    decomposed_params[f"{k}.S"] = S
    decomposed_params[f"{k}.V"] = V

policy = Policy(base_params, gpu="cuda:1", decomposed_params=decomposed_params, mode=1)
learnable_params = policy.get_learnable_params()
mask_idx = 0

def compose_new_params(
    policy,
    param_name,
    decomposed_params,
    learnable_params,
    mask_idx,
):
    """Compose new parameters from decomposed parameters."""
    mm = policy.get_mask(learnable_params[mask_idx][param_name], mask_idx)
    return (
        decomposed_params[f"{param_name}.U"]
        @ torch.diag_embed(decomposed_params[f"{param_name}.S"] * mm)
        @ decomposed_params[f"{param_name}.V"].T
    ) * (
        decomposed_params[f"{param_name}.S"].sum()
        / (decomposed_params[f"{param_name}.S"] * mm).sum()
    )

def backward(
    policy,
    model,
    base_params,
    decomposed_params,
    learnable_params,
    mask_idx,
):
    """Backward pass."""
    keys_to_backprop = [k for k in base_params if 'layernorm' not in k and 'bias' not in k and 'embeddings' not in k and 'layrnorm' not in k and 'layer_norm' not in k]
    last_key = keys_to_backprop[-1]
    for k in keys_to_backprop[:-1]:
        compose_new_params(policy, k, decomposed_params, learnable_params, mask_idx).backward(
            model.vision_model.get_parameter(k).grad, retain_graph=True
        )
    # release graph
    compose_new_params(policy, last_key, decomposed_params, learnable_params, mask_idx).backward(
        model.vision_model.get_parameter(last_key).grad, retain_graph=False
    )

new_params = {}
for k in base_params:
    if 'layernorm' in k or 'bias' in k or 'embeddings' in k or 'layrnorm' in k or 'layer_norm' in k:
        new_params[k] = base_params[k]
        continue  # skip this param

    new_params[k] = compose_new_params(
        policy, k, decomposed_params, learnable_params, mask_idx
    )

model.vision_model.load_state_dict(new_params)

# set the learnable params to training
for k in learnable_params[mask_idx]:
    model.vision_model.get_parameter(k).requires_grad_(True)
    
def apply_policy_to_model(policy, base_params, decomposed_params, learnable_params, mask_idx):
    updated_params = {}
    for k in base_params:
        if any(skip in k for skip in ['layernorm', 'bias', 'embeddings', 'layer_norm', 'layrnorm']):
            updated_params[k] = base_params[k]
            continue
        updated_params[k] = compose_new_params(policy, k, decomposed_params, learnable_params, mask_idx)
    model.vision_model.load_state_dict(updated_params, strict=False)

import requests
from PIL import Image
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(policy.trainable_params[mask_idx], lr=1e-2)

labels = ["cat", "dog", "bicycle"]
correct_label_idx = 2  # suppose bicycle is correct
text_inputs = clip_processor(text=labels, return_tensors="pt").to("cuda:1")
target = torch.tensor([correct_label_idx], device="cuda:1")

for i in range(100):
    # Forward
    image_inputs = clip_processor(images=image, return_tensors="pt").to("cuda:1")
    outputs = model(**image_inputs, **text_inputs)
    logits = outputs.logits_per_image  # [1, num_labels]

    # Loss
    loss = loss_fn(logits, target)

    # Backward
    optimizer.zero_grad()
    loss.backward()
    backward(policy, model, base_params, decomposed_params, learnable_params, mask_idx)
    optimizer.step()

    # Compose current masked weights
    apply_policy_to_model(policy, base_params, decomposed_params, learnable_params, mask_idx)
    
    # Report
    predicted_class_idx = logits.argmax(-1).item()
    print(f"Step {i:03d} | Loss: {loss.item():.4f} | Prediction: {labels[predicted_class_idx]}")


# # Image input
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image_inputs = clip_processor(images=image, return_tensors="pt").to("cuda:1")

# # ---- ORIGINAL EMBEDDING ----
import time
start_time = time.time()
policy.enable_mask[mask_idx] = False
apply_policy_to_model(policy, base_params, decomposed_params, learnable_params, mask_idx)
with torch.no_grad():
    f_orig = model.get_image_features(**image_inputs)
    f_orig = f_orig / f_orig.norm(dim=-1, keepdim=True)

# # ---- ADAPTED EMBEDDING ----
policy.enable_mask[mask_idx] = True
apply_policy_to_model(policy, base_params, decomposed_params, learnable_params, mask_idx)
with torch.no_grad():
    f_adapted = model.get_image_features(**image_inputs)
    f_adapted = f_adapted / f_adapted.norm(dim=-1, keepdim=True)

# # ROUTING SCORE
routing_score = torch.nn.functional.cosine_similarity(f_orig, f_adapted, dim=-1)
print("\n[ROUTING CHECK]")
print("Cosine similarity:", routing_score.item())

# TEXT EMBEDDINGS AND SIMILARITY
labels = ["cat", "dog", "bicycle"]
text_inputs = clip_processor(text=labels, return_tensors="pt").to("cuda:1")

with torch.no_grad():
    text_embeds = model.get_text_features(**text_inputs)
    text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)

    sims_orig = torch.matmul(f_orig, text_embeds.T)
    sims_adapted = torch.matmul(f_adapted, text_embeds.T)

print("Total Time Taken", time.time() - start_time)
print("\n[TEXT SIMILARITY COMPARISON]")
print("Text prompts:", labels)
print("Similarity with original vision weights:", sims_orig.squeeze().tolist())
print("Similarity with adapted vision weights:", sims_adapted.squeeze().tolist())
print("Routing delta per class:", (sims_adapted - sims_orig).squeeze().tolist())
