import torch
import numpy as np
import copy
import random 
import argparse
import os
from datetime import datetime

# Import custom utility functions and models
from utils_baseline import *
from utils_intermediate import *
from utils.misc import *
from utils.misc_cifar import *
from utils.models import *


######################################################################################################################################
current_time = datetime.now().strftime("%y%m%d_%H%M%S")
parser = argparse.ArgumentParser()

# ---  --- #
parser.add_argument('--benchmark', type=str, default='cifar10', help=['cifar10', 'cifar100', 'digit'])

# ---  --- #
parser.add_argument('--rho', type=float, default=0.01, help=[0.01, 0.1, 1.])

# --- Training detils for TTA --- #
parser.add_argument('--batch_size', type=int, default=200)
parser.add_argument('--num_epochs', type=int, default=1)
parser.add_argument('--pl_threshold', type=float, default=0.9)
parser.add_argument('--optim_type', type=str, default='adam', help=['sgd', 'adam'])
parser.add_argument('--ft_layers', type=str, default='bn', help=['bn', 'tent', 'whole'])

# ---  --- #
parser.add_argument('--seed', type=int, default=1)
# ---  --- #

args = parser.parse_args()
print ( args )

######################################################################################################################################
# Set random seeds for reproducibility across runs
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

#########################################################################################
args.num_classes = num_classes

# Load a pre-trained model based on the benchmark
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if args.benchmark == 'cifar10':
    ckpt_path = "path of trained classifier"

    net_init = Normalized_ResNet(depth=26)
    checkpoint = torch.load(ckpt_path)
    checkpoint = checkpoint['net']

    net_init.to(device)
    net_init = torch.nn.DataParallel(net_init)
    net_init.load_state_dict(checkpoint)

    # load trained g_phi
    num_hidden = 1000 
    g_phi = MLP(num_classes, num_hidden)
    g_phi.load_state_dict(torch.load("./eval_results/cifar10/trained_gphi.pt"))
    g_phi.to(device)

elif args.benchmark == 'cifar100':
    ckpt_path = "path of trained classifier"
    net_init = Normalized_ResNet_CIFAR100()
    net_init = torch.nn.DataParallel(net_init)

    checkpoint = torch.load(ckpt_path)
    net_init.load_state_dict(checkpoint["net"])

    net_init.to(device)

    # load trained g_phi
    num_hidden = 1000 
    g_phi = MLP(num_classes, num_hidden)
    g_phi.load_state_dict(torch.load("./eval_results/cifar100/trained_gphi.pt"))
    g_phi.to(device)
else:
    print ('there is no pre-trained model for the benchmark')


# load test datasets
if args.benchmark == 'cifar10':
    x_tests, y_tests, test_datasets = load_cifar10_test_datasets()
    num_classes = 10
elif args.benchmark == 'cifar100':
    x_tests, y_tests, test_datasets = load_cifar100_test_datasets()
    num_classes = 100
else:
    print ('there is no benchmark')

batch_size = args.batch_size

# configure model
net_adapt = copy.deepcopy(net_init)
if args.ft_layers == 'tent':
    net_adapt = configure_model_tent(net_adapt)
elif args.ft_layers == 'bn':
    net_adapt = configure_model_bn(net_adapt)
elif args.ft_layers == 'noadapt':
    net_adapt = configure_model_noadapt(net_adapt)
else:
    del net_adapt

# Define a list of TTA methods to be evaluated// Expand the list with different variations of TTA methods
baselines = ['bnadapt','delta','tent', 'pseudolabel', 'note', 'lame',  'odsN']
tta_methods = ['noadapt']
for baseline in baselines:
    tta_methods.append(baseline)
    tta_methods.append(baseline+'_ours')
    tta_methods.append(baseline+'_ours'+'_diag') 
    tta_methods.append(baseline+'_ours'+'_online')
print ('tta_methods are' , tta_methods)

# Loop through the test datasets
for i, test_dataset in enumerate(test_datasets):
    x_test, y_test = x_tests[i], y_tests[i]

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # construct long-tailed test datasets
    x_test_bal, y_test_bal = x_test, y_test
    x_test, y_test, classwise_num_ex = make_LT_datasets(x_test_bal, y_test_bal, args.rho, args.num_classes)

    # Define the save path for storing accuracy results
    save_path = "./eval_results/tta/cifar10/tslog"
    
    # Check if the accuracy storage file exists or initialize it
    if not os.path.exists(save_path):
        accs_storage = {}
        print ('acc storage is initialized')
    else:
        accs_storage = torch.load(save_path)
        print ('acc storage is loaded')

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Calculate accuracy for baseline1 (noadapt)
    acc_noadapt, _, _  = tta_noadapt(x_test ,y_test, net_init, args.batch_size, num_classes)
    accs_storage["noadapt"] = acc_noadapt

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Calculate accuracy for baseline2 (bnadapt)
    net_adapt = copy.deepcopy(net_init)
    net_adapt = configure_model_bn(net_adapt)
    acc_bnadapt, logits_te, labels_te = tta_bnadapt(x_test ,y_test, net_adapt, args)
    accs_storage["bnadapt"] = acc_bnadapt
    del net_adapt

    # Calculate the T_te tensor using g_phi
    with torch.no_grad():
        T_te = g_phi(torch.softmax(logits_te, -1).mean(0).cuda()).cpu()

    # Loop through different TTA methods
    for tta_method in tta_methods:
        if tta_method in accs_storage.keys():
            accs = accs_storage[tta_method]
            if type(accs) == list:
                accs_4f = ["%.4f"%item for item in accs]
                print (test_dataset, tta_method, accs_4f, 'pre-cal')
            else:
                print (test_dataset, tta_method, "%.4f"%accs, 'pre-cal')

            continue

        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        # Check if the TTA method is online or diagonal
        is_online = tta_method.split('_')[-1] == 'online'
        is_diag = tta_method.split('_')[-1] == 'diag'    

        # Determine T and g based on the TTA method
        if 'ours' in tta_method:
            if is_online:
                T = None
                g = g_phi if 'ours' in tta_method else None
            elif is_diag:
                T = torch.diag(torch.diag(T_te.clone()))
                g = None
            else:
                T = T_te.clone()
                g = None
        else:
            T, g = None, None

        # Calculate accuracy for different TTA methods
        if 'bnadapt' in tta_method :
            net_adapt = copy.deepcopy(net_init)
            net_adapt = configure_model_bn(net_adapt)
            accs,_,_ = tta_bnadapt(x_test, y_test, net_adapt, args, T, g)
            del net_adapt

        elif 'tent' in tta_method :
            net_adapt = copy.deepcopy(net_init)
            net_adapt = configure_model_tent(net_adapt)
            accs = tta_tent(x_test, y_test, net_adapt, args,T, g)
            del net_adapt

        elif 'pseudolabel' in tta_method :
            net_adapt = copy.deepcopy(net_init)
            net_adapt = configure_model_bn(net_adapt)
            accs = tta_pseudolabel(x_test, y_test, net_adapt, args, T, g)
            del net_adapt   

        elif 'lame' in tta_method:
            net_adapt = copy.deepcopy(net_init)
            net_adapt = configure_model_bn(net_adapt)
            accs = tta_lame(x_test, y_test, net_adapt, args,T, g)

        elif 'delta' in tta_method:
            net_adapt = copy.deepcopy(net_init)
            net_adapt = configure_model_bn(net_adapt)
            accs = tta_delta(x_test, y_test, net_adapt, args,T, g)

        elif 'note' in tta_method:
            net_adapt = copy.deepcopy(net_init)
            net_adapt = configure_model_bn(net_adapt)
            accs = tta_note(x_test, y_test, net_adapt, args,T, g)

        elif 'odsN' in tta_method:
            net_adapt = copy.deepcopy(net_init)
            net_adapt = configure_model_bn(net_adapt)
            accs = tta_ods_note(x_test, y_test, net_adapt, args,T, g)

        # Store the calculated accuracy in the storage dictionary
        accs_storage[tta_method] = accs
        if type(accs) == list:
            accs_4f = ["%.4f"%item for item in accs]
            print (test_dataset, tta_method, accs_4f)
        else:
            print (test_dataset, tta_method, "%.4f"%accs)

        # Save the accuracy results to the storage file
        torch.save(accs_storage, save_path)
        
    # Save the accuracy results to the storage file
    torch.save(accs_storage, save_path)
    print (test_dataset, accs_storage)