# general imports
import gc
import numpy as np
import seaborn as sns
from tqdm.auto import tqdm
import random
import torch
import torchvision
import os
import pickle
import sys
import argparse
from torchvision import transforms, datasets
from torch.utils.data.dataset import random_split
from sklearn.model_selection import train_test_split
import pandas as pd
#from my_utils import SavePlot, Plot_Calscores
import cv2
import numpy as np
import glob 
# from glob import glob

import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from matplotlib.ticker import FormatStrFormatter

from math import ceil
from PIL.Image import BICUBIC
from PIL import Image
from torchvision.datasets.cifar import CIFAR100, CIFAR10
from torchvision.datasets import EMNIST
from torchvision.transforms import Compose, RandomCrop, Pad, RandomHorizontalFlip, Resize, RandomAffine
from torchvision.transforms import ToTensor, Normalize

from torch.utils.data import Subset,Dataset, Sampler

import torchvision.utils as vutils
import random
from torch.utils.data import DataLoader

# My imports
sys.path.insert(0, './')
import ICP.Score_Functions as scores

from ICP.utils_others import *
#from ICP.my_utils import *
#from ICP.clustering_utils import *
from ICP.conformal_utils2 import *

from collections import Counter
from scipy import stats, cluster

import models
from dataset.cifar10 import load_cifar10
from dataset.cifar100 import load_cifar100
from dataset.Eurosat import load_Eurosat, EurosatDataset
from dataset.EMNIST import load_Emnist
from train.imbalance_mini import IMBALANEMINIIMGNET
from train.imbalance_food import IMBALANCEFOOD
#from dataset.tiny_imagenet import load_Tiny

from ICP.generate_score import *

import matplotlib.pylab as pylab
params = {'legend.fontsize': 'xx-large',
          'figure.figsize': (10, 7),
         'axes.labelsize': 'xx-large',
         'axes.titlesize':'xx-large',
         'xtick.labelsize':'xx-large',
         'ytick.labelsize':'xx-large'}
pylab.rcParams.update(params)


model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))

print(model_names, 'model_names')

    
# parameters
parser = argparse.ArgumentParser(description='Experiments')
parser.add_argument('-a', '--alpha', default=0.1, type=float, help='Desired nominal marginal coverage')
parser.add_argument('-s', '--splits', default=10, type=int, help='Number of experiments to estimate coverage')

parser.add_argument('--coverage_on_label', action='store_true', help='True for getting coverage and size for each label')
parser.add_argument('--dataset', default='cifar10', help='dataset setting')
parser.add_argument('-ar', '--arch', metavar='ARCH', default='resnet32',
                    choices=model_names,
                    help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet32)')
parser.add_argument('--loss_type', default="CE", type=str, help='loss type')
parser.add_argument('--imb_type', default="exp", type=str, help='imbalance type')
parser.add_argument('--rho', default=0.01, type=float, help='imbalance factor')
parser.add_argument('--train_rule', default='None', type=str, help='data sampling strategy for train loader')
parser.add_argument('--rand_number', default=0, type=int, help='fix random number for data sampling')
parser.add_argument('--exp_str', default='0', type=str, help='number to indicate which experiment it is')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=10, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch_size', default=128, type=int,
                    metavar='N',
                    help='mini-batch size')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--wd', '--weight-decay', default=2e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)',
                    dest='weight_decay')
parser.add_argument('-p', '--print-freq', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
parser.add_argument('--seeds', default=None, type=int,
                    help='seed for initializing training. ')
parser.add_argument('--gpu', default=None, type=int,
                    help='GPU id to use.')

parser.add_argument('-score_functions', type=str,  nargs='+', 
                    help='Conformal score functions to use. List with a space in between. Options are'
                    '"softmax", "APS", "RAPS"')
parser.add_argument('-methods', type=str,  nargs='+', 
                    help='Conformal methods to use. List with a space in between. Options include'
                    '"MCP", "CCP", "k-CCP", "always_cluster"')
parser.add_argument('-seeds', type=int,  nargs='+', 
                    help='Seeds for random splits into calibration and validation sets,'
                    'List with spaces in between')
parser.add_argument('-avg_num_per_class', type=int,
                        help='Number of examples per class, on average, to include in calibration dataset')
parser.add_argument('--calibration_sampling', type=str, default='random',
                    help='How to sample the calibration set. Options are "random" and "balanced"')
parser.add_argument('--bins', type=int, default='10',
                    help='Histgram range to plot"')
parser.add_argument('--mgap', type=float,  default='1.0', 
                    help='Maximum Concentration gap for ablation study')
parser.add_argument('--lmbda_val', type=float, default='0.01',
                    help='lmbda val for RAPS')
parser.add_argument('--k_reg', type=int, default='5',
                    help='k_reg value for RAPS')

args = parser.parse_args()
print(f"args = {args}")
# parameters
alpha = args.alpha  # desired nominal marginal coverage
n_experiments = args.splits  # number of experiments to estimate coverage

dataset = args.dataset  # dataset to be used  CIFAR100', 'CIFAR10', 'ImageNet'
calibration_scores = ['SC', 'HCC', 'SC_Reg']  # score function to check 'HCC', 'SC', 'SC_Reg'
coverage_on_label = True # Whether to calculate coverage and size per class

# number of test points (if larger then available it takes the entire set)
if dataset == 'ImageNet':
    n_test = 50000
elif dataset == 'eurosat':
    n_test = 5400
elif dataset == 'EMNIST':
    n_test = 20800
elif dataset == 'tiny':
    n_test = 100000
else:
    n_test = 10000

# Validate parameters
assert 0 <= alpha <= 1, 'Nominal level must be between 0 to 1'
assert isinstance(n_experiments, int) and n_experiments >= 1, 'number of splits must be a positive integer.'


# The GPU used for oue experiments can only handle the following quantities of images per batch
GPU_CAPACITY = args.batch_size
torch.cuda.set_device(args.gpu)
device = torch.cuda.current_device()
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"device = {device}")
# set random seed
seed = 0
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

def SavePlot(args, path, x=None, y=None, col=None, hue=None, data=None, kind=None, rotation=45):
    sns.set_style("whitegrid")
    sns.set(font_scale=1.5)
    legend_fontsize = 40
    tick_fontsize = 40
    label_fontsize = 40

    # Create the catplot and get the FacetGrid object
    g = sns.catplot(data=data, x=x, y=y, hue=hue, kind=kind, ci='sd', height=4.5, aspect=2.2,
                    palette=['black', 'darkblue', 'red'], legend_out=False)

    # Iterate over the axes in the FacetGrid
    for ax in g.axes.flat:
        ax.tick_params(labelsize=tick_fontsize)
        plt.xticks([])

        # Set x and y labels
        ax.set_xlabel('g', fontsize=label_fontsize)

        if y == 'Under Coverage Ratio':
            ax.set_ylabel('UCR', fontsize=label_fontsize)
            ax.legend(loc='upper right', fontsize=legend_fontsize)
            max_ucr = data[y].max()
            ax.set_ylim(0, max_ucr+0.05)

        if y == 'Average Prediction Set Size':
            ax.set_ylabel('APSS', fontsize=label_fontsize)
            ax.legend(loc='upper left', fontsize=legend_fontsize)

    # Save the plot
    plt.tight_layout()
    g.savefig(path, dpi=100, bbox_inches='tight', pad_inches=0.1)
    plt.close('all')

base_path = "dataset={}/architecture={}/loss_type={}/imb_type={}/imb_factor={}/train_rule={}/epochs={}/batch-size={}\
        /lr={}/momentum={}/".format(args.dataset, args.arch, args.loss_type, args.imb_type, args.rho, args.train_rule,\
             args.epochs, args.batch_size, args.lr, args.momentum)
patha_model = '/checkpoint/' + base_path + "{}_{}_{}_{}_{}_{}_{}".format(args.dataset, args.arch, args.loss_type, \
    args.train_rule, args.imb_type, args.rho, args.rand_number) + '/ckpt.best.pth.tar'
patha_model = os.getcwd() + patha_model
#print(patha_model)

patha = './Results/ablation/' + base_path 
if not os.path.exists(patha):
    os.makedirs(patha)

figures_path = './figure/ablation_new/'

if not os.path.exists(figures_path):
    os.makedirs(figures_path)

score_func = args.score_functions 
# print(score_func)

if 'RAPS' in score_func:
    result_folder = os.path.join(patha, f'{args.calibration_sampling}_calset/n_totalcal={args.avg_num_per_class}/score={args.score_functions}/ablation')
    result_folder = result_folder.replace("score=['", "score=").replace("']", "/")
    print(result_folder)
    file_path = os.path.join(result_folder, f'g=0.75_ablationresults.pkl')
else:
    result_folder = os.path.join(patha, f'{args.calibration_sampling}_calset/n_totalcal={args.avg_num_per_class}/score={args.score_functions}')
    result_folder = result_folder.replace("score=['", "score=").replace("']", "/")
    file_path = os.path.join(result_folder, f'g=0.75_ablation_results.pkl')


all_results = {}
# gaps = np.arange(0.1, args.mgap, 0.05)

if os.path.exists(file_path):
    with open(file_path, 'rb') as f:
        all_results = pickle.load(f)
else:
    print(f"File {file_path} does not exist.")

data_lists = {'Average Prediction Set Size': [], 'Under Coverage Ratio': [], 'Gap': [], 'Method': []}
for gap, methods in all_results.items():
    for method, result in methods.items():
        data_lists['Average Prediction Set Size'].append(result[3]['average set size'])
        data_lists['Under Coverage Ratio'].append(result[3]['undercov ratio'])
        data_lists['Gap'].append(gap)
        data_lists['Method'].append(method)

results = pd.DataFrame(data_lists)

data_lists = {'Average Prediction Set Size': [], 'Under Coverage Ratio': [], 'Gap': [], 'Method': []}
for gap, methods in all_results.items():
    for method, result in methods.items():
        data_lists['Average Prediction Set Size'].append(result[3]['average set size'])
        data_lists['Under Coverage Ratio'].append(result[3]['undercov ratio'])
        data_lists['Gap'].append(gap)
        data_lists['Method'].append(method)

# Create DataFrame
results = pd.DataFrame(data_lists)

results['Method'] = results['Method'].replace('cluster_CP', 'cluster-CP')

method_order = ['CCP', 'cluster-CP', 'k-CCP']

# Convert the 'Method' column to a categorical type with the specified order
results['Method'] = pd.Categorical(results['Method'], categories=method_order, ordered=True)

# # Prepare data for plotting
# f_results = pd.DataFrame({
#     'Average Prediction Set Size': np.concatenate([np.atleast_1d(x) for x in results['Average Prediction Set Size']]),
#     'Under Coverage Ratio': np.concatenate([np.atleast_1d(x) for x in results['Under Coverage Ratio']]),
#     'Gap': np.concatenate([np.atleast_1d(x) for x in results['Gap']]),
#     'Method': np.repeat(results['Method'], [len(np.atleast_1d(x)) for x in results['Average Prediction Set Size']])
# })

SavePlot(args, os.path.join(figures_path, f'{args.dataset}_rho_{args.rho}_loss_{args.loss_type}_type_{args.imb_type}_score_{args.score_functions}_UCR_ablation.pdf'), x = 'Gap', y='Under Coverage Ratio', hue='Method', data=results, kind='point', rotation=90)
SavePlot(args, os.path.join(figures_path, f'{args.dataset}_rho_{args.rho}_loss_{args.loss_type}_type_{args.imb_type}_score_{args.score_functions}_APSS_ablation.pdf'), x = 'Gap', y='Average Prediction Set Size', hue='Method', data=results, kind='point', rotation=90)
           

