import torch
import numpy as np
import copy
import random
import argparse
import os
from datetime import datetime
import torch.optim as optim

# Import custom utility functions and models
from utils.misc import *
from utils.misc_cifar import *
from utils.models import *

from misc import *
from utils_baseline_dart_split import *

np.set_printoptions(precision=2, suppress=True)

######################################################################################################################################
current_time = datetime.now().strftime("%y%m%d_%H%M%S")
parser = argparse.ArgumentParser()

# ---  --- #
parser.add_argument('--benchmark', type=str, default='cifar10', help=['cifar10', 'cifar100'])

parser.add_argument('--int_dataset', type=str, default='clean')
parser.add_argument('--int_batch_size', type=int, default=64)
parser.add_argument('--test_batch_size', type=int, default=200)
parser.add_argument('--optim_type', type=str, default='adam', help=['sgd', 'adam'])
parser.add_argument('--ft_layers', type=str, default='tent', help=['bn', 'tent', 'whole'])

# newly added
parser.add_argument('--lr', type=float, default=1e-3)

parser.add_argument('--num_epochs', type=int, default=50)
parser.add_argument('--print_freq', type=int, default=3000)
parser.add_argument('--save_freq', type=int, default=50)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--hiddendim', type=int, default=1000)

args = parser.parse_args()
print ( args )
######################################################################################################################################
# Set random seeds for reproducibility
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
random.seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

#########################################################################################
# load intermediate dataset
if args.benchmark == 'cifar10':
    x_train, y_train = load_cifar10_train_dataset()
    args.num_classes = 10

elif args.benchmark == "cifar100":
    x_train, y_train = load_cifar100_train_dataset()
    args.num_classes = 100

x_train, y_train = x_train.cuda(), y_train.cuda()

# Make a balanced intermediate dataset
x_train, y_train = make_balanced_intermediate_dataset(x_train, y_train, args.num_classes)

# Load a pre-trained model based on the selected benchmark
device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.device = device
if args.benchmark == 'cifar10':
    ckpt_path = "path/pre-trained model"

    net_init = Normalized_ResNet(depth=26)
    checkpoint = torch.load(ckpt_path)
    checkpoint = checkpoint['net']

    net_init.to(device)
    net_init = torch.nn.DataParallel(net_init)
    net_init.load_state_dict(checkpoint)

elif args.benchmark == 'cifar100':
    ckpt_path =  "path/pre-trained model"
    net_init = Normalized_ResNet_CIFAR100()
    net_init = torch.nn.DataParallel(net_init)

    checkpoint = torch.load(ckpt_path)
    net_init.load_state_dict(checkpoint["net"])

    net_init.to(device)
else:
    print ('there is no pre-trained model for the benchmark')

# Configure the model based on the selected ft_layers
if args.ft_layers == 'bn':
    net_adapt = copy.deepcopy(net_init)
    net_adapt = configure_model_bn(net_adapt)

elif args.ft_layers == 'tent':
    net_adapt = copy.deepcopy(net_init)
    net_adapt = configure_model_tent(net_adapt)

elif args.ft_layers == 'noadapt':
    net_adapt = copy.deepcopy(net_init)
    net_adapt = configure_model_noadapt(net_adapt)
else:
    print ('there is no ft layers')

# Initialize prediction refinement module g_phi
g_phi = MLP_dart_split(args.num_classes, hiddendim=args.hiddendim).to(device)
optimizer_g_phi = torch.optim.Adam(g_phi.parameters(), lr=args.lr)
scheduler_g_phi = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_g_phi, T_max=args.num_epochs)

# Temporary test datasets for training log
if args.benchmark == "cifar10":
    x_test, y_test = load_cifar10c(n_examples=10000, severity=5, data_dir="CIFAR-10C PATH" ,
                        shuffle=False, corruptions=["gaussian_noise"])
    imb_ratios = [1, 20, 50, 5000]
elif args.benchmark == "cifar100":
    x_test, y_test = load_cifar100c(n_examples=10000, severity=5, data_dir="CIFAR-100C PATH" ,
                        shuffle=False, corruptions=["gaussian_noise"])
    imb_ratios = [1, 200, 500, 50000]
else:
    print ('there is no benchmark')

#sorting by classes #
x_mod = []
y_mod = []
for c in range(args.num_classes):
    x_mod.append(x_test[y_test == c])
    y_mod.append(y_test[y_test == c])

x_test, y_test = torch.cat(x_mod, 0), torch.cat(y_mod, 0)

imb_indices = []
for imb_ratio in imb_ratios:
    idx_curr = np.load('./eval_results/idx/{}/imb/seed{}_total_{}_ir_{}_class_order_shuffle_{}.npy'.format(
                            args.benchmark,
                            args.seed,
                            10000,
                            imb_ratio,
                            "yes",
                            True
                            ))

    imb_indices.append(idx_curr)

# Define paths for saving models and storage
save_path = "./eval_results/cifar10/trained_gphi.pt"
storage_save_path = "./eval_results/cifar10/trlog"

# compute naive BNAdapt accuracy
storage = {}
accs = []
for i1, imb_ratio in enumerate(imb_ratios):
    acc_temp = tta_bnadapt(x_test, y_test, imb_indices[i1], net_init, args, g_phi=None)
    accs.append(acc_temp)

storage["acc_original"] = accs
print(storage["acc_original"])

# prepare intermediate time training
if args.benchmark == "cifar10":
    num_batches_train = x_train.size(0) // args.int_batch_size
    dirichlet_numchunks = 250
    non_iid_ness = 10

elif args.benchmark == "cifar100":
    num_batches_train = x_train.size(0) // args.int_batch_size
    dirichlet_numchunks = 1000
    non_iid_ness = 0.1

loss_fn = nn.CrossEntropyLoss().to(device)
mseloss = nn.MSELoss().to(device)
bceloss = nn.BCELoss().to(device)


net_adapt.train()
if not os.path.exists(save_path) : # prevent over-write
    eps = 1e-12
    for epoch in range(args.num_epochs):
        acc_tr_naive = 0.
        acc_tr_ours = 0.

        perm_rand = torch.randperm(x_train.size(0))

        perm_dir = dirichlet_indices(x_train, y_train, net_adapt, args.num_classes,
                                     dirichlet_numchunks=dirichlet_numchunks, non_iid_ness=non_iid_ness)

        for counter in range(num_batches_train):
            batch_idx = num_batches_train * epoch + counter

            idx_imb = perm_dir[counter * args.int_batch_size:(counter + 1) * args.int_batch_size]
            idx_bal = perm_rand[counter * args.int_batch_size:(counter + 1) * args.int_batch_size]

            x_bal, y_bal = x_train[idx_bal].to(device), y_train[idx_bal].to(device)
            x_imb, y_imb = x_train[idx_imb].to(device), y_train[idx_imb].to(device)

            with torch.no_grad():
                logits_bal = net_adapt(x_bal)
                logits_imb = net_adapt(x_imb)

            # prediction/ pseudo label
            preds_bal = torch.softmax(logits_bal, dim=1)
            preds_imb = torch.softmax(logits_imb, dim=1)

            # averaged pseudo label distribution
            py_bal, py_imb = preds_bal.mean(0), preds_imb.mean(0)

            # prediction variance
            pred_dev_bal = (-((torch.ones(args.num_classes) / args.num_classes).mean(0, keepdim=True).to(device) * torch.log(
                    preds_bal + eps))).sum(dim=1).mean()
            pred_dev_imb = (-((torch.ones(args.num_classes) / args.num_classes).mean(0, keepdim=True).to(device) * torch.log(
                    preds_imb + eps))).sum(dim=1).mean()

            W_bal, b_bal, inp_bal = g_phi(py_bal, pred_dev_bal)
            W_imb, b_imb, inp_imb = g_phi(py_imb, pred_dev_imb)

            ce_loss = loss_fn(logits_imb @ W_imb + b_imb, y_imb)
            reg_loss = bceloss(torch.cat((inp_bal, inp_imb)), torch.tensor([0, 1.]).to(device))

            total_loss = ce_loss + reg_loss

            optimizer_g_phi.zero_grad()
            total_loss.backward()
            optimizer_g_phi.step()

            # test
            if batch_idx % args.print_freq == args.print_freq - 1:
                accs = []
                for i1, imb_ratio in enumerate(imb_ratios):
                    acc_temp = tta_bnadapt(x_test, y_test, imb_indices[i1], net_init, opt, args, g_phi=g_phi)
                    accs.append(acc_temp)
                storage["acc;g_phi;intb%d" % batch_idx] = torch.stack(accs)
                print (batch_idx, "g_phi", torch.stack(accs))

        scheduler_g_phi.step()
        if epoch % args.save_freq == args.save_freq - 1 or epoch == args.num_epochs - 1:
            torch.save(storage, storage_path + ";epoch%d" % epoch + ".ckpt")
            torch.save(g_phi.state_dict(), g_phi_path + ";epoch%d" % epoch + ".ckpt")

    print('Intermediate training is finished')
else:
    print ("Trained file exists")


