import torch
import numpy as np
import copy
import random
import argparse
import os
from datetime import datetime

from utils_baseline import *
from utils_intermediate import *
from utils.misc_pacs import *
from utils.misc import *
from utils.models_pacs import *


######################################################################################################################################
current_time = datetime.now().strftime("%y%m%d_%H%M%S")
parser = argparse.ArgumentParser() # 

# ---  --- #
parser.add_argument('--benchmark', type=str, default='pacs')

# ---  --- #

# --- Training detils for TTA --- #
parser.add_argument('--batch_size', type=int, default=200)
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--pl_threshold', type=float, default=0.9)
parser.add_argument('--optim_type', type=str, default='adam', help=['sgd', 'adam'])
parser.add_argument('--ft_layers', type=str, default='bn', help=['bn', 'tent', 'whole'])

# ---  --- #
parser.add_argument('--seed', type=int, default=1)
# ---  --- #

args = parser.parse_args()
print (args)
######################################################################################################################################
# Set random seeds for reproducibility
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
random.seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Load the training and test dataset based on the selected benchmark
x_trains, y_trains = load_pacs_train_datasets()
x_tests, y_tests,_ = load_pacs_test_datasets()
num_classes = 7
args.num_classes = num_classes
domains = ["art", "cartoon", "photo", "sketch"]
args.domains = domains
batch_size = args.batch_size
print ('size of the test datasets is ', [x_test.size() for x_test in x_tests])


for i, src_domain in enumerate(domains):
    # Define paths for saving models and storage
    save_path = "load for trained_gphi directory"
    if not os.path.exists(save_path):
        # Load a pre-trained model based on the selected benchmark
        ckpt_path = "path/pre-trained model of source domain"
        depth = int("resnet50".replace("resnet", ""))
        model = resnet(
            'pacs',
            depth,
            split_point=None,
            group_norm_num_groups=None,
            grad_checkpoint=False,
        )
        ckpt = torch.load(ckpt_path)
        model.load_state_dict(ckpt["model"])
        net_init = model.cuda()

        # Configure the model based on the selected ft_layers        
        if args.ft_layers == 'bn':
            net_adapt = copy.deepcopy(net_init)
            net_adapt = configure_model_bn(net_adapt)

        elif args.ft_layers == 'tent':
            net_adapt = copy.deepcopy(net_init)
            net_adapt = configure_model_tent(net_adapt)

        else:
            print ('there is no ft layers')

        # # Make a balanced intermediate dataset
        x_train, y_train = x_trains[i], y_trains[i]
        x_train, y_train = make_balanced_intermediate_dataset(x_train, y_train, num_classes)

        # Temporary test datasets for training log
        if i == 3:
            x_test, y_test = x_tests[0], y_tests[0]
        else:
            x_test, y_test = x_tests[-1], y_tests[-1]
        acc_bnadapt_te, logits_te, labels_te = tta_bnadapt(x_test, y_test, net_adapt, args)

        # Create an MLP (Multi-Layer Perceptron) model for g_\phi
        hiddendim = 1000
        netW = MLP(num_classes, hiddendim=hiddendim).cuda()
        loss_fn = nn.CrossEntropyLoss().cuda()

        # init optimizer
        params = netW.parameters()
        optimizer = setup_optimizer(params, args.optim_type)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epochs)

        # Define paths for saving models and storage
        save_path = "./eval_results/pacs/trained_src%s.pt"%src_domain
        storage_save_path = "./eval_results/pacs/trlog_src%s"%src_domain

        storage = {}
        storage['loss_pre'] = []
        storage['loss_post'] = []
        storage['acc_pre'] = []
        storage['acc_post'] = []

        for epoch in range(args.num_epochs):
            loss_avg = 0
            loss_compare_avg = 0
            acc_avg = 0
            acc_compare_avg = 0

            # Generate new indices for a Dirichlet distribution-based sampling
            dirichlet_numchunks = x_train.size(0) // args.batch_size
            new_indices = dirichlet_indices(x_train, y_train, net_adapt, num_classes,
                                            dirichlet_numchunks=dirichlet_numchunks, non_iid_ness=1.)
            n_batches = math.ceil(x_train.size(0) / batch_size)
            net_adapt.train()

            for counter in range(n_batches):
                x_curr = x_train[new_indices][counter * batch_size:(counter + 1) * batch_size].cuda()
                y_curr = y_train[new_indices][counter * batch_size:(counter + 1) * batch_size].cuda()

                # Compute model outputs with no gradient computation for the pre-trained classifiers
                with torch.no_grad():
                    outputs, _ = net_adapt(x_curr, True)

                # Compute the loss using the g_phi outputs
                T = netW(torch.softmax(outputs, dim=-1).mean(0))
                loss = loss_fn(outputs @ T, y_curr)

                # Compute the loss without temperature scaling for comparison
                with torch.no_grad():
                    loss_compare = loss_fn(outputs, y_curr)

                # Backpropagation and optimization
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                # Update loss and accuracy metrics
                loss_avg += loss.item()
                loss_compare_avg += loss_compare.item()

            # Print and record metrics for the current epoch
            with torch.no_grad():
                
                T_curr = netW(torch.softmax(logits_te, 1).mean(0).cuda()).cpu()
            print ("|epoch\t|%d\t|loss_old\t|%.4f\t|loss_new\t|%.4f\t|acc_old\t|%.4f\t|acc_new\t|%.4f\t|" % (
                epoch, loss_compare_avg / n_batches, loss_avg / n_batches,
                acc_bnadapt_te, ((logits_te @ T_curr).argmax(1) == labels_te).float().mean()
            ))

            # Record loss and accuracy metrics for plotting or analysis
            storage['loss_pre'].append(loss_compare_avg / n_batches) 
            storage['loss_post'].append(loss_avg / n_batches) 
            storage['acc_pre'].append(acc_bnadapt_te)  
            storage['acc_post'].append(((logits_te @ T_curr).argmax(1) == labels_te).float().mean())
            
            # Save the model and storage dictionary periodicall
            if epoch % 10 == 9 or epoch == args.num_epochs - 1:
                torch.save(netW.state_dict(),save_path)
                torch.save(storage, storage_save_path)
                print ("model and storage are saved")
            
            # Adjust the learning rate based on the scheduler
            scheduler.step()