from torch import optim
import torch
import torch.utils.data
import argparse
import torch.backends.cudnn as cudnn
import random
import json
import sys
import swanlab
# Import dataloaders
# import Data.cifar10 as cifar10  
import Data.cifar10_faster as cifar10
import Data.cifar100_faster as cifar100
# import Data.tiny_imagenet as tiny_imagenet
# import Data.imagenet as imagenet

# Import network models
from Net.resnet import resnet50, resnet110
# from Net.resnet_tiny_imagenet import resnet50 as resnet50_ti
# from Net.resmet_imagenet import resnet50_imagenet1k as resnet50_imagenet
from Net.wide_resnet import wide_resnet_cifar
from Net.densenet import densenet121

# Import train and validation utilities
from train_utils import train_single_epoch, test_single_epoch

# Import validation metrics
from Metrics.metrics import test_classification_net
from evaluate import get_logits_labels
import torch.nn as nn
from Metrics.metrics import ECELoss, AdaptiveECELoss, ClasswiseECELoss

'''
The training file of baseline
'''

saveSwanLab = 'cloud' # todo 'disabled' 'cloud'

dataset_num_classes = {'cifar10': 10, 'cifar100': 100, 'tiny_imagenet': 200, 'imagenet': 1000}

dataset_loader = {
    'cifar10': cifar10,
    'cifar100': cifar100,
    # 'tiny_imagenet': tiny_imagenet,
    # 'imagenet': imagenet
}

models = {
    'resnet50': resnet50,
    # 'resnet50_ti': resnet50_ti,
    # 'resnet50_imagenet': resnet50_imagenet,
    'resnet110': resnet110,
    'wide_resnet': wide_resnet_cifar,
    'densenet121': densenet121
}

def loss_function_save_name(loss_function, scheduled=False, gamma=1.0, gamma1=1.0, gamma2=1.0, gamma3=1.0, lamda=1.0):
    res_dict = {
        'cross_entropy': 'cross_entropy',
        'focal_loss': 'focal_loss_gamma_' + str(gamma),
        'focal_loss_adaptive': 'focal_loss_adaptive_gamma_' + str(gamma),
        'mmce': 'mmce_lamda_' + str(lamda),
        'mmce_weighted': 'mmce_weighted_lamda_' + str(lamda),
        'brier_score': 'brier_score',
        'dual_focal_loss': 'dual_focal_loss_gamma_' + str(gamma),
        'ce_soft_ece': 'ce_soft_ece_lamda_' + str(lamda),
        'focal_soft_ece': 'focal_soft_ece_lamda_' + str(lamda),
        'ce_soft_avuc': 'ce_soft_avuc_lamda_' + str(lamda),
        'focal_soft_avuc': 'focal_soft_avuc_lamda_' + str(lamda),
    }
    if (loss_function == 'focal_loss' and scheduled == True):
        res_str = 'focal_loss_scheduled_gamma_' + str(gamma1) + '_' + str(gamma2) + '_' + str(gamma3)
    else:
        res_str = res_dict[loss_function]
    return res_str


def parseArgs():
    dataset_root = "./dataset/tiny_imagenet/"  # only for tiny imagenet
    train_batch_size = 128
    test_batch_size = 128
    learning_rate = 0.1
    momentum = 0.9
    optimiser = "sgd"
    loss = "cross_entropy"
    gamma = 1.0
    gamma2 = 1.0
    gamma3 = 1.0
    lamda = 1.0
    weight_decay = 5e-4
    log_interval = 50
    save_interval = 50
    save_loc = './checkpoints/'
    model_name = None
    load_loc = './'
    model = "resnet50"
    epoch = 350
    first_milestone = 150 #Milestone for change in lr  # todo 
    second_milestone = 250 #Milestone for change in lr
    gamma_schedule_step1 = 100
    gamma_schedule_step2 = 250

    parser = argparse.ArgumentParser(description="Training for calibration.", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--dataset", type=str, default='cifar10', dest="dataset", help='dataset to train on')
    parser.add_argument("--dataset-root", type=str, default=dataset_root, dest="dataset_root", help='root path of the dataset (for tiny imagenet)')
    parser.add_argument("--data-aug", action="store_true", dest="data_aug")
    parser.set_defaults(data_aug=True)
    parser.add_argument("-g", action="store_true", default=True,dest="gpu", help="Use GPU")
    # parser.set_defaults(gpu=True)
    parser.add_argument("--load", action="store_true", default=False, dest="load", help="Load from pretrained model")
    # parser.set_defaults(load=False)
    parser.add_argument("-b", type=int, default=train_batch_size, dest="train_batch_size", help="Batch size")
    parser.add_argument("-tb", type=int, default=test_batch_size, dest="test_batch_size", help="Test Batch size")
    parser.add_argument("-e", type=int, default=epoch, dest="epoch", help='Number of training epochs')
    parser.add_argument("--lr", type=float, default=learning_rate, dest="learning_rate", help='Learning rate')
    parser.add_argument("--mom", type=float, default=momentum, dest="momentum", help='Momentum')
    parser.add_argument("--nesterov", action="store_true", dest="nesterov", help="Whether to use nesterov momentum in SGD")
    parser.set_defaults(nesterov=False)
    parser.add_argument("--decay", type=float, default=weight_decay, dest="weight_decay", help="Weight Decay")
    parser.add_argument("--opt", type=str, default=optimiser, dest="optimiser", help='Choice of optimisation algorithm')
    parser.add_argument("--num_workers", type=int, default=4, dest="num_workers")

    parser.add_argument("--loss", type=str, default=loss, dest="loss_function", help="Loss function to be used for training")
    parser.add_argument("--loss-mean", action="store_true", dest="loss_mean", help="whether to take mean of loss instead of sum to train")
    parser.set_defaults(loss_mean=False)
    parser.add_argument("--gamma", type=float, default=gamma, dest="gamma", help="Gamma for focal components")
    parser.add_argument("--gamma2", type=float, default=gamma2, dest="gamma2", help="Gamma for different focal components")
    parser.add_argument("--gamma3", type=float, default=gamma3, dest="gamma3", help="Gamma for different focal components")
    parser.add_argument("--lamda", type=float, default=lamda, dest="lamda", help="Regularization factor")
    parser.add_argument("--gamma-schedule", type=int, default=0, dest="gamma_schedule", help="Schedule gamma or not")
    parser.add_argument("--gamma-schedule-step1", type=int, default=gamma_schedule_step1, dest="gamma_schedule_step1", help="1st step for gamma schedule")
    parser.add_argument("--gamma-schedule-step2", type=int, default=gamma_schedule_step2, dest="gamma_schedule_step2", help="2nd step for gamma schedule")

    parser.add_argument("--log-interval", type=int, default=log_interval, dest="log_interval", help="Log Interval on Terminal")
    parser.add_argument("--save-interval", type=int, default=save_interval, dest="save_interval", help="Save Interval on Terminal")
    parser.add_argument("--saved_model_name", type=str, default="resnet50_cross_entropy_350.model", dest="saved_model_name", help="file name of the pre-trained model")
    parser.add_argument("--save-path", type=str, default=save_loc, dest="save_loc", help='Path to export the model')
    parser.add_argument("--model-name", type=str, default=model_name, dest="model_name", help='name of the model')
    parser.add_argument("--load-path", type=str, default=load_loc, dest="load_loc", help='Path to load the model from')

    parser.add_argument("--model", type=str, default=model, dest="model", help='CNN Model to train')
    parser.add_argument("--first-milestone", type=int, default=first_milestone, dest="first_milestone", help="First milestone to change lr")
    parser.add_argument("--second-milestone", type=int, default=second_milestone, dest="second_milestone", help="Second milestone to change lr")
    
    parser.add_argument("--use-corruption", action="store_true", default=False, dest="use_corruption", help='use corrupted images for training')
    parser.add_argument("--corruption-prob", type=float, default=0.3, dest="corruption_prob", help='probability of applying corruption to images')
    parser.add_argument("--corruption-types", type=str, nargs='+', default=None, dest="corruption_types", 
                       help='specific corruption types to use (space-separated)')
    parser.add_argument("--corruption-severity-min", type=int, default=1, dest="corruption_severity_min", help='minimum corruption severity')
    parser.add_argument("--corruption-severity-max", type=int, default=3, dest="corruption_severity_max", help='maximum corruption severity')
    
    parser.add_argument("--few-shot", action="store_true", default=False, dest="few_shot", help='use few-shot learning')
    parser.add_argument("--samples-per-class", type=int, default=None, dest="samples_per_class", help='number of samples per class for few-shot learning')
    parser.add_argument("--sample-ratio", type=float, default=None, dest="sample_ratio", help='ratio of samples to keep per class (0-1)')
    parser.add_argument("--few-shot-seed", type=int, default=42, dest="few_shot_seed", help='random seed for few-shot sampling')

    return parser.parse_args()

def get_scheduler(optimizer, args):
    """Get learning rate scheduler for CNN models"""
    if args.dataset == 'tiny_imagenet':
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[args.first_milestone, args.second_milestone], gamma=0.1)
    elif args.dataset == 'imagenet':
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 60, 80], gamma=0.1)
    else:  # cifar 10/100
        milestones = [args.first_milestone, args.second_milestone]
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)
    return scheduler
    
if __name__ == "__main__":

    torch.manual_seed(1)
    args = parseArgs()

    cuda = False
    if (torch.cuda.is_available() and args.gpu):
        cuda = True
    device = torch.device("cuda" if cuda else "cpu")


    swanlab.init(
        # workspace="GardCalibration",
        project = args.dataset+'_baselines',
        experiment_name=args.model + '_' +  loss_function_save_name(args.loss_function, args.gamma_schedule, 0, args.gamma, args.gamma2, args.gamma3, args.lamda),
        config=args,
        mode=saveSwanLab,
    )

    num_classes = dataset_num_classes[args.dataset]

    # Choosing the model to train
    net = models[args.model](num_classes=num_classes)

    # Setting model name
    if args.model_name is None:
        args.model_name = args.model

    if args.gpu is True:
        net.cuda()
        # net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
        cudnn.benchmark = True

    start_epoch = 0
    num_epochs = args.epoch
    if args.load:
        net.load_state_dict(torch.load(args.save_loc + args.saved_model_name))
        start_epoch = int(args.saved_model_name[args.saved_model_name.rfind('_')+1:args.saved_model_name.rfind('.model')])

    if args.optimiser == "sgd":
        optimizer = optim.SGD(net.parameters(), lr=args.learning_rate, momentum=args.momentum, 
                              weight_decay=args.weight_decay, nesterov=args.nesterov)
    elif args.optimiser == "adam":
        optimizer = optim.Adam(net.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
    elif args.optimiser == "adamw":
        optimizer = optim.AdamW(net.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay,
                               betas=(0.9, 0.999), eps=1e-8)

    # Get data loaders
    if args.dataset == 'tiny_imagenet':
        train_loader = dataset_loader[args.dataset].get_data_loader(
            root=args.dataset_root,
            split='train',
            batch_size=args.train_batch_size,
            pin_memory=args.gpu)

        val_loader = dataset_loader[args.dataset].get_data_loader(
            root=args.dataset_root,
            split='val',
            batch_size=args.test_batch_size,
            pin_memory=args.gpu)

        # test_loader = dataset_loader[args.dataset].get_data_loader(
        #     root=args.dataset_root,
        #     split='val',
        #     batch_size=args.test_batch_size,
        #     pin_memory=args.gpu)
    elif args.dataset == 'imagenet':
        # scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 60, 80], gamma=0.1)

        train_loader = dataset_loader[args.dataset].get_data_loader(
            batch_size=args.train_batch_size,
            root=args.dataset_root,
            split='train', shuffle=True, num_workers=args.num_workers,
            pin_memory=args.gpu)

        val_loader = dataset_loader[args.dataset].get_data_loader(
            batch_size=args.test_batch_size,
            root=args.dataset_root,
            split='val', shuffle=False, num_workers=args.num_workers,
            pin_memory=args.gpu)

    else:  # cifar 10/100
        train_loader, val_loader, _  = dataset_loader[args.dataset].get_train_valid_loader(
            batch_size=args.train_batch_size,
            augment=args.data_aug,
            random_seed=1,
            pin_memory=args.gpu)

        test_loader = dataset_loader[args.dataset].get_test_loader(
            batch_size=args.test_batch_size,
            pin_memory=args.gpu)

    if args.use_corruption:
        from Data.corrupted_dataset import CorruptedDataset
        corrupted_dataset = CorruptedDataset(
            train_loader.dataset,
            corruption_prob=args.corruption_prob,
            corruption_types=args.corruption_types,
            severity_range=(args.corruption_severity_min, args.corruption_severity_max)
        )
        train_loader = torch.utils.data.DataLoader(
            corrupted_dataset,
            batch_size=args.train_batch_size,
            shuffle=True,
            num_workers=train_loader.num_workers,
            pin_memory=args.gpu
        )

    # Get scheduler
    scheduler = get_scheduler(optimizer, args)

    training_set_loss = {}
    val_set_loss = {}
    test_set_loss = {}
    val_set_acc = {}

    for epoch in range(0, start_epoch):
        scheduler.step()

    best_val_acc = 0
    scaler = torch.cuda.amp.GradScaler()
    for epoch in range(start_epoch, num_epochs):
        if (args.loss_function == 'focal_loss' and args.gamma_schedule == 1):
            if (epoch < args.gamma_schedule_step1):
                gamma = args.gamma
            elif (epoch >= args.gamma_schedule_step1 and epoch < args.gamma_schedule_step2):
                gamma = args.gamma2
            else:
                gamma = args.gamma3
        else:
            gamma = args.gamma

        train_loss = train_single_epoch(epoch, net, train_loader, optimizer, device, 
                                    loss_function=args.loss_function, gamma=gamma, lamda=args.lamda, loss_mean=args.loss_mean, scaler=scaler) 
        val_loss = test_single_epoch(epoch, net, val_loader, device, 
                                    loss_function=args.loss_function, gamma=gamma, lamda=args.lamda)
        
        if not (args.dataset == 'tiny_imagenet' or args.dataset == 'imagenet'):  
            test_loss = test_single_epoch(epoch, net, test_loader, device,    
                                    loss_function=args.loss_function, gamma=gamma, lamda=args.lamda)
        else:
            test_loss = val_loss
        
        _, val_acc, _, _, _ = test_classification_net(net, val_loader, device)

        scheduler.step()
        
        nll_criterion = nn.CrossEntropyLoss().cuda()
        ece_criterion = ECELoss().cuda()
        adaece_criterion = AdaptiveECELoss().cuda()
        cece_criterion = ClasswiseECELoss().cuda()
        
        logits, labels = get_logits_labels(val_loader, net) 
        p_ece = ece_criterion(logits, labels).item()
        p_adaece = adaece_criterion(logits, labels).item()
        p_cece = cece_criterion(logits, labels).item()
        p_nll = nll_criterion(logits, labels).item()

        training_set_loss[epoch] = train_loss
        val_set_loss[epoch] = val_loss
        test_set_loss[epoch] = test_loss
        val_set_acc[epoch] = val_acc
        
        # Log learning rate for ViT models
        current_lr = optimizer.param_groups[0]['lr']
        swanlab.log({"train_loss": train_loss, "val_loss": val_loss, "test_loss": test_loss, "val_acc": val_acc, \
                     "p_ece": p_ece, "p_adaece": p_adaece, "p_cece": p_cece, "p_nll": p_nll, "lr": current_lr}, step=epoch)
        print(f'Epoch: {epoch} Train Loss: {train_loss:.6f} Val Loss: {val_loss:.6f} Test Loss: {test_loss:.6f} Val Acc: {val_acc:.4f} ECE: {p_ece:.4f} AdaECE: {p_adaece:.4f} CECE: {p_cece:.4f} NLL: {p_nll:.4f} LR: {current_lr:.6f}')

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            # print(f'New best Acc: {best_val_acc:.4f}')
            best_epoch = epoch
            best_net = net.state_dict()
            
        # if (epoch + 1) % args.save_interval == 0:    
        if (epoch + 1) == num_epochs:
            save_name = args.save_loc + args.dataset + '/' +\
                        args.model_name + '_' + \
                        loss_function_save_name(args.loss_function, args.gamma_schedule, gamma, args.gamma, args.gamma2, args.gamma3, args.lamda) + \
                        '_' + str(epoch + 1) + '.model'
            torch.save(net.state_dict(), save_name)

    # with open(save_name[:save_name.rfind('_')] + '_train_loss.json', 'a') as f:
    #     json.dump(training_set_loss, f)

    # with open(save_name[:save_name.rfind('_')] + '_val_loss.json', 'a') as fv:
    #     json.dump(val_set_loss, fv)

    # with open(save_name[:save_name.rfind('_')] + '_test_loss.json', 'a') as ft:
    #     json.dump(test_set_loss, ft)

    # with open(save_name[:save_name.rfind('_')] + '_val_error.json', 'a') as ft:
    #     json.dump(val_set_acc, ft)

    # save_name = args.save_loc + args.dataset + '/' +  args.model_name + '_' + \
    #             loss_function_save_name(args.loss_function, args.gamma_schedule, gamma, args.gamma, args.gamma2, args.gamma3, args.lamda) + \
    #             '_best_' + str(best_epoch + 1) + '.model'
    # torch.save(best_net, save_name)

    swanlab.finish()