import torch
import numpy as np
import copy
import random 
import argparse
import os
from datetime import datetime

from utils_baseline import *
from utils_intermediate import *
from utils.misc_pacs import *
from utils.misc import *
from utils.models_pacs import *


######################################################################################################################################
current_time = datetime.now().strftime("%y%m%d_%H%M%S")
parser = argparse.ArgumentParser()

# ---  --- #
parser.add_argument('--benchmark', type=str, default='pacs')

# ---  --- #

# --- Training detils for TTA --- #
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--num_epochs', type=int, default=1)
parser.add_argument('--pl_threshold', type=float, default=0.9)
parser.add_argument('--optim_type', type=str, default='adam', help=['sgd', 'adam'])
parser.add_argument('--ft_layers', type=str, default='bn', help=['bn', 'tent', 'whole'])

# ---  --- #
parser.add_argument('--seed', type=int, default=1)
# ---  --- #

args = parser.parse_args()
print ( args )

######################################################################################################################################
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
######################################################################################################################################

# load test dataset
x_tests, y_tests, _ = load_pacs_test_datasets()
num_classes = 7
args.num_classes = num_classes
domains =  ["art", "cartoon", "photo", "sketch"]
args.domains = domains
print ('size of the test datasets is ', [x_test.size() for x_test in x_tests])


# tta_methods
baselines = ['bnadapt','delta','tent', 'pseudolabel', 'note', 'lame',  'odsN']
tta_methods = ['noadapt']
for baseline in baselines:
    tta_methods.append(baseline)
    tta_methods.append(baseline+'_ours')
    tta_methods.append(baseline+'_ours'+'_diag') 
    tta_methods.append(baseline+'_ours'+'_online')
print ('tta_methods are' , tta_methods)


for i, src_domain in enumerate(domains):
    # load pre-trained model
    ckpt_path = '../trained_models/pacs/resnet50_bn_ssh_%s.pth'%src_domain
    depth = int("resnet50".replace("resnet", ""))
    model = resnet(
                'pacs',
                depth,
                split_point=None,
                group_norm_num_groups=None,
                grad_checkpoint=False,
            )
    ckpt = torch.load(ckpt_path)
    model.load_state_dict(ckpt["model"])
    net_init = model.cuda()

    # load pre-trained g_phi
    num_hidden = 1000
    g_phi = MLP(num_classes, num_hidden)
    g_phi.load_state_dict("./eval_results/pacs/trained_gphi_%s.pt"%src_domain)
    g_phi.cuda()

    for j, tgt_domain in enumerate(domains):
        if i != j:
            print (src_domain, tgt_domain)
            x_test, y_test = x_tests[j], y_tests[j]
            save_path = "./eval_results/tta/pacs/tslog"
            if not os.path.exists(save_path):
                accs_storage = {}
                accs_storage['args'] = args
                print ('acc storage is initialized')
            else:
                accs_storage = torch.load(save_path)
                print ('acc storage is loaded')

            random.seed(args.seed)
            np.random.seed(args.seed)
            torch.manual_seed(args.seed)
            torch.cuda.manual_seed_all(args.seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

            # baseline1 // noadapt
            acc_noadapt, _, _  = tta_noadapt(x_test ,y_test, net_init, args.batch_size, num_classes)
            accs_storage["noadapt"] = acc_noadapt

            random.seed(args.seed)
            np.random.seed(args.seed)
            torch.manual_seed(args.seed)
            torch.cuda.manual_seed_all(args.seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

            # baseline2 // bnadapt
            net_adapt = copy.deepcopy(net_init)
            net_adapt = configure_model_bn(net_adapt)
            acc_bnadapt, logits_te, labels_te = tta_bnadapt(x_test ,y_test, net_adapt, args)
            accs_storage["bnadapt"] = acc_bnadapt
            del net_adapt

            with torch.no_grad():
                T_te = g_phi(torch.softmax(logits_te, -1).mean(0).cuda()).cpu()
            accs_storage["bnadapt_ours"] = ((logits_te@T_te).argmax(-1) == labels_te).float().mean()

            for tta_method in tta_methods:
                if tta_method in accs_storage.keys():
                    accs = accs_storage[tta_method]
                    if type(accs) == list:
                        accs_4f = ["%.4f"%item for item in accs]
                        print (src_domain, tgt_domain, tta_method, accs_4f, 'pre-cal')
                    else:
                        print (src_domain, tgt_domain, tta_method, "%.4f"%accs, 'pre-cal')

                    continue
                # accs, _, _  = tta_noadapt(x_test ,y_test, net_init, args.batch_size, num_classes)

                random.seed(args.seed)
                np.random.seed(args.seed)
                torch.manual_seed(args.seed)
                torch.cuda.manual_seed_all(args.seed)
                torch.backends.cudnn.deterministic = True
                torch.backends.cudnn.benchmark = False

                is_online = tta_method.split('_')[-1] == 'online'
                is_diag = tta_method.split('_')[-1] == 'diag'
                # is_hard = 'hard' in tta_method

                if 'ours' in tta_method:
                    if is_online:
                        T = None
                        g = g_phi if 'ours' in tta_method else None
                    elif is_diag:
                        T = torch.diag(torch.diag(T_te.clone()))
                        g = None
                    else:
                        T = T_te.clone()
                        g = None
                else:
                    T, g = None, None

                if 'bnadapt' in tta_method :
                    net_adapt = copy.deepcopy(net_init)
                    net_adapt = configure_model_bn(net_adapt)
                    accs,_,_ = tta_bnadapt(x_test, y_test, net_adapt, args, T, g)
                    del net_adapt

                elif 'tent' in tta_method :
                    net_adapt = copy.deepcopy(net_init)
                    net_adapt = configure_model_tent(net_adapt)
                    accs = tta_tent(x_test, y_test, net_adapt, args,T, g)
                    del net_adapt

                elif 'pseudolabel' in tta_method :
                    net_adapt = copy.deepcopy(net_init)
                    net_adapt = configure_model_bn(net_adapt)
                    accs = tta_pseudolabel(x_test, y_test, net_adapt, args, T, g)
                    del net_adapt   

                elif 'lame' in tta_method:
                    net_adapt = copy.deepcopy(net_init)
                    net_adapt = configure_model_bn(net_adapt)
                    accs = tta_lame(x_test, y_test, net_adapt, args,T, g)

                elif 'delta' in tta_method:
                    net_adapt = copy.deepcopy(net_init)
                    net_adapt = configure_model_bn(net_adapt)
                    accs = tta_delta(x_test, y_test, net_adapt, args,T, g)

                elif 'note' in tta_method:
                    net_adapt = copy.deepcopy(net_init)
                    net_adapt = configure_model_bn(net_adapt)
                    accs = tta_note(x_test, y_test, net_adapt, args,T, g)

                elif 'odsN' in tta_method:
                    net_adapt = copy.deepcopy(net_init)
                    net_adapt = configure_model_bn(net_adapt)
                    accs = tta_ods_note(x_test, y_test, net_adapt, args,T, g)


                accs_storage[tta_method] = accs
                if type(accs) == list:
                    accs_4f = ["%.4f"%item for item in accs]
                    print (src_domain, tgt_domain, tta_method, accs_4f)
                else:
                    print (src_domain, tgt_domain, tta_method, "%.4f"%accs)

                torch.save(accs_storage, save_path)
            torch.save(accs_storage, save_path)
