from transformers import ViTImageProcessor, ViTModel
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

## GLOBAL

T = 20 # Number of tasks

## LOADING THE MODEL

processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

class Model(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.fc = nn.Linear(768, num_classes)
    
    def forward(self, **inputs):
        outputs = self.model(**inputs)
        return self.fc(outputs.pooler_output)

model = Model(5)
device = "cuda:1"
model = model.to(torch.float32)
model = model.to(device)

##### POLICY PATCH

base_params = model.state_dict()

############ DECOMPOSE
decomposed_params = {}
for k, v in base_params.items():
    if 'layernorm' in k or 'bias' in k or 'embeddings' in k or 'layrnorm' in k or 'layer_norm' in k:
        continue  # skip this param

    # print(k)
    U, S, V = torch.svd(v)
    decomposed_params[f"{k}.U"] = U
    decomposed_params[f"{k}.S"] = S
    decomposed_params[f"{k}.V"] = V

######### INIT POLICY
from utils import Policy, apply_policy_to_model, backward, compose_new_params

policy = Policy(base_params, gpu=device, decomposed_params=decomposed_params, mode=1)
learnable_params = policy.get_learnable_params()

model = apply_policy_to_model(model, policy, base_params, decomposed_params, learnable_params)

for k, p in model.named_parameters():
    model.get_parameter(k).requires_grad_(False)

# # set the learnable params to training
for k in learnable_params:
    model.get_parameter(k).requires_grad_(True)

for k, p in model.named_parameters():
    print(k, p.requires_grad)

## LOADING THE DATASET

from dataset_cifar import CIFAR100Dataset
train_dataset = CIFAR100Dataset("Dataset/cifar-100-python", train=True, num_tasks=T)
test_dataset = CIFAR100Dataset("Dataset/cifar-100-python", train=False, num_tasks=T)

def collate_wrapper(batch):
    batch_inputs = {"pixel_values": []}
    batch_labels = []
    for data, label in batch:
        image = Image.fromarray(data)
        inputs = processor(images=image, return_tensors="pt")
        batch_labels.append(label)
        
        batch_inputs["pixel_values"].append(inputs["pixel_values"])
    batch_inputs['pixel_values'] = torch.cat(batch_inputs["pixel_values"], dim=0)
    batch_labels = torch.tensor(batch_labels)
    return batch_inputs, batch_labels

# train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_wrapper)
# test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=True, collate_fn=collate_wrapper)

num_epochs = 30

def accuracy_fn(logits, labels):
    pred = torch.argmax(logits, dim=-1)
    return torch.mean((pred == labels).to(torch.float32))

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(policy.trainable_params, lr=1e-3)

total_avg_acc = 0
for t in range(T):
    train_dataset.set_task(t)
    test_dataset.set_task(t)
    train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True, collate_fn=collate_wrapper)
    test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=collate_wrapper)
    
    ## INIT POLICY
    policy = Policy(base_params, gpu=device, decomposed_params=decomposed_params, mode=1)
    learnable_params = policy.get_learnable_params()

    model = apply_policy_to_model(model, policy, base_params, decomposed_params, learnable_params)

    for k, p in model.named_parameters():
        model.get_parameter(k).requires_grad_(False)

    # # set the learnable params to training
    for k in learnable_params:
        model.get_parameter(k).requires_grad_(True)
        
    for epoch in range(num_epochs):
        avg_acc = 0
        avg_loss = 0
        num = 0
        for data, labels in train_dataloader:
            labels = labels.to(device)
            for key in data:
                data[key] = data[key].to(device)
                
            logits = model(**data)
            probs = torch.nn.functional.softmax(logits, dim=-1)
        
            loss = loss_fn(probs, labels)
            acc = accuracy_fn(logits, labels)
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            backward(policy, model, base_params, decomposed_params, learnable_params)
            optimizer.step()
            
            ## UPDATE THE WEIGHTS
            ## need to reset the policy everytime
            learnable_params = policy.get_learnable_params()
            new_params = {}
            for k in base_params:
                if 'layernorm' in k or 'bias' in k or 'embeddings' in k:
                    new_params[k] = base_params[k]
                    continue  # skip this param

                new_params[k] = compose_new_params(
                    policy, k, decomposed_params, learnable_params
                )
                
            # loading the new parameters
            model.load_state_dict(new_params)
            
            num += 1
            avg_acc += acc.item()
            avg_loss += loss.item()
        
            print(f"(TRAINING) task: {t}, epoch: {epoch}/{num_epochs}, loss: {avg_loss/num:.4f}, acc: {avg_acc/num:.4f}")
            
        with torch.no_grad():
            avg_acc = 0
            avg_loss = 0
            num = 0
            for data, labels in test_dataloader:
                labels = labels.to(device)
                for key in data:
                    data[key] = data[key].to(device)
                    
                logits = model(**data)
                probs = torch.nn.functional.softmax(logits, dim=-1)
            
                loss = loss_fn(probs, labels)
                acc = accuracy_fn(logits, labels)
                
                num += 1
                avg_acc += acc.item()
                avg_loss += loss.item()
            
                print(f"(TESTING) task: {t}, epoch: {epoch}/{num_epochs}, loss: {avg_loss/num:.4f}, acc: {avg_acc/num:.4f}")
        total_avg_acc += avg_acc/num
        
print("total average accuracy", total_avg_acc)