import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForImageClassification, AutoImageProcessor


def get_mask(p, max_mult=1):
    return torch.sigmoid(p).to(torch.bfloat16) * max_mult


device = torch.device("cuda")
image_processor = AutoImageProcessor.from_pretrained(
    "google/vit-base-patch16-224",
    use_fast=True,
)

model = AutoModelForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    torch_dtype=torch.float16,
    device_map="cuda",
    attn_implementation="sdpa"
)

model = model.to(torch.float32)
base_params = model.state_dict()
original_model_params = {
    k: v.clone().detach().cpu() for k, v in base_params.items() if "classifier" in k
}

decomposed_params = {}

class Policy(nn.Module):
    def __init__(self, base_params, gpu, init_val=0.1, max_mult=1, **kwargs):
        # Create learnable parameters.
        super().__init__()
        self.learnable_params = {}
        self.num_params = 0
        self.max_mult = max_mult
        for k, v in base_params.items():
            # each param initialized with small gaussian noise
            if 'layernorm' in k or 'bias' in k or 'embeddings' in k:
                continue
            else:
                self.learnable_params[k] = torch.nn.Parameter(
                    data=(
                        torch.randn(
                            min(v.shape),
                            device=gpu,
                            dtype=torch.bfloat16,
                        )
                        * 0.01
                        + init_val
                    ),
                    requires_grad=True,
                )
                self.num_params += self.learnable_params[k].numel()
        print(f"#params={self.num_params}")
        self.learnable_params_list = list(self.learnable_params.values())
        self.trainable_params = self.learnable_params_list
        self.learnable_params_module_list = nn.ParameterList(self.learnable_params_list)

    def get_learnable_params(self, detach=False):
        return self.learnable_params

    def set_trainable_params_values(self, new_values):
        with torch.no_grad():
            for p, v in zip(self.trainable_params, new_values):
                p.data.copy_(v)

    def get_mask(self, p):
        return torch.sigmoid(p).to(torch.bfloat16) * self.max_mult

for k, v in base_params.items():
    if 'layernorm' in k or 'bias' in k or 'embeddings' in k:
        continue  # skip this param

    # print(k)
    U, S, V = torch.svd(v)
    decomposed_params[f"{k}.U"] = U
    decomposed_params[f"{k}.S"] = S
    decomposed_params[f"{k}.V"] = V

policy = Policy(base_params, gpu="cuda", decomposed_params=decomposed_params, mode=1)
learnable_params = policy.get_learnable_params()

def compose_new_params(
    policy,
    param_name,
    decomposed_params,
    learnable_params,
):
    """Compose new parameters from decomposed parameters."""
    mm = get_mask(learnable_params[param_name])
    return (
        decomposed_params[f"{param_name}.U"]
        @ torch.diag_embed(decomposed_params[f"{param_name}.S"] * mm)
        @ decomposed_params[f"{param_name}.V"].T
    ) * (
        decomposed_params[f"{param_name}.S"].sum()
        / (decomposed_params[f"{param_name}.S"] * mm).sum()
    )

def backward(
    policy,
    model,
    base_params,
    decomposed_params,
    learnable_params,
):
    """Backward pass."""
    keys_to_backprop = [k for k in base_params if 'layernorm' not in k and 'bias' not in k and 'embeddings' not in k]
    last_key = keys_to_backprop[-1]
    for k in keys_to_backprop[:-1]:
        compose_new_params(policy, k, decomposed_params, learnable_params).backward(
            model.get_parameter(k).grad, retain_graph=True
        )
    # release graph
    compose_new_params(policy, last_key, decomposed_params, learnable_params).backward(
        model.get_parameter(last_key).grad, retain_graph=False
    )

new_params = {}
for k in base_params:
    if 'layernorm' in k or 'bias' in k or 'embeddings' in k:
        new_params[k] = base_params[k]
        continue  # skip this param

    new_params[k] = compose_new_params(
        policy, k, decomposed_params, learnable_params
    )
    
# loading the new parameters
model.load_state_dict(new_params)

# set the learnable params to training
for k in learnable_params:
    model.get_parameter(k).requires_grad_(True)

# testing
import requests
from PIL import Image
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(policy.trainable_params, lr=1e-2)

for i in range(100):
    target = torch.tensor([1]).cuda()

    inputs = image_processor(images=image, return_tensors="pt").to("cuda")
    outputs = model(**inputs)

    logits = outputs.logits
    probs = torch.nn.functional.softmax(logits, dim=-1)

    loss = loss_fn(probs, target)

    optimizer.zero_grad()
    loss.backward()
    backward(policy, model, base_params, decomposed_params, learnable_params)
    optimizer.step()

    ## need to reset the policy everytime
    learnable_params = policy.get_learnable_params()
    new_params = {}
    for k in base_params:
        if 'layernorm' in k or 'bias' in k or 'embeddings' in k:
            new_params[k] = base_params[k]
            continue  # skip this param

        new_params[k] = compose_new_params(
            policy, k, decomposed_params, learnable_params
        )
        
    # loading the new parameters
    model.load_state_dict(new_params)

    # model predicts one of the 1000 ImageNet classes
    predicted_class_idx = logits.argmax(-1).item()
    print("loss", loss, "Predicted class:", model.config.id2label[predicted_class_idx])