# load a pretrained encoder image on pytorch
import os
from tqdm import tqdm
from dataset.CelebADataset import CelebADataset
import torch
import torchvision.models as models
from torch.utils.data import DataLoader
from utilities.utils import fix_seed 
import pandas as pd
import copy
from utilities.utils_transform import resnet_transform as default_transform
from utilities.utils_transform import custom_transform
from argparser import get_args, get_models
import os
import torchvision
def main():
    args = get_args()
    train(args)

def train_one_task_cl(args: dict, 
                   task_name: str, 
                   task_index: int,
                   model: torch.nn.Module,
                   train_dataloader, 
                   seed: int, 
                   device: str, 
                   criterion: torch.nn.Module,
                   custom_transform,) -> None:
    best_loss = float('inf')
    epochs = args['epochs']
    early_stopping_tolerance = args['early_stopping_tolerance']
    early_stopping_counter = 0
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        with tqdm(total=len(train_dataloader), desc=f"[Task {task_index}][Epoch {epoch+1}/{epochs}]") as pbar:
            for sample in train_dataloader:
                image = sample["image"].to(device)
                cur_task_y = sample[task_name].type(torch.LongTensor).to(device)
                # optimizer.zero_grad()
                if args['model'] == 'co2l':
                    tqdm_loss = model.compute_loss(image, cur_task_y, custom_transform)
                elif args['model'] in ['der', 'derpp', 'er', 'fdr', 'gss']:
                    tqdm_loss = model.compute_loss(image, cur_task_y, image, criterion, custom_transform)
                # elif args['model'] == 'er':
                #     tqdm_loss = model.compute_loss(image, cur_task_y, image, custom_transform, criterion)
                else: 
                    raise ValueError("Model compute loss not found")
                total_loss += tqdm_loss
                if tqdm_loss < best_loss:
                    best_loss = tqdm_loss
                    if not os.path.exists(f"./models/CL/{args['dataset']}/{args['model']}/"):
                        os.makedirs(f"./models/CL/{args['dataset']}/{args['model']}/")
                    torch.save(model.state_dict(), f"./models/CL/{args['dataset']}/{args['model']}/{seed}_{task_name}_best_enc.pt")
                    early_stopping_counter = 0
                elif epoch > 1: # at least first epoch is not early stpping
                    early_stopping_counter += 1
                if early_stopping_counter >= early_stopping_tolerance:
                    return total_loss/len(train_dataloader)
                
                pbar.set_postfix(Loss=tqdm_loss) # accuracy not computed until the end
                pbar.update(1)
    model.end_task(train_dataloader, task_name)
    return total_loss/len(train_dataloader)

def train(args: dict):
    results = []
    seeds = [42, 43, 44, 45,12312]# 46, 47, 48, 49, 50, 51, 52]
    
    
    
    z_dim = args['z_dim']
    z_prime_dim = args['z_prime_dim']
    epochs = args['epochs']
    batch_size = args['batch_size']
    lr = args['lr']
    pretrain = args['pretrain']
    augment = args['augment']
    cls_output_dim = args['cls_output_dim']
    dataset = args['dataset']
    train_subsample_ratio = 1
    
    
    device = torch.device("cuda") #cuda
    
    prediction_targets = ["Wearing_Lipstick",
                          "Smiling", 
                          "Mouth_Slightly_Open", 
                          "High_Cheekbones", 
                          "Attractive", 
                          "Heavy_Makeup", 
                          "Male", 
                          "Young", 
                          "Wavy_Hair", 
                          "Straight_Hair"]
    num_tasks = len(prediction_targets)
    # Load the pretrained model
    if pretrain:
        resnet_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1).to(device)
    else:
        resnet_model = models.resnet18(weights=None).to(device)
    # augmentation
    if augment == 'default':
        resnet_transform = default_transform
    elif augment =='none':
        resnet_transform = torchvision.Compose([])
    else:
        raise ValueError("Augment not found")

    # dataset:
    if dataset == 'celeba':
        test_dataset = CelebADataset(split="test", transform=resnet_transform)
        test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
        train_dataset_overall = CelebADataset(split="train",
                                          transform=resnet_transform,
                                          num_tasks=len(prediction_targets), seed=seeds[0]) # just predifine a seed so that it is constant
    else:
        raise ValueError("Dataset not found")



    for seed in seeds:
        fix_seed(seed)
        encoder_past = torch.nn.Sequential(*(list(resnet_model.children())[:-1]), torch.nn.Flatten()).to(device)

        criterion = torch.nn.CrossEntropyLoss()
        
        for task_index, task_name in enumerate(prediction_targets):
            # initialize train data
            train_dataset = train_dataset_overall.split_data_by_task(task_index)
            train_dataset, _ = torch.utils.data.random_split(train_dataset,
                                                            [train_subsample_ratio, (1 - train_subsample_ratio)])
            train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
            
            cur_encoder = copy.deepcopy(encoder_past).to(device)

            model = get_models(args, 
                               cur_encoder, 
                               cls_output_dim=cls_output_dim,
                               lr = lr).to(device) # a type of nn.Module/ContinualLearning
           
            total_loss = train_one_task_cl(args, 
                                        task_name, 
                                        task_index,
                                        model, 
                                        train_dataloader, 
                                        seed, 
                                        device, 
                                        criterion, 
                                        custom_transform)
            # Compute accuracy on validation set
            model.load_state_dict(torch.load(f"./models/CL/{args['dataset']}/{args['model']}/{seed}_{task_name}_best_enc.pt"))

            temp_result = model.calculate_accuraciess(test_dataloader, prediction_targets[:task_index+1], device)
            
            cur_result = {
                'seed': seed,
                'task': task_name,
                'loss': total_loss/len(train_dataloader),
            }
            print(f'[Task {task_index}][Loss: {total_loss/len(train_dataloader):.4f}, {task_name}: {temp_result[task_name]:.4f}][PAST TASK:', end=' ')
            # if prediction_targets[:task_index] != []:
                # not first task
            # temp_result = model.calculate_accuracies(predictors_dict, test_dataloader, prediction_targets[:task_index], device)
            # reverse order
            for each_task in prediction_targets[::-1]:
                if each_task in temp_result:
                    print(f'{each_task}: {temp_result[each_task]:.4f}', end=' ')
                    cur_result[each_task+'_accuracy'] = temp_result[each_task]
                else:
                    # for those None still add to the dictionary
                    cur_result[each_task+'_accuracy'] = None
            print(']')
            # after running past tasks, we can add the current task predictor to the dictionary
    
            results.append(cur_result)
    
    # save to a file:
    df = pd.DataFrame(results)
    if not os.path.exists(f"./experiment_results/CL/{args['dataset']}/{args['model']}/"):
        os.makedirs(f"./experiment_results/CL/{args['dataset']}/{args['model']}/")
    df.to_csv(f"./experiment_results/CL/{args['dataset']}/{args['model']}/{len(prediction_targets)}tasks.csv", index=False)
    


if __name__ == "__main__":
    main()
           