import os
import sys
import torch
import random
import argparse
from torch import nn
import matplotlib.pyplot as plt
import torch.backends.cudnn as cudnn
# Import dataloaders
import Data.cifar10_faster as cifar10
import Data.cifar100_faster as cifar100
import Data.cifar10_c as cifar10_c
import Data.cifar100_c as cifar100_c
import Data.tiny_imagenet_c as tiny_imagenet_c
import Data.tiny_imagenet as tiny_imagenet
import Data.imagenet as imagenet
import Data.imagenet_c as imagenet_c

# Import network architectures
from Net.resnet_tiny_imagenet import resnet50 as resnet50_ti
from Net.resmet_imagenet import resnet50_imagenet1k as resnet50_imagenet
from Net.resnet import resnet50, resnet110
from Net.wide_resnet import wide_resnet_cifar
from Net.densenet import densenet121
from Net.vit import vit_base_patch16_224, vit_small_patch16_224

# Import metrics to compute
from Metrics.metrics import test_classification_net_logits
from Metrics.metrics import ECELoss, AdaptiveECELoss, ClasswiseECELoss
# Import temperature scaling and NLL utilities
from temperature_scaling import ModelWithTemperature
from evaluate_scripts.change2args import update_args_from_shell
from Metrics.corruption_test import corruption_test


# Dataset params
dataset_num_classes = {'cifar10': 10, 'cifar100': 100, 'tiny_imagenet': 200, 'imagenet': 1000}

dataset_loader = {
    'cifar10': cifar10,
    'cifar100': cifar100,
    'tiny_imagenet': tiny_imagenet,
    'imagenet': imagenet
}

# Mapping model name to model function
models = {
    'resnet50': resnet50,
    'resnet50_ti': resnet50_ti,
    'resnet50_imagenet': resnet50_imagenet,
    'resnet110': resnet110,
    'wide_resnet': wide_resnet_cifar,
    'densenet121': densenet121,
    'vit_base': vit_base_patch16_224,
    'vit_small': vit_small_patch16_224,
}

def is_vit_model(model_name):
    """Check if the model is a Vision Transformer"""
    return model_name.startswith('vit_')

def parseArgs():
    default_dataset = 'cifar10'
    dataset_root = './'
    model = 'resnet50'
    save_loc = './checkpoints/'
    saved_model_name = 'resnet50_cross_entropy_350.model'
    num_bins = 15
    model_name = None
    train_batch_size = 512
    test_batch_size = 512
    cross_validation_error = 'ece'  # todo ece or nll

    parser = argparse.ArgumentParser(description="Evaluating a single model on calibration metrics.", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--dataset", type=str, default=default_dataset, dest="dataset", help='dataset to test on')
    parser.add_argument("--dataset-root", type=str, default=dataset_root, dest="dataset_root", help='root path of the dataset (for tiny imagenet)')
    parser.add_argument("--model-name", type=str, default=model_name, dest="model_name", help='name of the model')
    parser.add_argument("--model", type=str, default=model, dest="model", help='Model to test')
    parser.add_argument("--save-path", type=str, default=save_loc, dest="save_loc", help='Path to import the model')
    parser.add_argument("--saved_model_name", type=str, default=saved_model_name, dest="saved_model_name", help="file name of the pre-trained model")
    parser.add_argument("--num-bins", type=int, default=num_bins, dest="num_bins", help='Number of bins')
    parser.add_argument("-g", action="store_true", dest="gpu", help="Use GPU")
    parser.set_defaults(gpu=True)
    parser.add_argument("-da", action="store_true", dest="data_aug", help="Using data augmentation")
    parser.set_defaults(data_aug=True)
    parser.add_argument("-b", type=int, default=train_batch_size, dest="train_batch_size", help="Batch size")
    parser.add_argument("-tb", type=int, default=test_batch_size, dest="test_batch_size", help="Test Batch size")
    parser.add_argument("--cverror", type=str, default=cross_validation_error, dest="cross_validation_error", help='Error function to do temp scaling')
    parser.add_argument("-log", action="store_true", default=True, dest="log", help="whether to print log data")
    parser.add_argument("-prompted", default=False, dest="prompted", help="Whether to use prompted model")
    
    parser.add_argument("--use-pretrained", action="store_true", default=False, dest="use_pretrained", help="Use pretrained model directly without loading checkpoint")
    
    parser.add_argument("--freeze", action="store_true", default=False, dest="freeze", help='freeze backbone')
    parser.add_argument("--freeze-fc", action="store_true", default=False, dest="freeze_fc", help='freeze fc layer')
    parser.add_argument("--freeze-prompt", action="store_true", default=False, dest="freeze_prompt", help='')
    parser.add_argument("--pixel-prompt", action="store_true", default=False, dest="pixel_prompt", help='use pixel prompt')
    parser.add_argument("--dynamic-pixel", action="store_true", default=False, dest="dynamic_pixel", help='use dynamic pixel prompt')
    parser.add_argument("--pixel-size", type=int, default=4, dest="pixel_size", help='size of pixel prompt')
    parser.add_argument("--remark", type=str, default="", dest="remark", help='remark for the experiment') 
    
    parser.add_argument("--gard", action="store_true", default=False, dest="gard", help='') 

    # ViT specific arguments
    parser.add_argument("--vit-input-size", type=int, default=224, dest="vit_input_size", help="Input image size for ViT models")
    
    parser.add_argument("--use-prompt", action="store_true", default=False, dest="use_prompt",  help="Use prompt tuning for ViT models")
    parser.add_argument("--prompt-tokens", type=int, default=10, dest="prompt_tokens", help="Number of prompt tokens")
    parser.add_argument("--prompt-deep", action="store_true", default=False, dest="prompt_deep", help="Use deep prompting")
    
    # Gaussian Data Refinement arguments 
    parser.add_argument("--gaussian-refinement", action="store_true", default=False, dest="gaussian_refinement", help="Apply Gaussian Data Refinement for robust temperature scaling")
    parser.add_argument("--gaussian-eps", type=float, nargs='+', default=[0.1, 0.2, 0.3], dest="gaussian_eps", help="Gaussian noise levels for data refinement (e.g., --gaussian-eps 0.1 0.2 0.3)")

    parser.add_argument("--adaptive-gaussian", action="store_true", default=False, dest="adaptive_gaussian", help="Use adaptive Gaussian Data Refinement to automatically optimize epsilon values")
    parser.add_argument("--adaptive-levels", type=int, default=6, dest="adaptive_levels", help="Number of refinement levels for adaptive Gaussian refinement")



    return parser.parse_args()

def get_logits_labels(data_loader, net):
    logits_list = []
    labels_list = []
    net.eval()
    with torch.no_grad():
        for batch_data in data_loader:
            if len(batch_data) == 2:  # For datasets with additional info
                data, label = batch_data
            else:
                data, label, _ = batch_data
            data = data.cuda()
            if hasattr(net, 'architecture') and net.architecture == 'CNN':
                logits, _ = net(data)
            else:
                logits = net(data)
            logits_list.append(logits)
            labels_list.append(label)
        logits = torch.cat(logits_list).cuda()
        labels = torch.cat(labels_list).cuda()
    return logits, labels

nll_criterion = nn.CrossEntropyLoss().cuda()
ece_criterion = ECELoss().cuda()
adaece_criterion = AdaptiveECELoss().cuda()
cece_criterion = ClasswiseECELoss().cuda()

def getMetrics(logits, labels):
    ece = ece_criterion(logits, labels).item()
    adaece = adaece_criterion(logits, labels).item()
    cece = cece_criterion(logits, labels).item()
    nll = nll_criterion(logits, labels).item()
    return ece, adaece, cece, nll


if __name__ == "__main__":
    # Checking if GPU is available
    cuda = False
    if (torch.cuda.is_available()):
        cuda = True

    # Setting additional parameters
    torch.manual_seed(1)
    device = torch.device("cuda" if cuda else "cpu")

    args = parseArgs()

    if args.model_name is None:
        args.model_name = args.model

    dataset = args.dataset
    # dataset_root = args.dataset_root
    model_name = args.model_name
    save_loc = args.save_loc
    saved_model_name = args.saved_model_name
    num_bins = args.num_bins
    cross_validation_error = args.cross_validation_error

    # Taking input for the dataset
    num_classes = dataset_num_classes[dataset]
    
    input_size = args.vit_input_size if is_vit_model(args.model) else None
    
    if (args.dataset == 'tiny_imagenet'):
        val_loader = dataset_loader[args.dataset].get_data_loader(
            # root=args.dataset_root,
            split='val',
            batch_size=args.test_batch_size,
            pin_memory=args.gpu,
            input_size=input_size)

        test_loader = dataset_loader[args.dataset].get_data_loader(
            split='test',
            batch_size=args.test_batch_size,
            pin_memory=args.gpu,
            input_size=input_size)
    elif (args.dataset == 'imagenet'):
        val_loader = dataset_loader[args.dataset].get_data_loader(
            split='val',
            batch_size=args.test_batch_size,
            pin_memory=args.gpu,
            input_size=input_size)

        test_loader = dataset_loader[args.dataset].get_data_loader(
            split='test',
            batch_size=args.test_batch_size,
            pin_memory=args.gpu,
            input_size=input_size)
    else:
        _, val_loader, _ = dataset_loader[args.dataset].get_train_valid_loader(
            batch_size=args.train_batch_size,
            augment=args.data_aug,
            random_seed=1,
            pin_memory=args.gpu,
            input_size=input_size
        )

        test_loader = dataset_loader[args.dataset].get_test_loader(
            batch_size=args.test_batch_size,
            pin_memory=args.gpu,
            input_size=input_size
        )

    model = models[model_name]

    # net = model(num_classes=num_classes, temp=1.0).cuda()
    if is_vit_model(args.model) and args.use_prompt:
        net = model(num_classes=num_classes, use_prompt=True, args=args).cuda()
    else:
        net = model(num_classes=num_classes).cuda()
    
    if args.use_pretrained:
        print("Using pretrained model directly without loading checkpoint")
        if not is_vit_model(args.model):
            print("Warning: --use-pretrained is primarily designed for ViT models with ImageNet pretraining")
    else:
        if args.gard:
            net.load_state_dict(torch.load(args.save_loc + args.dataset + '_gard/' + args.saved_model_name))
        else:
            net.load_state_dict(torch.load(args.save_loc + args.dataset + '/' + args.saved_model_name))
    
    cudnn.benchmark = True
    
    logits, labels = get_logits_labels(test_loader, net)
    conf_matrix, p_accuracy, _, _, _ = test_classification_net_logits(logits, labels)
    p_ece, p_adaece, p_cece, p_nll = getMetrics(logits, labels)

    # Printing the required evaluation metrics
    # print (conf_matrix)
    print(f'Test Acc: {p_accuracy*100:.2f}%  Test NLL: {p_nll:.5f}   ECE: {p_ece:.5f}  AdaECE: {p_adaece:.5f}  Classwise_ECE: {p_cece:.5f}')

    scaled_model = ModelWithTemperature(net, args.log)
    
    # Apply Gaussian Data Refinement if requested
    if args.gaussian_refinement:
        if args.adaptive_gaussian:
            print(f"Using adaptive Gaussian data refinement will optimize the {args.adaptive_levels} refinement levels")
            scaled_model.set_temperature(val_loader, cross_validate=cross_validation_error, 
                                       use_gaussian_refinement=True, gaussian_eps=args.gaussian_eps,
                                       adaptive_gaussian=True, n_classes=num_classes)
        else:
            print(f"Refined using fixed Gaussian data, epsilon level: {args.gaussian_eps}")
            scaled_model.set_temperature(val_loader, cross_validate=cross_validation_error, 
                                       use_gaussian_refinement=True, gaussian_eps=args.gaussian_eps)
    else:
        scaled_model.set_temperature(val_loader, cross_validate=cross_validation_error)
    
    T_opt = scaled_model.get_temperature()
    logits, labels = get_logits_labels(test_loader, scaled_model)
    conf_matrix, accuracy, _, _, _ = test_classification_net_logits(logits, labels)
    ece, adaece, cece, nll = getMetrics(logits, labels)

    # print (conf_matrix)
    # print(f'Test error: {1 - accuracy:.5f}', end=' ')
    print(f'Test Acc: {accuracy*100:.2f}%  Test NLL: {nll:.5f}   ECE: {ece:.5f}  AdaECE: {adaece:.5f}  Classwise_ECE: {cece:.5f}')

    # Test NLL & ECE & AdaECE & Classwise ECE
    # print(res_str)

    print("----------------------------------------------------")
    
    clean_results = {
        'raw': [p_accuracy, p_ece, p_adaece, p_cece, p_nll],
        'ts': [accuracy, ece, adaece, cece, nll]
    }
    corruption_result = corruption_test(args, net, scaled_model, clean_results)

