import torch
import numpy as np
import copy
import random 
import argparse
import os
from datetime import datetime

# Import custom utility functions and models

from utils.misc import *
from utils.misc_cifar import *
from utils.models import *

from misc import *
from utils_baseline_dart import *

######################################################################################################################################
current_time = datetime.now().strftime("%y%m%d_%H%M%S")
parser = argparse.ArgumentParser()

# ---  --- #
parser.add_argument('--benchmark', type=str, default='cifar10', help=['cifar10', 'cifar100', 'digit'])

# --- Training detils for TTA --- #
parser.add_argument('--test_batch_size', type=int, default=200)
parser.add_argument('--optim_type', type=str, default='adam', help=['sgd', 'adam'])
parser.add_argument('--ft_layers', type=str, default='bn', help=['bn', 'tent', 'whole'])
parser.add_argument('--tta_method', type=str, default='bnadapt',
                    help=["noadapt", "bnadapt", "tent", "pl", "ods", "sar", "lame", "delta", "note"])
# ---  --- #
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--smax_temp', type=float, default=1)
# ---  --- #

args = parser.parse_args()

args.num_hidden = 1000

print ( args )

######################################################################################################################################
# Set random seeds for reproducibility across runs
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

#########################################################################################


# Load a pre-trained model based on the benchmark
device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.device = device
if args.benchmark == 'cifar10':
    ckpt_path = "path of trained classifier"

    net_init = Normalized_ResNet(depth=26)
    checkpoint = torch.load(ckpt_path)
    checkpoint = checkpoint['net']

    net_init.to(device)
    net_init = torch.nn.DataParallel(net_init)
    net_init.load_state_dict(checkpoint)

    # load trained g_phi
    args.num_classes = 10
    g_phi = MLP_dart(args.num_classes, args.num_hidden)
    g_phi.load_state_dict(torch.load("./eval_results/cifar10/trained_gphi.pt"))
    g_phi.to(device)

elif args.benchmark == 'cifar100':
    ckpt_path = "path of trained classifier"
    net_init = Normalized_ResNet_CIFAR100()
    net_init = torch.nn.DataParallel(net_init)

    checkpoint = torch.load(ckpt_path)
    net_init.load_state_dict(checkpoint["net"])
    net_init.to(device)

    # load trained g_phi
    args.num_classes = 100
    g_phi = MLP_dart(args.num_classes, args.num_hidden)
    g_phi.load_state_dict(torch.load("./eval_results/cifar100/trained_gphi.pt"))
    g_phi.to(device)
else:
    print ('there is no pre-trained model for the benchmark')


# load test datasets
if args.benchmark == 'cifar10':
    x_tests, y_tests, test_datasets = load_cifar10_test_datasets()
    num_classes = 10
elif args.benchmark == 'cifar100':
    x_tests, y_tests, test_datasets = load_cifar100_test_datasets()
    num_classes = 100
else:
    print ('there is no benchmark')

batch_size = args.batch_size

# configure model
net_adapt = copy.deepcopy(net_init)
if args.ft_layers == 'tent':
    net_adapt = configure_model_tent(net_adapt)
elif args.ft_layers == 'bn':
    net_adapt = configure_model_bn(net_adapt)
elif args.ft_layers == 'noadapt':
    net_adapt = configure_model_noadapt(net_adapt)
else:
    del net_adapt

# load indices
if args.benchmark == "cifar10":
    imb_ratios = [1,20, 50,5000]
elif args.benchmark == "cifar100":
    imb_ratios = [1,200,500,50000]

imb_indices = []
for imb_ratio in imb_ratios:
    idx_curr = np.load('./eval_results/idx/{}/imb/seed{}_total_{}_ir_{}_class_order_shuffle_{}.npy'.format(
                            args.benchmark,
                            args.seed,
                            10000,
                            imb_ratio,
                            "yes",
                            True
                            ))

    imb_indices.append(idx_curr)

corruptions = ['gaussian_noise'
        ,'shot_noise'
        ,'impulse_noise'
        ,'defocus_blur'
        ,'glass_blur'
        ,'motion_blur'
        ,'zoom_blur'
        ,'snow'
        ,'frost'
        ,'fog'
        ,'brightness'
        ,'contrast'
        ,'elastic_transform'
        ,'pixelate'
        ,'jpeg_compression']


for i1, corruption in enumerate(corruptions):
    storage_path = "./eval_results/tta/cifar10/tslog"

    if not os.path.exists(storage_path):
        os.makedirs(storage_path)

    if args.benchmark == "cifar100":
        x_test, y_test = load_cifar100c(n_examples=10000, severity=5, data_dir="cifar100c path",
                                        shuffle=False, corruptions=[corruption])

    else:
        x_test, y_test = load_cifar10c(n_examples=10000, severity=5, data_dir="cifar10c path",
                                       shuffle=False, corruptions=[corruption])
    # sorting by classes #
    x_mod = []
    y_mod = []
    for c in range(args.num_classes):
        x_mod.append(x_test[y_test == c])
        y_mod.append(y_test[y_test == c])

    x_test, y_test = torch.cat(x_mod, 0), torch.cat(y_mod, 0)

    try:
        accs_storage = torch.load(storage_path)
    except:
        accs_storage = {}


    for i3, imb_ratio in enumerate(imb_ratios):
        try:
            print ("|%s\t|%s\t|ir%d\t|naive\t|%.4f\t|ours\t|%.4f\t|" % (
                args.tta_method,
                corruption,
                imb_ratio,
                accs_storage["naive_ir%d" % (imb_ratio)],
                accs_storage["ours_ir%d" % (imb_ratio)]
            ))
        except:
            if args.tta_method == "noadapt":
                acc_naive, _, _ = tta_noadapt(x_test, y_test, imb_indices[i3], net_init, args,g_phi=None)
                acc_ours, _, _ = tta_noadapt(x_test, y_test, imb_indices[i3], net_init, args,g_phi=g_phi)

            elif args.tta_method == "bnadapt":
                acc_naive, _, _ = tta_bnadapt(x_test, y_test, imb_indices[i3], net_init, args,g_phi=None)
                acc_ours, _, _ = tta_bnadapt(x_test, y_test, imb_indices[i3], net_init, args,g_phi=g_phi)

            elif args.tta_method == "tent":
                acc_naive, _, _ = tta_tent(x_test, y_test, imb_indices[i3], net_init, args,g_phi=None)
                acc_ours, _, _ = tta_tent(x_test, y_test, imb_indices[i3], net_init, args,g_phi=g_phi)

            elif args.tta_method == "sar":
                acc_naive, _, _ = tta_sar(x_test, y_test, imb_indices[i3], net_init, args,g_phi=None)
                acc_ours, _, _ = tta_sar(x_test, y_test, imb_indices[i3], net_init, args,g_phi=g_phi)

            elif args.tta_method == "ods":
                acc_naive, _, _ = tta_ods(x_test, y_test, imb_indices[i3], net_init, args,g_phi=None)
                acc_ours, _, _ = tta_ods(x_test, y_test, imb_indices[i3], net_init, args, g_phi=g_phi)

            elif args.tta_method == "delta":
                acc_naive, _, _ = tta_delta(x_test, y_test, imb_indices[i3], net_init, args,g_phi=None)
                acc_ours, _, _ = tta_delta(x_test, y_test, imb_indices[i3], net_init, args, g_phi=g_phi)

            elif args.tta_method == "note":
                acc_naive, _, _ = tta_note(x_test, y_test, imb_indices[i3], net_init, args,g_phi=None)
                acc_ours, _, _ = tta_note(x_test, y_test, imb_indices[i3], net_init, args, g_phi=g_phi)

            elif args.tta_method == "lame":
                acc_naive, _, _ = tta_lame(x_test, y_test, imb_indices[i3], net_init, args,g_phi=None)
                acc_ours, _, _ = tta_lame(x_test, y_test, imb_indices[i3], net_init, args, g_phi=g_phi)

            elif args.tta_method == "pl":
                acc_naive, _, _ = tta_pl(x_test, y_test, imb_indices[i3], net_init, args,g_phi=None)
                acc_ours, _, _ = tta_pl(x_test, y_test, imb_indices[i3], net_init, args, g_phi=g_phi)

            accs_storage["naive_ir%d" % (imb_ratio)] = acc_naive
            accs_storage["ours_ir%d" % (imb_ratio)] = acc_ours

            print ("|%s\t|%s\t|ir%d\t|naive\t|%.4f\t|ours\t|%.4f\t|" % (
                args.tta_method,
                corruption,
                imb_ratio,
                accs_storage["naive_ir%d" % (imb_ratio)],
                accs_storage["ours_ir%d" % (imb_ratio)]
            ))

            torch.save(accs_storage, storage_path)

        torch.save(accs_storage, storage_path)
