from torch import optim
import torch
import torch.utils.data
import argparse
import torch.backends.cudnn as cudnn
import random
import json
import sys
import swanlab
# Import dataloaders
import Data.cifar10_faster as cifar10  
import Data.cifar100_faster as cifar100
# import Data.tiny_imagenet as tiny_imagenet

# Import network models
from Net.resnet import resnet50, resnet110
# from Net.resnet_tiny_imagenet import resnet50 as resnet50_ti
from Net.wide_resnet import wide_resnet_cifar
from Net.densenet import densenet121

# Import train and validation utilities
from train_gard_utils import train_single_epoch, test_single_epoch

# Import validation metrics
from Metrics.metrics import test_classification_net
from evaluate import get_logits_labels
import torch.nn as nn
from Metrics.metrics import ECELoss, AdaptiveECELoss, ClasswiseECELoss


"""The training file of our FGR method"""

saveSwanLab = 'cloud'  #  'disabled' 'cloud'

dataset_num_classes = {'cifar10': 10, 'cifar100': 100, 'tiny_imagenet': 200, 'imagenet': 1000}
dataset_loader = { 'cifar10': cifar10, 'cifar100': cifar100}

models = {
    'resnet50': resnet50,
    # 'resnet50_ti': resnet50_ti,
    'resnet110': resnet110,
    'wide_resnet': wide_resnet_cifar,
    'densenet121': densenet121
}

def loss_function_save_name(args, gamma=1.0, gamma1=1.0, gamma2=1.0, gamma3=1.0, lamda=1.0):
    res_dict = {
        'cross_entropy': 'cross_entropy',
        'focal_loss': 'focal_loss_gamma_' + str(gamma),
        'focal_loss_adaptive': 'focal_loss_adaptive_gamma_' + str(gamma),
        'mmce': 'mmce_lamda_' + str(lamda),
        'mmce_weighted': 'mmce_weighted_lamda_' + str(lamda),
        'brier_score': 'brier_score',
        'dual_focal_loss': 'dual_focal_loss_gamma_' + str(gamma),
        'bsce_gra': 'bsce_gra_'+ str(gamma),
    }
    if (args.loss_function == 'focal_loss' and args.scheduled == True):
        res_str = 'focal_loss_scheduled_gamma_' + str(gamma1) + '_' + str(gamma2) + '_' + str(gamma3)
    else:
        res_str = res_dict[args.loss_function]
    experiment_name = res_str + args.remark

    if args.use_corruption:
        experiment_name += f'_corrupted_{args.corruption_prob}'
    if args.few_shot:
        if args.samples_per_class is not None:
            experiment_name += f'_fewshot_{args.samples_per_class}per'
        elif args.sample_ratio is not None:
            experiment_name += f'_fewshot_{args.sample_ratio}ratio'

    return experiment_name if args.load else experiment_name + "_noload"    


def parseArgs():
    default_dataset = 'cifar10'
    dataset_root = "./dataset/tiny_imagenet/"  
    train_batch_size = 128
    test_batch_size = 128
    learning_rate = 0.1
    momentum = 0.9
    optimiser = "sgd"
    loss = "cross_entropy"
    gamma = 1.0
    gamma2 = 1.0
    gamma3 = 1.0
    lamda = 1.0
    weight_decay = 5e-4
    log_interval = 50
    save_interval = 50
    save_loc = './checkpoints/'
    model_name = None
    saved_model_name = "resnet50_cross_entropy_350.model"
    load_loc = './'
    model = "resnet50"  
    epoch = 350
    
    first_milestone = 150  # Milestone for change in lr  
    second_milestone = 250  # Milestone for change in lr
    gamma_schedule_step1 = 100
    gamma_schedule_step2 = 250

    parser = argparse.ArgumentParser(description="Training for calibration.", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--seed", type=int, default=1, dest="seed", help='Random seed for reproducibility')
    parser.add_argument("--dataset", type=str, default=default_dataset, dest="dataset", help='dataset to train on')
    parser.add_argument("--dataset-root", type=str, default=dataset_root, dest="dataset_root", help='root path of the dataset (for tiny imagenet)')
    parser.add_argument("--data-aug", action="store_true", dest="data_aug")
    parser.set_defaults(data_aug=True)
    parser.add_argument("-g", action="store_true", dest="gpu", help="Use GPU")
    parser.set_defaults(gpu=True)
    parser.add_argument("--load", action="store_true", default=False, dest="load", help="Load from pretrained model")   
    parser.add_argument("-b", type=int, default=train_batch_size, dest="train_batch_size", help="Batch size")
    parser.add_argument("-tb", type=int, default=test_batch_size, dest="test_batch_size", help="Test Batch size")
    parser.add_argument("-e", type=int, default=epoch, dest="epoch", help='Number of training epochs')
    parser.add_argument("--lr", type=float, default=learning_rate, dest="learning_rate", help='Learning rate')
    parser.add_argument("--mom", type=float, default=momentum, dest="momentum", help='Momentum')
    parser.add_argument("--nesterov", action="store_true", dest="nesterov", help="Whether to use nesterov momentum in SGD")
    parser.set_defaults(nesterov=False)
    parser.add_argument("--decay", type=float, default=weight_decay, dest="weight_decay", help="Weight Decay")
    parser.add_argument("--opt", type=str, default=optimiser, dest="optimiser", help='Choice of optimisation algorithm')

    parser.add_argument("--loss", type=str, default=loss, dest="loss_function", help="Loss function to be used for training")
    parser.add_argument("--loss-mean", action="store_true", dest="loss_mean", help="whether to take mean of loss instead of sum to train")
    parser.set_defaults(loss_mean=False)
    parser.add_argument("--gamma", type=float, default=gamma, dest="gamma", help="Gamma for focal components")
    parser.add_argument("--gamma2", type=float, default=gamma2, dest="gamma2", help="Gamma for different focal components")
    parser.add_argument("--gamma3", type=float, default=gamma3, dest="gamma3", help="Gamma for different focal components")
    parser.add_argument("--lamda", type=float, default=lamda, dest="lamda", help="Regularization factor")
    parser.add_argument("--gamma-schedule", type=int, default=0, dest="gamma_schedule", help="Schedule gamma or not")
    parser.add_argument("--gamma-schedule-step1", type=int, default=gamma_schedule_step1, dest="gamma_schedule_step1", help="1st step for gamma schedule")
    parser.add_argument("--gamma-schedule-step2", type=int, default=gamma_schedule_step2, dest="gamma_schedule_step2", help="2nd step for gamma schedule")

    parser.add_argument("--log-interval", type=int, default=log_interval, dest="log_interval", help="Log Interval on Terminal")
    parser.add_argument("--save-interval", type=int, default=save_interval, dest="save_interval", help="Save Interval on Terminal")
    parser.add_argument("--saved_model_name", type=str, default=saved_model_name, dest="saved_model_name", help="file name of the pre-trained model")
    parser.add_argument("--save-path", type=str, default=save_loc, dest="save_loc", help='Path to export the model')
    parser.add_argument("--save", action="store_false", default=True, dest="save_model", help='')
    parser.add_argument("--model-name", type=str, default=model_name, dest="model_name", help='name of the model')
    parser.add_argument("--load-path", type=str, default=load_loc, dest="load_loc", help='Path to load the model from')

    parser.add_argument("--model", type=str, default=model, dest="model", help='Model to train')
    parser.add_argument("--first-milestone", type=int, default=first_milestone, dest="first_milestone", help="First milestone to change lr")
    parser.add_argument("--second-milestone", type=int, default=second_milestone, dest="second_milestone", help="Second milestone to change lr")

    parser.add_argument("--freeze", action="store_true", default=False, dest="freeze", help='freeze backbone')
    parser.add_argument("--start-corrupt", type=int, default=0, dest="start_corrupt", help="")
    parser.add_argument("--remark", type=str, default="", dest="remark", help='remark for the experiment')
    
    parser.add_argument("--cal-loss", type=str, default="ce_soft_ece", dest="cal_loss", help='Calibration loss function to monitor')
    parser.add_argument("--cal-gamma", type=float, default=1.0, dest="cal_gamma", help='Gamma for calibration loss')
    parser.add_argument("--cal-lamda", type=float, default=1.0, dest="cal_lamda", help='Lambda for calibration loss')
    
    parser.add_argument("--use-corruption", action="store_true", default=False, dest="use_corruption", help='use corrupted images for training')
    parser.add_argument("--corruption-prob", type=float, default=0.3, dest="corruption_prob", help='probability of applying corruption to images')
    parser.add_argument("--corruption-types", type=str, nargs='+', default=None, dest="corruption_types", help='specific corruption types to use (space-separated)')
    parser.add_argument("--corruption-severity-min", type=int, default=1, dest="corruption_severity_min", help='minimum corruption severity')
    parser.add_argument("--corruption-severity-max", type=int, default=3, dest="corruption_severity_max", help='maximum corruption severity')
    parser.add_argument("--corradd", action="store_true", default=False, dest="corradd", help='use corradd for corruption')
    
    parser.add_argument("--few-shot", action="store_true", default=False, dest="few_shot", help='use few-shot learning')
    parser.add_argument("--samples-per-class", type=int, default=None, dest="samples_per_class", help='number of samples per class for few-shot learning')
    parser.add_argument("--sample-ratio", type=float, default=None, dest="sample_ratio", help='ratio of samples to keep per class (0-1)')
    parser.add_argument("--few-shot-seed", type=int, default=42, dest="few_shot_seed", help='random seed for few-shot sampling')
    
    parser.add_argument("--use-grad-correction", action="store_true", default=False, dest="use_gradient_correction", help='use gradient correction with calibration loss')
    parser.add_argument("--gradient-func", type=str, default='softece', dest="gradient_func", help='function to use for gradient correction, options: softece, softavuc, cece')
    parser.add_argument("--correction-alpha", type=float, default=1.0, dest="correction_alpha", help='weight for calibration loss in gradient correction')
    parser.add_argument("--correction-beta", type=float, default=0.5, dest="correction_beta", help='weight for cooperative enhancement in gradient correction')
    
    return parser.parse_args()


if __name__ == "__main__":

    args = parseArgs()
    torch.manual_seed(args.seed)
    # print(args)

    cuda = False
    if (torch.cuda.is_available() and args.gpu):
        cuda = True
    device = torch.device("cuda" if cuda else "cpu")
    # print("CUDA set: " + str(cuda) + "using device " + str(device))

    experiment_name = args.model + '_' + loss_function_save_name(args, args.gamma, args.gamma, args.gamma2, args.gamma3, args.lamda)

    swanlab.init(
        # workspace="GardCalibration",
        project = args.dataset+'_baselines',   
        experiment_name=experiment_name,
        config=args, mode=saveSwanLab
    )
    num_classes = dataset_num_classes[args.dataset]

    # Choosing the model to train
    net = models[args.model](num_classes=num_classes)

    # Setting model name
    if args.model_name is None:
        args.model_name = args.model

    if args.gpu is True:
        net.cuda()
        cudnn.benchmark = True

    start_epoch = 0
    num_epochs = args.epoch
    if args.load:
        net.load_state_dict(torch.load(args.save_loc + args.dataset + '/' + args.saved_model_name))
        # start_epoch = int(args.saved_model_name[args.saved_model_name.rfind('_')+1:args.saved_model_name.rfind('.model')])
        start_epoch = 1
        if args.dataset == 'tiny_imagenet':
            base_epoch = 100
        else:
            base_epoch = 350 
    else:
        base_epoch = 0

    if args.freeze:
        for param in net.parameters():
            param.requires_grad = False
        if hasattr(net, 'fc'):
            for param in net.fc.parameters():
                param.requires_grad = True
        else:
            raise ValueError("Model does not have a fc layer to train.")

    # print(net)

    if args.optimiser == "sgd":
        optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=args.learning_rate, momentum=args.momentum,
                              weight_decay=args.weight_decay, nesterov=args.nesterov)
    elif args.optimiser == "adam":            
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=args.learning_rate, weight_decay=args.weight_decay)
    
    if args.dataset == 'tiny_imagenet':
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[args.first_milestone, args.second_milestone], gamma=0.1)
        train_loader = dataset_loader[args.dataset].get_data_loader(
            root=args.dataset_root,
            split='train',
            batch_size=args.train_batch_size,
            pin_memory=args.gpu)

        val_loader = dataset_loader[args.dataset].get_data_loader(
            root=args.dataset_root,
            split='val',
            batch_size=args.test_batch_size,
            pin_memory=args.gpu)

        test_loader = dataset_loader[args.dataset].get_data_loader(
            root=args.dataset_root,
            split='test',
            batch_size=args.test_batch_size,
            pin_memory=args.gpu)
    else:  # cifar 10/100
        # milestones = [60, 120, 160]  
        milestones = [args.first_milestone, args.second_milestone]
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

        train_loader, val_loader, train_idx = dataset_loader[args.dataset].get_train_valid_loader(
            batch_size=args.train_batch_size,
            augment=args.data_aug,
            random_seed=1,
            pin_memory=args.gpu)

        test_loader = dataset_loader[args.dataset].get_test_loader(
            batch_size=args.test_batch_size,
            pin_memory=args.gpu)

        from torch.utils.data import Subset
        train_subset = Subset(train_loader.dataset, train_idx)
        
    
    if args.dataset == 'tiny_imagenet':
        final_train_dataset = train_loader.dataset
    else:
        final_train_dataset = train_subset
    
    if args.use_corruption:
        from Data.corrupted_dataset import CorruptedDataset, AugmentedCorruptedDataset
        if not args.corradd:
            corrupted_dataset = CorruptedDataset(
                final_train_dataset,
                corruption_prob=args.corruption_prob,corruption_types=args.corruption_types,
                severity_range=(args.corruption_severity_min, args.corruption_severity_max)
            )
        else:
            corrupted_dataset = AugmentedCorruptedDataset(
                final_train_dataset,
                corruption_prob=args.corruption_prob,corruption_types=args.corruption_types,
                severity_range=(args.corruption_severity_min, args.corruption_severity_max),
            )

        corrupted_train_loader = torch.utils.data.DataLoader(
            corrupted_dataset,
            batch_size=args.train_batch_size,
            shuffle=True,  
            num_workers=train_loader.num_workers,
            pin_memory=args.gpu
        )
    else:
        if args.few_shot or args.dataset != 'tiny_imagenet':
            train_loader = torch.utils.data.DataLoader(
                final_train_dataset,
                batch_size=args.train_batch_size,
                shuffle=True,
                num_workers=train_loader.num_workers,
                pin_memory=args.gpu
            )

    training_set_loss = {}
    val_set_loss = {}
    test_set_loss = {}
    val_set_acc = {}
    nll_criterion = nn.CrossEntropyLoss().cuda()
    ece_criterion = ECELoss().cuda()
    adaece_criterion = AdaptiveECELoss().cuda()
    cece_criterion = ClasswiseECELoss().cuda()


    for epoch in range(0 + base_epoch, start_epoch + base_epoch):
        
        logits, labels = get_logits_labels(val_loader, net)  
        p_ece = ece_criterion(logits, labels).item()
        p_adaece = adaece_criterion(logits, labels).item()
        p_cece = cece_criterion(logits, labels).item()
        p_nll = nll_criterion(logits, labels).item()
        _, val_acc, _, _, _ = test_classification_net(net, val_loader, device)
        # swanlab.log({"val_acc": val_acc,"p_ece": p_ece, "p_adaece": p_adaece, "p_cece": p_cece, "p_nll": p_nll}, step=epoch)
        print(f'Epoch: {epoch} val_acc: {val_acc:.4f}   ECE: {p_ece:.4f} AdaECE: {p_adaece:.4f} CECE: {p_cece:.4f} NLL: {p_nll:.4f}')

        logits, labels = get_logits_labels(test_loader, net)  
        p_ece = ece_criterion(logits, labels).item()
        p_adaece = adaece_criterion(logits, labels).item()
        p_cece = cece_criterion(logits, labels).item()
        p_nll = nll_criterion(logits, labels).item()
        _, test_acc, _, _, _ = test_classification_net(net, test_loader, device)
        print(f'Epoch: {epoch} val_acc: {test_acc:.4f}  ECE: {p_ece:.4f} AdaECE: {p_adaece:.4f} CECE: {p_cece:.4f} NLL: {p_nll:.4f}')

    best_val_acc = 0
    scaler = torch.cuda.amp.GradScaler()
    # for epoch in range(start_epoch, num_epochs):
    for epoch in range(start_epoch + base_epoch, num_epochs + base_epoch):
        if args.use_corruption and epoch >= args.start_corrupt and hasattr(corrupted_train_loader.dataset, 'reset_corruption_selection'):
            corrupted_train_loader.dataset.reset_corruption_selection()
        if args.corradd and args.use_corruption and epoch >= args.start_corrupt and hasattr(corrupted_train_loader.dataset, 'regenerate_corruptions'):
            corrupted_train_loader.dataset.regenerate_corruptions()

        if (args.loss_function == 'focal_loss' and args.gamma_schedule == 1):
            if (epoch < args.gamma_schedule_step1):
                gamma = args.gamma
            elif (epoch >= args.gamma_schedule_step1 and epoch < args.gamma_schedule_step2):
                gamma = args.gamma2
            else:
                gamma = args.gamma3
        else:
            gamma = args.gamma

        if args.use_corruption and epoch >= args.start_corrupt:
            train_loss = train_single_epoch(epoch, net, corrupted_train_loader, optimizer, device, 
                                    loss_function=args.loss_function, gamma=gamma, lamda=args.lamda, loss_mean=args.loss_mean,
                                    use_gradient_correction=args.use_gradient_correction, 
                                    gradient_func=args.gradient_func, num_classes=num_classes, scaler=scaler)
        else:
            train_loss = train_single_epoch(epoch, net, train_loader, optimizer, device, 
                                    loss_function=args.loss_function, gamma=gamma, lamda=args.lamda, loss_mean=args.loss_mean, 
                                    num_classes=num_classes, scaler=scaler)
        val_loss = test_single_epoch(epoch, net, val_loader, device, 
                                    loss_function=args.loss_function, gamma=gamma, lamda=args.lamda, num_classes=num_classes)
        test_loss = test_single_epoch(epoch, net, test_loader, device,
                                    loss_function=args.loss_function, gamma=gamma, lamda=args.lamda, num_classes=num_classes)
        _, val_acc, _, _, _ = test_classification_net(net, val_loader, device)
        _, test_acc, _, _, _ = test_classification_net(net, test_loader, device)

        scheduler.step()
        
        logits, labels = get_logits_labels(val_loader, net)  # val_loader
        p_ece = ece_criterion(logits, labels).item()
        p_adaece = adaece_criterion(logits, labels).item()
        p_cece = cece_criterion(logits, labels).item()
        p_nll = nll_criterion(logits, labels).item()

        training_set_loss[epoch] = train_loss
        val_set_loss[epoch] = val_loss
        test_set_loss[epoch] = test_loss
        val_set_acc[epoch] = val_acc
        swanlab.log({"train_loss": train_loss, "val_loss": val_loss, "test_loss": test_loss, "val_acc": val_acc, \
                     "p_ece": p_ece, "p_adaece": p_adaece, "p_cece": p_cece, "p_nll": p_nll, "test_acc": test_acc}, step=epoch)
        print(f'Epoch: {epoch} Train Loss: {train_loss:.6f} Val  Loss: {val_loss:.6f} Val  Acc: {val_acc:.4f} ECE: {p_ece:.4f} AdaECE: {p_adaece:.4f} CECE: {p_cece:.4f} NLL: {p_nll:.4f}')

        logits, labels = get_logits_labels(test_loader, net)  # test_loader
        p_ece = ece_criterion(logits, labels).item()
        p_adaece = adaece_criterion(logits, labels).item()
        p_cece = cece_criterion(logits, labels).item()
        p_nll = nll_criterion(logits, labels).item() 
        print(f'Epoch: {epoch} Train Loss: {train_loss:.6f} Test Loss: {test_loss:.6f} Test Acc: {test_acc:.4f} ECE: {p_ece:.4f} AdaECE: {p_adaece:.4f} CECE: {p_cece:.4f} NLL: {p_nll:.4f}')

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            # print(f'New best Acc: {best_val_acc:.4f}')
            best_epoch = epoch
            best_net = net.state_dict()
            
        if (epoch + 1) == num_epochs + base_epoch:
            save_name = args.save_loc + args.dataset + '_gard/' + args.model_name + '_' + \
                        loss_function_save_name(args, gamma, args.gamma, args.gamma2, args.gamma3, args.lamda) + \
                        '_' + str(epoch + 1) + '.model'
            if args.save_model:
                torch.save(net.state_dict(), save_name)

    swanlab.finish()