# general imports
import gc
import numpy as np
import seaborn as sns
from tqdm.auto import tqdm
import random
import torch
import torchvision
import os
import pickle
import sys
import argparse
from torchvision import transforms, datasets
from torch.utils.data.dataset import random_split
from sklearn.model_selection import train_test_split
import pandas as pd
#from my_utils import SavePlot, Plot_Calscores
import cv2
import numpy as np
import glob 
# from glob import glob

import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from matplotlib.ticker import FormatStrFormatter

from math import ceil
from PIL.Image import BICUBIC
from PIL import Image
from torchvision.datasets.cifar import CIFAR100, CIFAR10
from torchvision.datasets import EMNIST
from torchvision.transforms import Compose, RandomCrop, Pad, RandomHorizontalFlip, Resize, RandomAffine
from torchvision.transforms import ToTensor, Normalize

from torch.utils.data import Subset,Dataset, Sampler

import torchvision.utils as vutils
import random
from torch.utils.data import DataLoader

# My imports
sys.path.insert(0, './')
import ICP.Score_Functions as scores

from ICP.utils_others import *
#from ICP.my_utils import *
#from ICP.clustering_utils import *
from ICP.conformal_utils2 import *

from collections import Counter
from scipy import stats, cluster

import models
from dataset.cifar10 import load_cifar10
from dataset.cifar100 import load_cifar100
from dataset.Eurosat import load_Eurosat, EurosatDataset
from dataset.EMNIST import load_Emnist
from train.imbalance_mini import IMBALANEMINIIMGNET
from train.imbalance_food import IMBALANCEFOOD
#from dataset.tiny_imagenet import load_Tiny

from ICP.generate_score import *

import matplotlib.pylab as pylab
params = {'legend.fontsize': 'xx-large',
          'figure.figsize': (10, 7),
         'axes.labelsize': 'xx-large',
         'axes.titlesize':'xx-large',
         'xtick.labelsize':'xx-large',
         'ytick.labelsize':'xx-large'}
pylab.rcParams.update(params)


model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))

print(model_names, 'model_names')

    
# parameters
parser = argparse.ArgumentParser(description='Experiments')
parser.add_argument('-a', '--alpha', default=0.1, type=float, help='Desired nominal marginal coverage')
parser.add_argument('-s', '--splits', default=10, type=int, help='Number of experiments to estimate coverage')

parser.add_argument('--coverage_on_label', action='store_true', help='True for getting coverage and size for each label')
parser.add_argument('--dataset', default='cifar10', help='dataset setting')
parser.add_argument('-ar', '--arch', metavar='ARCH', default='resnet32',
                    choices=model_names,
                    help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet32)')
parser.add_argument('--loss_type', default="CE", type=str, help='loss type')
parser.add_argument('--imb_type', default="exp", type=str, help='imbalance type')
parser.add_argument('--rho', default=0.01, type=float, help='imbalance factor')
parser.add_argument('--train_rule', default='None', type=str, help='data sampling strategy for train loader')
parser.add_argument('--rand_number', default=0, type=int, help='fix random number for data sampling')
parser.add_argument('--exp_str', default='0', type=str, help='number to indicate which experiment it is')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=10, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch_size', default=128, type=int,
                    metavar='N',
                    help='mini-batch size')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--wd', '--weight-decay', default=2e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)',
                    dest='weight_decay')
parser.add_argument('-p', '--print-freq', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
parser.add_argument('--seeds', default=None, type=int,
                    help='seed for initializing training. ')
parser.add_argument('--gpu', default=None, type=int,
                    help='GPU id to use.')

parser.add_argument('-score_functions', type=str,  nargs='+', 
                    help='Conformal score functions to use. List with a space in between. Options are'
                    '"softmax", "APS", "RAPS"')
parser.add_argument('-methods', type=str,  nargs='+', 
                    help='Conformal methods to use. List with a space in between. Options include'
                    '"MCP", "CCP", "k-CCP", "always_cluster"')
parser.add_argument('-seeds', type=int,  nargs='+', 
                    help='Seeds for random splits into calibration and validation sets,'
                    'List with spaces in between')
parser.add_argument('-avg_num_per_class', type=int,
                        help='Number of examples per class, on average, to include in calibration dataset')
parser.add_argument('--calibration_sampling', type=str, default='random',
                    help='How to sample the calibration set. Options are "random" and "balanced"')
parser.add_argument('--bins', type=int, default='10',
                    help='Histgram range to plot"')
parser.add_argument('--t_gap', type=float, default='0.9',
                    help='Concentration gap of truncated')
parser.add_argument('--c_gap', type=float, default='0.9',
                    help='Concentration gap of classwise')

parser.add_argument('--lmbda_val', type=float, default='0.01',
                    help='lmbda val for RAPS')
parser.add_argument('--k_reg', type=int, default='5',
                    help='k_reg value for RAPS')

args = parser.parse_args()
print(f"args = {args}")
# parameters
alpha = args.alpha  # desired nominal marginal coverage
n_experiments = args.splits  # number of experiments to estimate coverage

dataset = args.dataset  # dataset to be used  CIFAR100', 'CIFAR10', 'ImageNet'
calibration_scores = ['SC', 'HCC', 'SC_Reg']  # score function to check 'HCC', 'SC', 'SC_Reg'
coverage_on_label = True # Whether to calculate coverage and size per class

# number of test points (if larger then available it takes the entire set)
if dataset == 'ImageNet':
    n_test = 50000
elif dataset == 'eurosat':
    n_test = 5400
elif dataset == 'EMNIST':
    n_test = 20800
elif dataset == 'tiny':
    n_test = 100000
else:
    n_test = 10000

# Validate parameters
assert 0 <= alpha <= 1, 'Nominal level must be between 0 to 1'
assert isinstance(n_experiments, int) and n_experiments >= 1, 'number of splits must be a positive integer.'


# The GPU used for oue experiments can only handle the following quantities of images per batch
GPU_CAPACITY = args.batch_size
torch.cuda.set_device(args.gpu)
device = torch.cuda.current_device()
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"device = {device}")
# set random seed
seed = 0
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

def SavePlotHistgram(args, path, x, random, nbins, results=None):

    directory = os.path.dirname(path)
    if not os.path.exists(directory):
        os.makedirs(directory)

    sns.set_style("whitegrid")
    sns.set(font_scale= 1)

    methods = results['Method'].unique()

    legend_fontsize = 48
    y_tick_fontsize = 40
    x_label_fontsize = 48
    y_label_fontsize = 40


    # # Only select the second and third methods
    # selected_methods = methods[1:3]

    # # Filter the data for the selected methods
    # selected_data = results[results['Method'].isin(selected_methods)]
    if x == 'Class Conditional Coverage':

        # Get the minimum and maximum for 'Class Conditional Coverage' for the selected methods
        min_ccc = results['Class Conditional Coverage'].min()
        max_ccc = results['Class Conditional Coverage'].max()

        # plt.figure(figsize=(54, 30))
        # Filter the data for combined histogram
        ax = sns.histplot(data=results, x=x, hue='Method', bins = nbins, element='step', common_norm=False, kde=True, line_kws={"linewidth": 5}, legend=True)

        sns.move_legend(ax, "center left")

        # Add dotted line at 0.1cm width on x-axis
        plt.axvline(x=1-args.alpha, linestyle='--', color='red', label = '1 - alpha', linewidth=5)

        plt.setp(ax.get_legend().get_texts(), fontsize=legend_fontsize) # for legend text
        plt.setp(ax.get_legend().get_title(), fontsize=legend_fontsize)

        # plt.title(f'Frequency Histogram of {args.dataset}', fontsize='72')
        # plt.xlabel('Class Conditional Coverage', fontsize=48)
        # plt.ylabel('Frequency', fontsize=40)

        ax.set_xlabel('Class Conditional Coverage', fontsize=x_label_fontsize)
        ax.set_ylabel('Frequency', fontsize=y_label_fontsize)

        ax.set_yticklabels(ax.get_yticks(), size = y_tick_fontsize)
        plt.xticks([])

        filename = f"{args.dataset}_{args.imb_type}_{args.rho}_CovgHist_seed_{random}_tgap_{args.t_gap}_cgap_{args.c_gap}_bins_{args.bins}.pdf"
        
    if x == 'Prediction Set Size':

        # Get the minimum and maximum for 'Class Conditional Coverage' for the selected methods
        min_ccc = results['Prediction Set Size'].min()
        max_ccc = results['Prediction Set Size'].max()

        # plt.figure(figsize=(75, 30))
        # Filter the data for combined histogram
        ax = sns.histplot(data=results, x=x, hue='Method', bins = nbins, element='step', common_norm=False, kde=True, line_kws={"linewidth": 5}, legend=True)

        # Add dotted line at 0.1cm width on x-axis
        # plt.axvline(x=1-args.alpha, linestyle='--', color='red')
        # plt.title(f'Frequency Histogram of {args.dataset}', fontsize= '72')
        # plt.xlabel('Prediction Set Size', fontsize=48)
        # plt.ylabel('Frequency', fontsize=40)

        ax.set_xlabel('Prediction Set Size', fontsize=x_label_fontsize)
        ax.set_ylabel('Frequency', fontsize=y_label_fontsize)

        plt.setp(ax.get_legend().get_texts(), fontsize=legend_fontsize) # for legend text
        plt.setp(ax.get_legend().get_title(), fontsize=legend_fontsize)

        ax.set_yticklabels(ax.get_yticks(), size = y_tick_fontsize)
        plt.xticks([])

        filename = f"{args.dataset}_{args.imb_type}_{args.rho}_SizeHist_seed_{random}_tgap_{args.t_gap}_cgap_{args.c_gap}_bins_{args.bins}.pdf"


    # Adjust x-axis range
    plt.tight_layout()
    plt.xlim(min_ccc, max_ccc)
    plt.savefig(os.path.join(directory, filename))
    plt.close()

def SavePlot_Q_Histgram(args, path, x, random, nbins, mq, results=None):

    directory = os.path.dirname(path)
    if not os.path.exists(directory):
        os.makedirs(directory)

    sns.set_style("whitegrid")
    sns.set(font_scale=1)

    legend_fontsize = 48
    y_tick_fontsize = 40
    x_label_fontsize = 48
    y_label_fontsize = 40


    min_ccc = results['Class Quantile'].min()
    max_ccc = results['Class Quantile'].max()

    # plt.figure(figsize=(48, 30))
    # Filter the data for combined histogram
    ax = sns.histplot(data=results, x=x, bins = nbins, element='step', common_norm=False, kde=True, line_kws={"linewidth": 5})

    # Add dotted line at 0.1cm width on x-axis
    plt.axvline(x=mq, linestyle='--', color='red', label = 'Marginal Q', linewidth=5)

    # Display the legend
    plt.legend(fontsize=legend_fontsize)

    # plt.title(f'Frequency Histogram of {args.dataset}', fontsize='72')
    # plt.xlabel('Class Quantile', fontsize=x_label_fontsize)
    # plt.ylabel('Frequency', fontsize=y_label_fontsize)

    ax.set_yticklabels(ax.get_yticks(), size = y_tick_fontsize)
    plt.xticks([])

    ax.set_xlabel('Class Quantile', fontsize=x_label_fontsize)
    ax.set_ylabel('Frequency', fontsize=y_label_fontsize)

    filename = f"{args.dataset}_{args.imb_type}_{args.rho}_QuanHist_seed_{random}_cgap_{args.c_gap}_bins_{args.bins}.pdf"

    # Adjust x-axis range
    plt.tight_layout()
    plt.xlim(min_ccc, max_ccc)
    plt.savefig(os.path.join(directory, filename))
    plt.close()

def SavePlot_S_Histgram(args, path, x, random, nbins, results=None):

    directory = os.path.dirname(path)
    if not os.path.exists(directory):
        os.makedirs(directory)

    sns.set_style("whitegrid")
    sns.set(font_scale=1)

    legend_fontsize = 48
    y_tick_fontsize = 40
    x_label_fontsize = 48
    y_label_fontsize = 40

    min_ccc = results['Class Sigma'].min()
    # max_ccc = results['Class Quantile'].max()

    # plt.figure(figsize=(48, 30))
    # Filter the data for combined histogram
    ax = sns.histplot(data=results, x=x, bins = nbins, element='step', common_norm=False, kde=True, line_kws={"linewidth": 5})

    # Add dotted line at 0.1cm width on x-axis
    plt.axvline(x=1.00, linestyle='--', color='red', label = '1.0',  linewidth=5 )

    # Display the legend
    plt.legend(fontsize=legend_fontsize)

    # plt.title(f'Frequency Histogram of {args.dataset}', fontsize='72')
    # plt.xlabel('Sigma_y', fontsize=x_label_fontsize)
    # plt.ylabel('Frequency', fontsize=y_label_fontsize)

    ax.set_xlabel('Sigma_y', fontsize=x_label_fontsize)
    ax.set_ylabel('Frequency', fontsize=y_label_fontsize)

    ax.set_yticklabels(ax.get_yticks(), size = y_tick_fontsize)
    plt.xticks([])

    filename = f"{args.dataset}_{args.imb_type}_{args.rho}_SigHist_seed_{random}_cgap_{args.c_gap}_bins_{args.bins}.pdf"

    # Adjust x-axis range
    plt.tight_layout()
    # plt.xlim(min_ccc, 1.05)
    plt.savefig(os.path.join(directory, filename))
    plt.close()

def Save_bar_plot(args, path, x, random, results=None):
    
    directory = os.path.dirname(path)
    if not os.path.exists(directory):
        os.makedirs(directory)

    sns.set_style("whitegrid")
    sns.set(font_scale=1)

    legend_fontsize = 19
    y_tick_fontsize = 40
    x_tick_fontsize = 32
    x_label_fontsize = 40
    y_label_fontsize = 40

    ax = sns.barplot(x=x, y='Prediction Set Size', hue='Method', data=results)

    ax.get_legend().remove()
    # plt.legend(title='Method', fontsize=legend_fontsize, loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4)
    ax.set_xlabel('Group', fontsize=x_label_fontsize)
    ax.set_ylabel('Average size', fontsize=y_label_fontsize)

    ax.set_yticklabels(ax.get_yticks(), size = y_tick_fontsize)
    ax.tick_params(axis='x', labelsize=x_tick_fontsize)
    # ax.set_xticklabels(ax.get_xticks(), size = x_tick_fontsize)
    # plt.xticks([])

    filename = f"{args.dataset}_{args.imb_type}_{args.rho}_score_{args.score_functions}_tail_seed_{random}.pdf"

    # Adjust x-axis range
    plt.tight_layout()
    # plt.xlim(min_ccc, 1.05)
    plt.savefig(os.path.join(directory, filename), bbox_inches='tight')
    plt.close()

base_path = "dataset={}/architecture={}/loss_type={}/imb_type={}/imb_factor={}/train_rule={}/epochs={}/batch-size={}\
        /lr={}/momentum={}/".format(args.dataset, args.arch, args.loss_type, args.imb_type, args.rho, args.train_rule,\
             args.epochs, args.batch_size, args.lr, args.momentum)
patha_model = '/checkpoint/' + base_path + "{}_{}_{}_{}_{}_{}_{}".format(args.dataset, args.arch, args.loss_type, \
    args.train_rule, args.imb_type, args.rho, args.rand_number) + '/ckpt.best.pth.tar'
patha_model = os.getcwd() + patha_model
#print(patha_model)

patha = './Results/new' + base_path 
if not os.path.exists(patha):
    os.makedirs(patha)

figures_path = './figure/new/'

if not os.path.exists(figures_path):
    os.makedirs(figures_path)

result_folder = os.path.join(patha, f'{args.calibration_sampling}_calset/n_totalcal={args.avg_num_per_class}/score={args.score_functions}/lmbda={args.lmbda_val}_kreg = {args.k_reg}')
# Remove square brackets from the score part and add trailing slash
result_folder = result_folder.replace("score=['", "score=").replace("']", "/")
# print(result_folder)

file_path = os.path.join(result_folder, 'seed=9_allresults.pkl')

# Load the results from the file
if os.path.exists(file_path):
    with open(file_path, 'rb') as f:
        all_results = pickle.load(f)
else:
    print(f"File {file_path} does not exist.")

data_lists = {
    'Prediction Set Size': [],
    'Group': [],
    'Method': []
}

# Loop through your results to populate the lists
for method, result in all_results.items():
    tail_metrics = result[4]  
    # Append the data for each group
    data_lists['Prediction Set Size'].append(tail_metrics['majority_avg_set_size'])
    data_lists['Group'].append('Majority')
    data_lists['Method'].append(method)

    data_lists['Prediction Set Size'].append(tail_metrics['medium_avg_set_size'])
    data_lists['Group'].append('Medium')
    data_lists['Method'].append(method)

    data_lists['Prediction Set Size'].append(tail_metrics['minority_avg_set_size'])
    data_lists['Group'].append('Minority')
    data_lists['Method'].append(method)

# Convert the lists into a DataFrame
results = pd.DataFrame(data_lists)

# Define the desired order of methods
method_order = ['MCP', 'CCP', 'cluster_CP', 'k-CCP']

# Convert the 'Method' column to a categorical type with the specified order
results['Method'] = pd.Categorical(results['Method'], categories=method_order, ordered=True)


Save_bar_plot(args, os.path.join(figures_path, f'{args.arch}_rho_{args.rho}_loss_{args.loss_type}_score_{args.score_functions}_epoch_{args.epochs}'), x = 'Group', random = args.seeds, results = results)

