import torch
import numpy as np
import copy
import random
import argparse
import os
from datetime import datetime
import torch.optim as optim

# Import custom utility functions and models from your module files
from utils_baseline_norm import *
from utils_intermediate import *
from utils.misc import *
from utils.misc_cifar import *
from utils.models import *

######################################################################################################################################
current_time = datetime.now().strftime("%y%m%d_%H%M%S")
parser = argparse.ArgumentParser()

# ---  --- #
parser.add_argument('--benchmark', type=str, default='cifar100', help=['cifar10', 'cifar100', 'digit'])

# ---  --- #
parser.add_argument('--rho', type=float, default=0.01, help=[0.01, 0.1, 1.])

# --- Training detils for TTA --- #
parser.add_argument('--batch_size', type=int, default=200)
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--pl_threshold', type=float, default=0.9)
parser.add_argument('--optim_type', type=str, default='adam', help=['sgd', 'adam'])
parser.add_argument('--ft_layers', type=str, default='bn', help=['bn', 'tent', 'whole'])

# ---  --- #
parser.add_argument('--seed', type=int, default=1)
# ---  --- #

args = parser.parse_args()
print (args)
######################################################################################################################################
# Set random seeds for reproducibility
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
random.seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

#########################################################################################
# Load the training dataset based on the selected benchmark
if args.benchmark == 'cifar10':
    x_train, y_train = load_cifar10_train_dataset()
    num_classes = 10

elif args.benchmark == 'cifar100':
    x_train, y_train = load_cifar100_train_dataset()
    num_classes = 100

# Make a balanced intermediate dataset
x_train, y_train = make_balanced_intermediate_dataset(x_train, y_train, num_classes)
args.num_classes = num_classes

# Load a pre-trained model based on the selected benchmark
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if args.benchmark == 'cifar10':
    ckpt_path = "path/pre-trained model"

    net_init = Normalized_ResNet(depth=26)
    checkpoint = torch.load(ckpt_path)
    checkpoint = checkpoint['net']

    net_init.to(device)
    net_init = torch.nn.DataParallel(net_init)
    net_init.load_state_dict(checkpoint)

elif args.benchmark == 'cifar100':
    ckpt_path = "path/pre-trained model"
    net_init = Normalized_ResNet_CIFAR100()
    net_init = torch.nn.DataParallel(net_init)

    checkpoint = torch.load(ckpt_path)
    net_init.load_state_dict(checkpoint["net"])

    net_init.to(device)
else:
    print ('there is no pre-trained model for the benchmark')

# Configure the model based on the selected ft_layers
if args.ft_layers == 'bn':
    net_adapt = copy.deepcopy(net_init)
    net_adapt = configure_model_bn(net_adapt)

elif args.ft_layers == 'tent':
    net_adapt = copy.deepcopy(net_init)
    net_adapt = configure_model_tent(net_adapt)

elif args.ft_layers == 'noadapt':
    net_adapt = copy.deepcopy(net_init)
    net_adapt = configure_model_noadapt(net_adapt)
    print ('line100')

else:
    print ('there is no ft layers')

# Temporary test datasets for training log
if args.benchmark == 'cifar10':
    x_tests, y_tests, test_datasets = load_cifar10_test_datasets()
    x_test_bal, y_test_bal = x_tests[1], y_tests[1]
    x_test, y_test, _ = make_LT_datasets(x_test_bal, y_test_bal, args.rho, num_classes)
    acc_bnadapt_te, logits_te, labels_te = tta_bnadapt(x_test, y_test, net_adapt, args)
    num_classes = 10

elif args.benchmark == 'cifar100':
    x_tests, y_tests, test_datasets = load_cifar100_test_datasets()
    x_test_bal, y_test_bal = x_tests[1], y_tests[1]
    x_test, y_test, _ = make_LT_datasets(x_test_bal, y_test_bal, args.rho, num_classes)
    acc_bnadapt_te, logits_te, labels_te = tta_bnadapt(x_test, y_test, net_adapt, args)
    num_classes = 100

else:
    print ('there is no benchmark')

# Define paths for saving models and storage
save_path ="./eval_results/cifar100/trained_gphi.pt"
storage_save_path = "./eval_results/cifar100/trlog"
batch_size = args.batch_size

# Function to compute a mask
def compute_mask():
    # Load corrupted CIFAR-100 training dataset
    def load_corrupted_cifar100_train_dataset():
        severity = 1
        n_img = 50000
        n_total_cifar = 50000

        images_all = np.load("path of speckle noised CIFAR-100C training image numpy file")
        labels_all = np.load("path of speckle noised CIFAR-100C training label numpy file")

        images = images_all[(severity - 1) * n_total_cifar:severity * n_total_cifar]
        labels = labels_all[(severity - 1) * n_total_cifar:severity * n_total_cifar]
        x_train = np.transpose(images, (0, 3, 1, 2))
        x_train = x_train.astype(np.float32) / 255
        x_train = torch.tensor(x_train)
        y_train = torch.tensor(labels).long()
        return x_train, y_train

    x_train, y_train = load_corrupted_cifar100_train_dataset()
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    net_adapt = copy.deepcopy(net_init)
    net_adapt = configure_model_bn(net_adapt)
    acc_bnadapt_tr, logits_tr, labels_tr = tta_bnadapt(x_train, y_train, net_adapt, args)
    del net_adapt

    # Compute the confusion matrix
    from sklearn.metrics import confusion_matrix
    con_mat = confusion_matrix(labels_tr, torch.squeeze(logits_tr.argmax(-1)).float())
    con_mat_norm = con_mat / con_mat.sum(1, keepdims=True)

    # Create a mask based on the confusion matrix
    mask = torch.zeros(args.num_classes, args.num_classes)
    for i in range(args.num_classes):
        for j in range(args.num_classes):
            if con_mat_norm[i, j] > 0.:
                mask[i, j] = 1
                mask[j, i] = 1

        mask[i, i] = 1
    return mask


if not os.path.exists(save_path):
    hiddendim = 1000

    # Create an MLP (Multi-Layer Perceptron) model for g_\phi
    netW = MLP(num_classes, hiddendim=hiddendim).cuda()
    loss_fn = nn.CrossEntropyLoss().cuda()

    storage = {}
    storage['loss_pre'] = []
    storage['loss_post'] = []
    storage['acc_pre'] = []
    storage['acc_post'] = []

    mask = compute_mask()
    # Define the optimizer for training the MLP
    params = netW.parameters()
    optimizer = torch.optim.Adam(params, lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epochs)

    for epoch in range(args.num_epochs):
        loss_avg = 0
        loss_compare_avg = 0
        acc_avg = 0
        acc_compare_avg = 0

        # Generate new indices for a Dirichlet distribution-based sampling
        dirichlet_numchunks = x_train.size(0) // args.batch_size
        new_indices = dirichlet_indices(x_train, y_train, net_adapt, num_classes,
                                        dirichlet_numchunks=dirichlet_numchunks, non_iid_ness=1.)
        n_batches = math.ceil(x_train.size(0) / batch_size)
        if args.ft_layers == 'noadapt':
            net_adapt.eval()
        else:
            net_adapt.train()
        for counter in range(n_batches):
            x_curr = x_train[new_indices][counter * batch_size:(counter + 1) * batch_size].cuda()
            y_curr = y_train[new_indices][counter * batch_size:(counter + 1) * batch_size].cuda()
            
            # Compute model outputs with no gradient computation for the pre-trained classifiers
            with torch.no_grad():
                outputs, _ = net_adapt(x_curr, True)

            # Compute the loss using the g_phi outputs
            T = netW(torch.softmax(outputs, dim=-1).mean(0))
            T = T * mask.cuda()
            loss = loss_fn(outputs @ (T.cuda()), y_curr)

            # Compute the loss without temperature scaling for comparison
            with torch.no_grad():
                loss_compare = loss_fn(outputs, y_curr)

            # Backpropagation and optimization
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # Update loss and accuracy metrics
            loss_avg += loss.item()
            loss_compare_avg += loss_compare.item()

        # Print and record metrics for the current epoch
        with torch.no_grad():
            T_curr = netW(torch.softmax(logits_te, 1).mean(0).cuda()).cpu()
        print ("|%s\t|seed%d\t|epoch\t|%d\t|loss_old\t|%.4f\t|loss_new\t|%.4f\t|acc_old\t|%.4f\t|acc_new\t|%.4f\t|" % (
            args.optim_type, args.seed,
            epoch, loss_compare_avg / n_batches, loss_avg / n_batches,
            acc_bnadapt_te, ((logits_te @ T_curr).argmax(1) == labels_te).float().mean()
        ))

        # Record loss and accuracy metrics for plotting or analysis
        storage['loss_pre'].append(loss_compare_avg / n_batches)  # = []
        storage['loss_post'].append(loss_avg / n_batches)  # = []
        storage['acc_pre'].append(acc_bnadapt_te)  # = []
        storage['acc_post'].append(((logits_te @ T_curr).argmax(1) == labels_te).float().mean())  # = []

        # Save the model and storage dictionary periodically
        if epoch % 10 == 9 or epoch == args.num_epochs - 1:
            torch.save(netW.state_dict(), save_path)
            torch.save(storage, storage_save_path)
            print ("model and storage are saved")

        # Adjust the learning rate based on the scheduler
        scheduler.step()
