import torch
import numpy as np
import copy
import random 
import argparse
import os
from datetime import datetime

# Import custom utility functions and models
from utils_baseline_norm import *
from utils_intermediate import *
from utils.misc import *
from utils.misc_cifar import *
from utils.models import *


######################################################################################################################################
current_time = datetime.now().strftime("%y%m%d_%H%M%S")
parser = argparse.ArgumentParser()

# ---  --- #
parser.add_argument('--benchmark', type=str, default='cifar100', help=['cifar10', 'cifar100', 'digit'])

# ---  --- #
parser.add_argument('--rho', type=float, default=0.01, help=[0.01, 0.1, 1.])

# --- Training detils for TTA --- #
parser.add_argument('--batch_size', type=int, default=200)
parser.add_argument('--num_epochs', type=int, default=1)
parser.add_argument('--pl_threshold', type=float, default=0.9)
parser.add_argument('--optim_type', type=str, default='adam', help=['sgd', 'adam'])
parser.add_argument('--ft_layers', type=str, default='bn', help=['bn', 'tent', 'whole'])

# ---  --- #
parser.add_argument('--seed', type=int, default=1)
# ---  --- #

args = parser.parse_args()
print ( args )

######################################################################################################################################
# Set random seeds for reproducibility across runs
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

#########################################################################################
# load training dataset
if args.benchmark == 'cifar10':
    x_train, y_train = load_cifar10_train_dataset()

    num_classes = 10

elif args.benchmark == 'cifar100':
    x_train, y_train = load_cifar100_train_dataset()
    num_classes = 100

# make balanced intermedate dataset
x_train, y_train = make_balanced_intermediate_dataset(x_train, y_train, num_classes)
args.num_classes = num_classes


# load pre-trained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if args.benchmark == 'cifar10':
    ckpt_path = "path of trained classifier"

    net_init = Normalized_ResNet(depth=26)
    checkpoint = torch.load(ckpt_path)
    checkpoint = checkpoint['net']

    net_init.to(device)
    net_init = torch.nn.DataParallel(net_init)
    net_init.load_state_dict(checkpoint)

    # load trained g_phi
    num_hidden = 1000 
    g_phi = MLP(num_classes, num_hidden)
    g_phi.load_state_dict(torch.load("./eval_results/cifar10/trained_gphi.pt"))
    g_phi.to(device)

elif args.benchmark == 'cifar100':
    ckpt_path = "path of trained classifier"
    net_init = Normalized_ResNet_CIFAR100()
    net_init = torch.nn.DataParallel(net_init)

    checkpoint = torch.load(ckpt_path)
    net_init.load_state_dict(checkpoint["net"])

    net_init.to(device)

    # load trained g_phi
    num_hidden = 1000 
    g_phi = MLP(num_classes, num_hidden)
    g_phi.load_state_dict(torch.load("./eval_results/cifar100/trained_gphi.pt"))
    g_phi.to(device)

else:
    print ('there is no pre-trained model for the benchmark')

# configure model
if args.ft_layers == 'bn':
    net_adapt = copy.deepcopy(net_init)
    net_adapt = configure_model_bn(net_adapt)

elif args.ft_layers == 'tent':
    net_adapt = copy.deepcopy(net_init)
    net_adapt = configure_model_tent(net_adapt)

elif args.ft_layers == 'noadapt':
    net_adapt = copy.deepcopy(net_init)
    net_adapt = configure_model_noadapt(net_adapt)
    print ('line100')

else:
    print ('there is no ft layers')

#### temp test datasets
if args.benchmark == 'cifar10':
    x_tests, y_tests, test_datasets = load_cifar10_test_datasets()
    x_test_bal, y_test_bal = x_tests[1], y_tests[1]
    x_test, y_test, _ = make_LT_datasets(x_test_bal, y_test_bal, args.rho, num_classes)
    print ('temp test dataset', x_test.size())
    acc_bnadapt_te, logits_te, labels_te = tta_bnadapt(x_test ,y_test, net_adapt, args)
    num_classes = 10

elif args.benchmark == 'cifar100':
    x_tests, y_tests, test_datasets = load_cifar100_test_datasets()
    x_test_bal, y_test_bal = x_tests[1], y_tests[1]
    x_test, y_test, _ = make_LT_datasets(x_test_bal, y_test_bal, args.rho, num_classes)
    print ('line113// temp x_test.size() is ', x_test.size())
    acc_bnadapt_te, logits_te, labels_te = tta_bnadapt(x_test ,y_test, net_adapt, args)
    num_classes = 100

else:
    print ('there is no benchmark')


def compute_mask():
    ## compute mask
    # load speckle noise lv 1 dataset
    def load_corrupted_cifar100_train_dataset():
        # load x_train and y_train
        severity = 1
        n_img = 50000
        n_total_cifar = 50000

        images_all = np.load("path of speckle noised CIFAR-100C training image numpy file")
        labels_all = np.load("path of speckle noised CIFAR-100C training label numpy file")
        
        images = images_all[(severity - 1) * n_total_cifar:severity *n_total_cifar]
        labels = labels_all[(severity - 1) * n_total_cifar:severity *n_total_cifar]
        # Make it in the PyTorch format
        x_train = np.transpose(images, (0, 3, 1, 2))
        # Make it compatible with our models
        x_train = x_train.astype(np.float32) / 255
        # Make sure that we get exactly n_examples but not a few samples more
        x_train = torch.tensor(x_train)
        y_train = torch.tensor(labels).long()

        return x_train, y_train

    x_train, y_train = load_corrupted_cifar100_train_dataset()
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    net_adapt = copy.deepcopy(net_init)
    net_adapt = configure_model_bn(net_adapt)
    acc_bnadapt_tr, logits_tr, labels_tr = tta_bnadapt(x_train, y_train, net_adapt, args)
    del net_adapt
    from sklearn.metrics import confusion_matrix
    con_mat = confusion_matrix(labels_tr, torch.squeeze(logits_tr.argmax(-1)).float())
    con_mat_norm = con_mat / con_mat.sum(1, keepdims=True)

    mask = torch.zeros(args.num_classes, args.num_classes)
    for i in range(args.num_classes):
        for j in range(args.num_classes):
            if con_mat_norm[i,j] > 0.:
                mask[i,j] = 1 
                mask[j,i] = 1
        
        mask[i,i] = 1
    print ("line175", mask.float().mean())
    return mask

batch_size = args.batch_size

net_adapt = copy.deepcopy(net_init)
if args.ft_layers == 'tent':
    net_adapt = configure_model_tent(net_adapt)
elif args.ft_layers == 'bn':
    net_adapt = configure_model_bn(net_adapt)
elif args.ft_layers == 'noadapt':
    net_adapt = configure_model_noadapt(net_adapt)
    print ('line131')
else:
    del net_adapt


# Define a list of TTA methods to be evaluated// Expand the list with different variations of TTA methods
baselines = ['bnadapt','delta','tent', 'pseudolabel', 'note', 'lame',  'odsN']
tta_methods = ['noadapt']
for baseline in baselines:
    tta_methods.append(baseline)
    tta_methods.append(baseline+'_ours')
    tta_methods.append(baseline+'_ours'+'_diag')
    tta_methods.append(baseline+'_ours'+'_online')
print ('tta_methods are' , tta_methods)

# Loop through the test datasets
for i, test_dataset in enumerate(test_datasets):
    x_test, y_test = x_tests[i], y_tests[i]

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # construct long-tailed test datasets
    x_test_bal, y_test_bal = x_test, y_test
    x_test, y_test, classwise_num_ex = make_LT_datasets(x_test_bal, y_test_bal, args.rho, args.num_classes)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # compute mask
    mask = compute_mask()

    # Define the save path for storing accuracy results
    save_path = "./eval_results/tta/cifar100/tslog"

    if not os.path.exists(save_path):
        accs_storage = {}
        accs_storage['args'] = args
        print ('acc storage is initialized')
    else:
        accs_storage = torch.load(save_path)
        print ('acc storage is loaded')

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Calculate accuracy for baseline1 (noadapt)
    acc_noadapt, _, _  = tta_noadapt(x_test ,y_test, net_init, args.batch_size, num_classes)
    accs_storage["noadapt"] = acc_noadapt

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Calculate accuracy for baseline2 (bnadapt)
    net_adapt = copy.deepcopy(net_init)
    net_adapt = configure_model_bn(net_adapt)
    acc_bnadapt, logits_te, labels_te = tta_bnadapt(x_test ,y_test, net_adapt, args)
    accs_storage["bnadapt"] = acc_bnadapt
    del net_adapt
    
    # Calculate the T_te tensor using g_phi
    with torch.no_grad():
        T_te = g_phi(torch.softmax(logits_te, -1).mean(0).cuda()).cpu()
        T_te = T_te * mask

    # Loop through different TTA methods
    for tta_method in tta_methods:
        if tta_method in accs_storage.keys():
            accs = accs_storage[tta_method]
            if type(accs) == list:
                accs_4f = ["%.4f"%item for item in accs]
                print (test_dataset, tta_method, accs_4f, 'pre-cal')
            else:
                print (test_dataset, tta_method, "%.4f"%accs, 'pre-cal')

            continue

        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        # Check if the TTA method is online or diagonal
        is_online = tta_method.split('_')[-1] == 'online'
        is_diag = tta_method.split('_')[-1] == 'diag'

        # Determine T and g based on the TTA method
        if 'ours' in tta_method:
            if is_online:
                T = None
                g = g_phi if 'ours' in tta_method else None
            elif is_diag:
                T = torch.diag(torch.diag(T_te.clone()))
                g = None
            else:
                T = T_te.clone()
                g = None
        else:
            T, g = None, None

        # Calculate accuracy for different TTA methods
        if 'bnadapt' in tta_method :
            net_adapt = copy.deepcopy(net_init)
            net_adapt = configure_model_bn(net_adapt)
            accs,_,_ = tta_bnadapt(x_test, y_test, net_adapt, args, T, g)
            del net_adapt

        elif 'tent' in tta_method :
            net_adapt = copy.deepcopy(net_init)
            net_adapt = configure_model_tent(net_adapt)
            accs = tta_tent(x_test, y_test, net_adapt, args,T, g)
            del net_adapt

        elif 'pseudolabel' in tta_method :
            net_adapt = copy.deepcopy(net_init)
            net_adapt = configure_model_bn(net_adapt)
            accs = tta_pseudolabel(x_test, y_test, net_adapt, args, T, g)
            del net_adapt   

        elif 'lame' in tta_method:
            net_adapt = copy.deepcopy(net_init)
            net_adapt = configure_model_bn(net_adapt)
            accs = tta_lame(x_test, y_test, net_adapt, args,T, g)

        elif 'delta' in tta_method:
            net_adapt = copy.deepcopy(net_init)
            net_adapt = configure_model_bn(net_adapt)
            accs = tta_delta(x_test, y_test, net_adapt, args,T, g)

        elif 'note' in tta_method:
            net_adapt = copy.deepcopy(net_init)
            net_adapt = configure_model_bn(net_adapt)
            accs = tta_note(x_test, y_test, net_adapt, args,T, g)

        elif 'odsN' in tta_method:
            net_adapt = copy.deepcopy(net_init)
            net_adapt = configure_model_bn(net_adapt)
            accs = tta_ods_note(x_test, y_test, net_adapt, args,T, g)

        # Store the calculated accuracy in the storage dictionary
        accs_storage[tta_method] = accs
        if type(accs) == list:
            accs_4f = ["%.4f"%item for item in accs]
            print (test_dataset, tta_method, accs_4f)
        else:
            print (test_dataset, tta_method, "%.4f"%accs)

        # Save the accuracy results to the storage file
        torch.save(accs_storage, save_path)
    # Save the accuracy results to the storage file
    torch.save(accs_storage, save_path)
    print (test_dataset, accs_storage)