import torch
import torch.nn as nn
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests

# Load CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda:1")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")


def get_mask(p, max_mult=1):
    return torch.sigmoid(p).to(torch.bfloat16) * max_mult

device = torch.device("cuda:1")

model = model.to(torch.float32)
base_params = model.vision_model.state_dict()

original_model_params = {
    k: v.clone().detach().cpu() for k, v in base_params.items() if "classifier" in k
}


decomposed_params = {}

class Policy(nn.Module):
    def __init__(self, base_params, gpu, init_val=0.1, max_mult=1, **kwargs):
        # Create learnable parameters.
        super().__init__()
        self.learnable_params = {}
        self.num_params = 0
        self.max_mult = max_mult
        for k, v in base_params.items():
            # each param initialized with small gaussian noise
            if 'layernorm' in k or 'bias' in k or 'embeddings' in k or 'layrnorm' in k or 'layer_norm' in k:
                continue
            else:
                self.learnable_params[k] = torch.nn.Parameter(
                    data=(
                        torch.randn(
                            min(v.shape),
                            device=gpu,
                            dtype=torch.bfloat16,
                        )
                        * init_val
                    ),
                    requires_grad=True,
                )
                self.num_params += self.learnable_params[k].numel()
        print(f"#params={self.num_params}")
        self.learnable_params_list = list(self.learnable_params.values())
        self.trainable_params = self.learnable_params_list
        self.learnable_params_module_list = nn.ParameterList(self.learnable_params_list)

    def get_learnable_params(self, detach=False):
        return self.learnable_params

    def set_trainable_params_values(self, new_values):
        with torch.no_grad():
            for p, v in zip(self.trainable_params, new_values):
                p.data.copy_(v)

    def get_mask(self, p):
        return torch.sigmoid(p).to(torch.bfloat16) * self.max_mult

for k, v in base_params.items():
    if 'layernorm' in k or 'bias' in k or 'embeddings' in k or 'layrnorm' in k or 'layer_norm' in k:
        continue  # skip this param

    # print(k)
    U, S, V = torch.svd(v)
    decomposed_params[f"{k}.U"] = U
    decomposed_params[f"{k}.S"] = S
    decomposed_params[f"{k}.V"] = V

policy = Policy(base_params, gpu="cuda:1", decomposed_params=decomposed_params, mode=1)
learnable_params = policy.get_learnable_params()

def compose_new_params(
    policy,
    param_name,
    decomposed_params,
    learnable_params,
):
    """Compose new parameters from decomposed parameters."""
    mm = get_mask(learnable_params[param_name])
    return (
        decomposed_params[f"{param_name}.U"]
        @ torch.diag_embed(decomposed_params[f"{param_name}.S"] * mm)
        @ decomposed_params[f"{param_name}.V"].T
    ) * (
        decomposed_params[f"{param_name}.S"].sum()
        / (decomposed_params[f"{param_name}.S"] * mm).sum()
    )

def backward(
    policy,
    model,
    base_params,
    decomposed_params,
    learnable_params,
):
    """Backward pass."""
    keys_to_backprop = [k for k in base_params if 'layernorm' not in k and 'bias' not in k and 'embeddings' not in k and 'layrnorm' not in k and 'layer_norm' not in k]
    last_key = keys_to_backprop[-1]
    for k in keys_to_backprop[:-1]:
        compose_new_params(policy, k, decomposed_params, learnable_params).backward(
            model.vision_model.get_parameter(k).grad, retain_graph=True
        )
    # release graph
    compose_new_params(policy, last_key, decomposed_params, learnable_params).backward(
        model.vision_model.get_parameter(last_key).grad, retain_graph=False
    )

new_params = {}
for k in base_params:
    if 'layernorm' in k or 'bias' in k or 'embeddings' in k or 'layrnorm' in k or 'layer_norm' in k:
        new_params[k] = base_params[k]
        continue  # skip this param

    new_params[k] = compose_new_params(
        policy, k, decomposed_params, learnable_params
    )

model.vision_model.load_state_dict(new_params)

# # set the learnable params to training
for k in learnable_params:
    model.vision_model.get_parameter(k).requires_grad_(True)

import requests
from PIL import Image
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(policy.trainable_params, lr=1e-2)

labels = ["cat", "dog", "bicycle"]
correct_label_idx = 2  # suppose bicycle is correct
text_inputs = clip_processor(text=labels, return_tensors="pt").to("cuda:1")
target = torch.tensor([correct_label_idx], device="cuda:1")

for i in range(100):
    # Forward
    image_inputs = clip_processor(images=image, return_tensors="pt").to("cuda:1")
    outputs = model(**image_inputs, **text_inputs)
    logits = outputs.logits_per_image  # [1, num_labels]

    # Loss
    loss = loss_fn(logits, target)

    # Backward
    optimizer.zero_grad()
    loss.backward()
    backward(policy, model, base_params, decomposed_params, learnable_params)
    optimizer.step()

    # Compose current masked weights
    new_params = {}
    for k in base_params:
        if any(skip in k for skip in ['layernorm', 'bias', 'embeddings', 'layer_norm', 'layrnorm']):
            new_params[k] = base_params[k]
            continue
        new_params[k] = compose_new_params(policy, k, decomposed_params, learnable_params)

    # Load into ViT only
    model.vision_model.load_state_dict(new_params, strict=False)


    # Refresh learnable params (optional if you're reinitializing mask references)
    # learnable_params = policy.get_learnable_params()

    # Report
    predicted_class_idx = logits.argmax(-1).item()
    print(f"Step {i:03d} | Loss: {loss.item():.4f} | Prediction: {labels[predicted_class_idx]}")
