import torch
import torch.nn as nn

import torchvision

import yaml
import sys, os, argparse
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from models.simclr import SimCLR

from utils.dataset_loader import get_dataset
from utils.metrics import KNN, NCCCEval, anisotropy
from utils.losses import NTXentLoss, WeakNTXentLoss
from utils.analysis import cal_cdnv, embedding_performance_nearest_mean_classifier
from utils.visualize_utils import line_plot
from utils.eval_utils import load_snapshot, get_ssl_minus_scl_loss

from collections import defaultdict
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False  # Ensures determinism

set_seed(42)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Evaluate the models for experiment 2')
    parser.add_argument('--config', '-c', type=str, help='Path to the config file')
    args = parser.parse_args()
    
    with open(args.config, 'r') as f:
        config = yaml.safe_load(f)

    dataset_name = config['dataset']['name']
    width_multiplier = config['model']['width_multiplier']
    hidden_dim = config['model']['hidden_dim']
    projection_dim = config['model']['projection_dim']
    # set device
    device = 'cuda'

    # define model
    encoder = torchvision.models.resnet50(weights=None)
    ssl_model = SimCLR(model=encoder,
                    dataset=dataset_name,
                    width_multiplier=width_multiplier,
                    hidden_dim=hidden_dim,
                    projection_dim=projection_dim,)


    # will need new dataloaders and new models multiple times
    # important arguments for dataloaders
    dataset_path = "/home/understanding-ssl/data"
    augment_both = True
    batch_size = 1024


    cl_train_epoch_losses = defaultdict(list) # {C=2: [losses for all epochs], C=4: [losses for all epochs], ...}
    cl_test_epoch_losses = defaultdict(list)
    nscl_train_epoch_losses = defaultdict(list)
    nscl_test_epoch_losses = defaultdict(list)


    # define loss
    temperature = 1.0
    ssl_criterion = NTXentLoss(temperature, device=device)
    weak_scl_criterion = WeakNTXentLoss(temperature, device=device)

    root_dir = "/home/understanding-ssl/experiments/"

    exp2_dir = config['experiment_name']

    all_models = os.listdir(os.path.join(root_dir, exp2_dir, 'checkpoints'))
    logs_dir = os.path.join(root_dir, exp2_dir, 'logs')

    # create a logs directory if not exist
    if not os.path.exists(logs_dir):
        os.makedirs(logs_dir)

    classes_groups = config['classes_groups']
    
    repeat = 1
    for classes in classes_groups:

        cl_train_mean_epoch_losses = defaultdict(list)
        cl_test_mean_epoch_losses = defaultdict(list)
        nscl_train_mean_epoch_losses = defaultdict(list)
        nscl_test_mean_epoch_losses = defaultdict(list)

        if dataset_name == "cifar10":
            dir_name = ""
            for c in classes:
                if c == classes[-1]:
                    dir_name += f"{c}"
                    continue
                dir_name += f"{c}_"
        elif dataset_name == 'cifar100' or dataset_name == 'imagenet':
            tot_C = len(classes)
            dir_name = f"{tot_C}_{repeat}"
            repeat += 1
        model = f"C{dir_name}"

        print(f"Calculating loss for model: {model}")
        snapshots_dir= os.path.join(root_dir, exp2_dir, 'checkpoints', model)
        output_logs_file = os.path.join(logs_dir, f"{model}.csv")

        # load dataframe if exsist
        if os.path.exists(output_logs_file):
            output_df = pd.read_csv(output_logs_file)
        else:
            output_df = pd.DataFrame(columns=['epoch', 'diff', 'ssl_loss', 'scl_loss', 
                                      'diff_test', 'ssl_loss_test', 'scl_loss_test'])          

        # load dataset
        _, train_loader, _, test_loader, _, _ = get_dataset(dataset_name=dataset_name, 
                                                    dataset_path=dataset_path,
                                                    augment_both_views=augment_both,
                                                    batch_size=batch_size, test=True,
                                                    classes=classes)
        
        # sort all checkpoints
        checkpoint_files = os.listdir(snapshots_dir)
        sorted_checkpoints = sorted(checkpoint_files, key=lambda x: int(x.split('_')[-1].split('.')[0]))

        for checkpoint in sorted_checkpoints:
            print(f"Calculating loss for checkpoint: {checkpoint}")
            snapshot_path = os.path.join(snapshots_dir, checkpoint)
            ssl_model = load_snapshot(snapshot_path, ssl_model, device)

            cur_epoch = int(checkpoint.split('_')[-1].split('.')[0])

            if cur_epoch in output_df['epoch'].values or cur_epoch>300:
                continue

            diff, ssl_loss, scl_loss = get_ssl_minus_scl_loss(ssl_model, train_loader, ssl_criterion, weak_scl_criterion,
                                                                labels_for_mapping=classes, device=device)
            cl_train_mean_epoch_losses[cur_epoch].append(ssl_loss)
            nscl_train_mean_epoch_losses[cur_epoch].append(scl_loss)

            diff_test, ssl_loss_test, scl_loss_test = get_ssl_minus_scl_loss(ssl_model, test_loader, ssl_criterion, weak_scl_criterion,
                                                                labels_for_mapping=classes, device=device)
            
            # log to dataframe
            new_row = pd.DataFrame([
                {'epoch': cur_epoch, 'diff': diff, 'ssl_loss': ssl_loss, 'scl_loss': scl_loss,
                'diff_test': diff_test, 'ssl_loss_test': ssl_loss_test, 'scl_loss_test': scl_loss_test}
            ])
            output_df = pd.concat([output_df, new_row], ignore_index=True)
            
            cl_test_mean_epoch_losses[cur_epoch].append(ssl_loss_test)
            nscl_test_mean_epoch_losses[cur_epoch].append(scl_loss_test)
        
        # save output_df to csv
        output_df.to_csv(output_logs_file, index=False)
        print(f"Output logs saved to: {output_logs_file}")


    # # take mean of all losses for all models for each epoch
    # for epoch in cl_train_mean_epoch_losses:
    #     cl_train_epoch_losses[exp2_dir].append(np.mean(cl_train_mean_epoch_losses[epoch]))
    #     cl_test_epoch_losses[exp2_dir].append(np.mean(cl_test_mean_epoch_losses[epoch]))
    #     nscl_train_epoch_losses[exp2_dir].append(np.mean(nscl_train_mean_epoch_losses[epoch]))
    #     nscl_test_epoch_losses[exp2_dir].append(np.mean(nscl_test_mean_epoch_losses[epoch]))