from transformers import ViTImageProcessor, ViTModel
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

## GLOBAL

T = 40 # Number of tasks
task_name = "imagenet_a"

## LOADING THE MODEL

processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

class Model(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.fc = nn.Linear(768, num_classes)
    
    def forward(self, **inputs):
        outputs = self.model(**inputs)
        return self.fc(outputs.pooler_output)

model = Model(5)
device = "cuda:1"
model = model.to(torch.float32)
model = model.to(device)

##### POLICY PATCH

base_params = model.state_dict()

## CLONING
torch.save(base_params, "base.pt")
base_params = torch.load("base.pt")

############ DECOMPOSE
decomposed_params = {}
for k, v in base_params.items():
    if 'layernorm' in k or 'bias' in k or 'embeddings' in k or 'layrnorm' in k or 'layer_norm' in k:
        continue  # skip this param

    # print(k)
    U, S, V = torch.svd(v)
    decomposed_params[f"{k}.U"] = U
    decomposed_params[f"{k}.S"] = S
    decomposed_params[f"{k}.V"] = V

######### INIT POLICY
from utils import Policy, apply_policy_to_model, backward, compose_new_params

policy = Policy(base_params, gpu=device, decomposed_params=decomposed_params, mode=1)
learnable_params = policy.get_learnable_params()

model = apply_policy_to_model(model, policy, base_params, decomposed_params, learnable_params)

for k, p in model.named_parameters():
    model.get_parameter(k).requires_grad_(False)

# # set the learnable params to training
for k in learnable_params:
    model.get_parameter(k).requires_grad_(True)

# for k, p in model.named_parameters():
#     print(k, p.requires_grad)

## LOADING THE DATASET

from dataset_imagenet_r import ImageNetRDataset
train_dataset = ImageNetRDataset("Dataset/imagenet-a", train=True, num_tasks=T)
test_dataset = ImageNetRDataset("Dataset/imagenet-a", train=False, num_tasks=T)

from torchvision import datasets, transforms


input_size = 224
scale = (0.05, 1.0)
ratio = (3. / 4., 4. / 3.)
transform = transforms.Compose([
    transforms.RandomResizedCrop(input_size, scale=scale, ratio=ratio),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.Resize(256, interpolation=3),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])

def collate_wrapper(batch):
    batch_inputs = {"pixel_values": []}
    batch_labels = []
    for data, label, is_train in batch:
        image = Image.fromarray(data)
        if is_train:
            inputs = {"pixel_values": test_transform(image).unsqueeze(0)}
        else:
            inputs = {"pixel_values": test_transform(image).unsqueeze(0)}
        # inputs = processor(images=image, return_tensors="pt")
        batch_labels.append(label)
        
        batch_inputs["pixel_values"].append(inputs["pixel_values"])
    batch_inputs['pixel_values'] = torch.cat(batch_inputs["pixel_values"], dim=0)
    batch_labels = torch.tensor(batch_labels)
    
    
    # Transforms    
        # scale = (0.05, 1.0)
        # ratio = (3. / 4., 4. / 3.)
        
        # transform = [
        #     transforms.RandomResizedCrop(input_size, scale=scale, ratio=ratio),
        #     transforms.RandomHorizontalFlip(p=0.5),
        #     transforms.ToTensor(),
        # ]
    
    return batch_inputs, batch_labels

num_epochs = 125

def accuracy_fn(logits, labels):
    pred = torch.argmax(logits, dim=-1)
    return torch.mean((pred == labels).to(torch.float32))

loss_fn = torch.nn.CrossEntropyLoss()

torch.manual_seed(42)

for t in range(T):
    train_dataset.set_task(t)
    test_dataset.set_task(t)
    train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True, collate_fn=collate_wrapper)
    test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=collate_wrapper)
        
    ## RESET OF THE POLICY
    model = Model(5)
    device = "cuda:1"
    model = model.to(torch.float32)
    model = model.to(device)
    
    policy = Policy(base_params, gpu=device, decomposed_params=decomposed_params, mode=1)
    learnable_params = policy.get_learnable_params()

    model = apply_policy_to_model(model, policy, base_params, decomposed_params, learnable_params)

    for k, p in model.named_parameters():
        model.get_parameter(k).requires_grad_(False)

    # # set the learnable params to training
    for k in learnable_params:
        model.get_parameter(k).requires_grad_(True)
        
    optimizer = torch.optim.Adam(policy.trainable_params, lr=5e-3)

    ################# FIX ##############
    optimizer.zero_grad()                # clears grads for policy params
    model.zero_grad(set_to_none=True)    # clears grads for backbone + head
    ###################################

    for epoch in range(num_epochs):
        avg_acc = 0
        avg_loss = 0
        num = 0
        for data, labels in train_dataloader:
            labels = labels.to(device)
            for key in data:
                data[key] = data[key].to(device)
                
            logits = model(**data)
            probs = torch.nn.functional.softmax(logits, dim=-1)
        
            loss = loss_fn(probs, labels)
            acc = accuracy_fn(logits, labels)
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            backward(policy, model, base_params, decomposed_params, learnable_params)
            optimizer.step()
            
            ## UPDATE THE WEIGHTS
            ## need to reset the policy everytime
            learnable_params = policy.get_learnable_params()
            new_params = {}
            for k in base_params:
                if 'layernorm' in k or 'bias' in k or 'embeddings' in k:
                    new_params[k] = base_params[k]
                    continue  # skip this param

                new_params[k] = compose_new_params(
                    policy, k, decomposed_params, learnable_params
                )
                
            # loading the new parameters
            model.load_state_dict(new_params)
            
            num += 1
            avg_acc += acc.item()
            avg_loss += loss.item()
        
        print(f"(TRAINING) task: {t}, epoch: {epoch}/{num_epochs}, loss: {avg_loss/num:.4f}, acc: {avg_acc/num:.4f}")
            
        if epoch % 5 == 0:
            torch.save(policy, f"weights/{task_name}_task_num_{t}_{epoch}_{avg_loss/num:.4f}_{avg_acc/num:.4f}_training")

        with torch.no_grad():
            avg_acc = 0
            avg_loss = 0
            num = 0
            for data, labels in test_dataloader:
                labels = labels.to(device)
                for key in data:
                    data[key] = data[key].to(device)
                    
                logits = model(**data)
                probs = torch.nn.functional.softmax(logits, dim=-1)
                            
                loss = loss_fn(probs, labels)
                acc = accuracy_fn(logits, labels)
                
                num += 1
                avg_acc += acc.item()
                avg_loss += loss.item()
            
            print(f"(TESTING) task: {t}, epoch: {epoch}/{num_epochs}, loss: {avg_loss/num:.4f}, acc: {avg_acc/num:.4f}")
            
        if epoch % 5 == 0:
            torch.save(policy, f"weights/{task_name}_task_num_{t}_{epoch}_{avg_loss/num:.4f}_{avg_acc/num:.4f}_testing")