# load a pretrained encoder image on pytorch
import os
from tqdm import tqdm
from dataset.CelebADataset import CelebADataset
import torch
import torchvision.models as models
from torch.utils.data import DataLoader
from utilities.utils import fix_seed
import pandas as pd
import copy
from utilities.utils_transform import resnet_transform as default_transform
from utilities.utils_transform import custom_transform
from argparser import get_args, get_models, get_models_MTL
import os
import torchvision
import numpy as np
from models.PCGradMTL import PCGrad

if __name__ == "__main__":
    results = []
    seeds = [42, 43, 44, 45, 12312]  # 46, 47, 48, 49, 50, 51, 52]

    args = get_args()

    z_dim = args['z_dim']
    z_prime_dim = args['z_prime_dim']
    epochs = args['epochs']
    batch_size = args['batch_size']
    lr = args['lr']
    pretrain = args['pretrain']
    augment = args['augment']
    cls_output_dim = args['cls_output_dim']
    train_subsample_ratio = 1

    device = torch.device("cuda")  # cuda

    prediction_targets = ["Wearing_Lipstick",
                          "Smiling",
                          "Mouth_Slightly_Open",
                          "High_Cheekbones",
                          "Attractive",
                          "Heavy_Makeup",
                          "Male",
                          "Young",
                          "Wavy_Hair",
                          "Straight_Hair"]
    
    # for CelebA, each task is binary classification
    task_num_class_dict = {task_name : 2 for task_name in prediction_targets}
    
    num_tasks = len(prediction_targets)
    # Load the pretrained model
    if pretrain:
        resnet_model = models.resnet18(
            weights=models.ResNet18_Weights.IMAGENET1K_V1).to(device)
    else:
        resnet_model = models.resnet18(weights=None).to(device)

    if augment == 'default':
        resnet_transform = default_transform
    elif augment == 'none':
        resnet_transform = torchvision.Compose([])
    else:
        raise ValueError("Augment not found")

    encoder_past = torch.nn.Sequential(
        *(list(resnet_model.children())[:-1]), torch.nn.Flatten()).to(device)
    criterion = torch.nn.CrossEntropyLoss()

    # for each seed
    for seed in seeds:
        fix_seed(seed)

        train_dataset = CelebADataset(split="train", transform=resnet_transform, seed=seed, num_tasks=None)
        train_dataset, _ = torch.utils.data.random_split(train_dataset,
                                                        [train_subsample_ratio, (1 - train_subsample_ratio)])
        valid_dataset = CelebADataset(split="valid", transform=resnet_transform)
        test_dataset = CelebADataset(split="test", transform=resnet_transform)

        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
        test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
        
        predictors_dict = {task_name: None for task_name in prediction_targets}

        cur_encoder = copy.deepcopy(encoder_past).to(device)

        model = get_models_MTL(args,
                               encoder=cur_encoder,
                               tasks_name_to_cls_num=task_num_class_dict,
                               cls_output_dim=cls_output_dim).to(device)
        
        model_parameters = filter(lambda p: p.requires_grad, model.parameters())
        params = sum([np.prod(p.size()) for p in model_parameters])
        print(f"Trainable parameters: {params}")


        best_loss = 9999999999999
        for epoch in range(epochs):
            with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch+1}/{epochs}") as pbar:
                # training loop
                model.train()
                train_total_loss = 0
                for sample in train_dataloader:                    
                    image = sample["image"].to(device)
                    tasks_y = {task_name: sample[task_name].type(torch.long).to(
                        device) for task_name in prediction_targets}
                    train_loss = model.compute_loss(image, tasks_y, criterion)
                    tqdm_loss = train_loss.item()
                    train_total_loss += tqdm_loss
                    pbar.set_postfix(train_loss=tqdm_loss)
                    pbar.update(1)
                
                # validation loop
                model.eval()
                valid_total_loss = 0
                for sample in valid_dataloader:
                    image = sample["image"].to(device)
                    tasks_y = {task_name: sample[task_name].type(torch.long).to(
                        device) for task_name in prediction_targets}
                    val_loss = model.compute_loss_nograd(image, tasks_y, criterion)
                    valid_total_loss += val_loss.item()
                    
                if valid_total_loss < best_loss:
                    best_loss = valid_total_loss
                    if not os.path.exists(f"./models/MTL/celebA/sub_{train_subsample_ratio:.2f}/{args['model']}/"):
                        os.makedirs(
                            f"./models/MTL/celebA/sub_{train_subsample_ratio:.2f}/{args['model']}/")
                    torch.save(model.state_dict(),
                                f"./models/MTL/celebA/sub_{train_subsample_ratio:.2f}/{args['model']}/{seed}_best_enc.pt")

                    # accuracy not computed until the end

        # Compute accuracy on validation set
        model.load_state_dict(torch.load(
            f"./models/MTL/celebA/sub_{train_subsample_ratio:.2f}/{args['model']}/{seed}_best_enc.pt"))
        temp_result = model.calculate_accuraciess(
            test_dataloader, prediction_targets, device)

        cur_result = {
            'seed': seed,
            'epoch': epoch,
            'train loss': train_total_loss/len(train_dataloader),
            'valid loss': valid_total_loss/len(valid_dataloader)}
        
        print(f"[Epoch {epoch+1}, train loss: {cur_result['train loss']:.4f}, valid loss: {cur_result['valid loss']:.4f}][TASKS:", end=' ')

        for each_task in prediction_targets:
            if each_task in temp_result:
                print(f'{each_task}: {temp_result[each_task]:.4f}', end=' ')
                cur_result[each_task+'_accuracy'] = temp_result[each_task]
            else:
                # for those None still add to the dictionary
                cur_result[each_task+'_accuracy'] = None
        print(']')
        # after running past tasks, we can add the current task predictor to the dictionary

        results.append(cur_result)

    # save to a file:
    df = pd.DataFrame(results)
    if not os.path.exists(f"./experiment_results/MTL/celebA/sub_{train_subsample_ratio:.2f}/{args['model']}/"):
        os.makedirs(
            f"./experiment_results/MTL/celebA/sub_{train_subsample_ratio:.2f}/{args['model']}/")
    df.to_csv(
        f"./experiment_results/MTL/celebA/sub_{train_subsample_ratio:.2f}/{args['model']}/{len(prediction_targets)}tasks.csv", index=False)
