import torch
import numpy as np
import copy
import random
import argparse
import os
from datetime import datetime
import torch.optim as optim

# Import custom utility functions and models
from utils_baseline import *
from utils_intermediate import *
from utils.misc import *
from utils.misc_digit import *
from utils.models import *



######################################################################################################################################
current_time = datetime.now().strftime("%y%m%d_%H%M%S")
parser = argparse.ArgumentParser()

# ---  --- #
parser.add_argument('--benchmark', type=str, default='digit', help=['cifar10', 'cifar100', 'digit'])

# ---  --- #
parser.add_argument('--rho', type=float, default=0.01, help=[0.01, 0.1, 1.])

# --- Training detils for TTA --- #
parser.add_argument('--batch_size', type=int, default=200)
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--pl_threshold', type=float, default=0.9)
parser.add_argument('--optim_type', type=str, default='adam', help=['sgd', 'adam'])
parser.add_argument('--ft_layers', type=str, default='bn', help=['bn', 'tent', 'whole'])

# ---  --- #
parser.add_argument('--seed', type=int, default=1)
# ---  --- #

args = parser.parse_args()
print ( args )
######################################################################################################################################
# Set random seeds for reproducibility
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
random.seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

#########################################################################################
# Load the training dataset based on the selected benchmark
if args.benchmark == 'cifar10':
    x_train, y_train = load_cifar10_train_dataset()
    num_classes = 10

elif args.benchmark == 'cifar100':
    x_train, y_train = load_cifar100_train_dataset()
    num_classes = 100
elif args.benchmark  == 'digit':
    x_train, y_train = load_digit_train_dataset()
    num_classes = 10

# Make a balanced intermediate dataset
x_train, y_train = make_balanced_intermediate_dataset(x_train, y_train, num_classes)
args.num_classes = num_classes

# Load a pre-trained model based on the selected benchmark
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if args.benchmark == 'cifar10':
    ckpt_path = "path/pre-trained model"

    net_init = Normalized_ResNet(depth=26)
    checkpoint = torch.load(ckpt_path)
    checkpoint = checkpoint['net']

    net_init.to(device)
    net_init = torch.nn.DataParallel(net_init)
    net_init.load_state_dict(checkpoint)

elif args.benchmark == 'cifar100':
    ckpt_path =  "path/pre-trained model"
    net_init = Normalized_ResNet_CIFAR100()
    net_init = torch.nn.DataParallel(net_init)

    checkpoint = torch.load(ckpt_path)
    net_init.load_state_dict(checkpoint["net"])

    net_init.to(device)
elif args.benchmark == 'digit':
    ckpt_path =  "path/pre-trained model"
    net_init = ResNet18()
    checkpoint = torch.load(ckpt_path)
    net_init.to(device)
    net_init = torch.nn.DataParallel(net_init)
    net_init.load_state_dict(checkpoint)
else:
    print ('there is no pre-trained model for the benchmark')

# configure model
if args.ft_layers == 'bn':
    net_adapt = copy.deepcopy(net_init)
    net_adapt = configure_model_bn(net_adapt)

elif args.ft_layers == 'tent':
    net_adapt = copy.deepcopy(net_init)
    net_adapt = configure_model_tent(net_adapt)

elif args.ft_layers == 'noadapt':
    net_adapt = copy.deepcopy(net_init)
    net_adapt = configure_model_noadapt(net_adapt)
else:
    print ('there is no ft layers')

# Temporary test datasets for training log
if args.benchmark == 'cifar10':
    x_tests, y_tests, test_datasets = load_cifar10_test_datasets()
    x_test_bal, y_test_bal = x_tests[1], y_tests[1]
    x_test, y_test, _ = make_LT_datasets(x_test_bal, y_test_bal, args.rho, num_classes)
    acc_bnadapt_te, logits_te, labels_te = tta_bnadapt(x_test ,y_test, net_adapt, args)
    num_classes = 10
elif args.benchmark == 'cifar100':
    x_tests, y_tests, test_datasets = load_cifar100_test_datasets()
    x_test_bal, y_test_bal = x_tests[1], y_tests[1]
    x_test, y_test, _ = make_LT_datasets(x_test_bal, y_test_bal, args.rho, num_classes)
    acc_bnadapt_te, logits_te, labels_te = tta_bnadapt(x_test ,y_test, net_adapt, args)
    num_classes = 100
elif args.benchmark  == 'digit':
    x_tests, y_tests= load_digit_test_datasets()
    x_test, y_test = x_tests[1], y_tests[1]
    acc_bnadapt_te, logits_te, labels_te, feats_te  = tta_bnadapt(x_test ,y_test, net_adapt, args, need_feats=True)
    num_classes = 10
else:
    print ('there is no benchmark')

# Define paths for saving models and storage
save_path = "./eval_results/digit/trained_gphi.pt"
storage_save_path = "./eval_results/digit/trlog"

batch_size = args.batch_size
net_adapt = copy.deepcopy(net_init)
if args.ft_layers == 'tent':
    net_adapt = configure_model_tent(net_adapt)
elif args.ft_layers == 'bn':
    net_adapt = configure_model_bn(net_adapt)
elif args.ft_layers == 'noadapt':
    net_adapt = configure_model_noadapt(net_adapt)
    print ('line131')
else:
    del net_adapt

if not os.path.exists(save_path) :
    hiddendim = 1000 
    # Create an MLP (Multi-Layer Perceptron) model for g_\phi
    netW = MLP(num_classes, hiddendim=hiddendim).cuda()
    loss_fn = nn.CrossEntropyLoss().cuda()

    storage = {}
    storage['loss_pre'] = []
    storage['loss_post'] = []
    storage['acc_pre'] = []
    storage['acc_post'] = []

    # optimizer
    params = netW.parameters()
    optimizer = torch.optim.Adam(params, lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epochs)

    for epoch in range(args.num_epochs):
        loss_avg = 0
        loss_compare_avg = 0
        acc_avg = 0
        acc_compare_avg = 0

        # Generate new indices for a Dirichlet distribution-based sampling
        dirichlet_numchunks = x_train.size(0) // args.batch_size
        new_indices = dirichlet_indices(x_train, y_train, net_adapt, num_classes, dirichlet_numchunks=dirichlet_numchunks, non_iid_ness=10.)
        n_batches = math.ceil(x_train.size(0) / batch_size)
        if args.ft_layers == 'noadapt':
            net_adapt.eval()
        else:
            net_adapt.train()
        for counter in range(n_batches):         
            x_curr = x_train[new_indices][counter * batch_size:(counter + 1) * batch_size].cuda()
            y_curr = y_train[new_indices][counter * batch_size:(counter + 1) * batch_size].cuda()

            # Compute model outputs with no gradient computation for the pre-trained classifiers
            with torch.no_grad():
                outputs, _ = net_adapt(x_curr, True)

            # Compute the loss using the g_phi outputs
            T = netW(torch.softmax(outputs, dim=-1).mean(0))
            loss = loss_fn(outputs @ (T.cuda()), y_curr)

            # Compute the loss without temperature scaling for comparison
            with torch.no_grad():
                loss_compare = loss_fn(outputs, y_curr)

            # Backpropagation and optimization
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # Update loss and accuracy metrics
            loss_avg += loss.item()
            loss_compare_avg+= loss_compare.item()

        # Print and record metrics for the current epoch
        with torch.no_grad():
            T_curr = netW(torch.softmax(logits_te, 1).mean(0).cuda()).cpu()
        print ("|%s\t|seed%d\t|epoch\t|%d\t|loss_old\t|%.4f\t|loss_new\t|%.4f\t|acc_old\t|%.4f\t|acc_new\t|%.4f\t|"%(
            args.optim_type, args.seed,
            epoch, loss_compare_avg/n_batches, loss_avg/n_batches,
            acc_bnadapt_te, ((logits_te@T_curr).argmax(1) == labels_te).float().mean()
        ))

        # Record loss and accuracy metrics for plotting or analysis
        storage['loss_pre'].append(loss_compare_avg/n_batches) 
        storage['loss_post'].append(loss_avg/n_batches)
        storage['acc_pre'].append(acc_bnadapt_te) 
        storage['acc_post'].append( ((logits_te@T_curr).argmax(1) == labels_te).float().mean())

        # Save the model and storage dictionary periodically
        if epoch % 10 == 9 or epoch == args.num_epochs -1:
            torch.save(netW.state_dict(), save_path)
            torch.save(storage, storage_save_path)
            print ("model and storage are saved")

        # Adjust the learning rate based on the scheduler
        scheduler.step()
