import argparse
import random

import torch.backends.cudnn as cudnn
from torch import optim
from torchcp.classification.predictors import SplitPredictor
from torchcp.classification.utils.metrics import Metrics
from common.scores import APS, THR, SAPS, RAPS, Margin
from common.scores import *
from torch.utils.data import TensorDataset

from lib.post_process import *
from lib.utils import *
from common.utils import add_dict_to_csv
from dataset import get_num_classes

metrics = Metrics()

from torchcp.classification.predictors import SplitPredictor, ClusteredPredictor, ClassWisePredictor

predictors = {"Split": SplitPredictor,
              "ClassWise": ClassWisePredictor,
              "Cluster":ClusteredPredictor
              }
score_functions = {"APS": APS,
                  "THR": THR,
                  "SAPS": SAPS,
                  "RAPS": RAPS,
                  "Margin": Margin,
                  }


class experiment:
    def __init__(self, args, csv_path) -> None:
        self.dataset_name = args.dataset
        self.model_name = args.model
        self.num_trials = args.trials
        self.alpha = args.alpha
        self.conf_cal = args.conf_cal
        self.score_function_name = args.score
        self.predictor_name = args.predictor
        self.truncation_name = args.truncation_name
        self.hyperpar_tuning = args.hyperpar
        self.tuning_criterion = args.tuning_criterion
        self.csv_path = csv_path
        print(
            f'Model: {self.model_name} | Desired coverage: {1 - self.alpha} | Predictor: {self.predictor_name}| Score: {self.score_function_name}| Calibration: {self.conf_cal}| Truncated: {self.truncation_name}')
            
        self.choosed_predictor = predictors[self.predictor_name]
        self.score_function = score_functions[self.score_function_name]

        

        ### Data Loading
        self.dataset,self.num_calsses = get_logits_dataset(self.model_name, self.dataset_name, "test")
        self.features = get_features_dataset(self.model_name, self.dataset_name)
        self.features = torch.from_numpy(self.features)
        self.model = build_common_model(self.model_name, self.dataset_name)

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    def run(self, n_data_conf, n_data_val, pct_paramtune, bsz):
        ### Perform experiment
        top1s = np.zeros((self.num_trials,))
        top5s = np.zeros((self.num_trials,))
        coverages = np.zeros((self.num_trials,))
        sizes = np.zeros((self.num_trials,))
        sscvs = np.zeros((self.num_trials,))
        wscs = np.zeros((self.num_trials,))
        covgaps = np.zeros((self.num_trials,))

        for i in tqdm(range(self.num_trials)):
            self.seed = i
            self._fix_randomness(self.seed)
            top1_avg, top5_avg, cvg_avg, sz_avg, sscv, kstar, prec_kstar,wsc, covgap = self.trial_split(n_data_conf, n_data_val,
                                                                                      pct_paramtune, bsz)
            top1s[i] = top1_avg
            top5s[i] = top5_avg
            coverages[i] = cvg_avg
            sizes[i] = sz_avg
            sscvs[i] = sscv
            wscs[i] = wsc
            covgaps[i] = covgap
            
            res_dict = {}
            res_dict["Model"] = self.model_name
            res_dict["Alpha"] = self.alpha
            res_dict["Predictor"] = self.predictor_name
            res_dict["Score"] = self.score_function_name
            res_dict["Trans"] = self.truncation_name
            res_dict["Seed"] = self.seed
            res_dict["Top1"] = top1_avg.item()
            res_dict["Top5"] = top5_avg.item()
            res_dict["Coverage"] = cvg_avg
            res_dict["Size"] = sz_avg
            res_dict["SSCV"] = sscv
            res_dict["WSC"] = wsc
            res_dict["CovGap"] = covgap
            # res_dict["n_cal"] = n_data_conf
            # res_dict["n_val"] = n_data_val

            add_dict_to_csv(self.csv_path, res_dict)
            
            print(
                f'\n\tTop1: {np.mean(top1s[0:i + 1]):.3f}, Top5: {np.mean(top5s[0:i + 1]):.3f}, kstar: {kstar, round(prec_kstar.item(), 2)},Coverage: {np.mean(coverages[0:i + 1]):.3f}, Size: {np.mean(sizes[0:i + 1]):.3f}, SSCV: {np.mean(sscvs[0:i + 1]):.3f} WSC: {np.mean(wscs[0:i + 1]):.3f} CovGap: {np.mean(covgaps[0:i + 1]):.3f} \033[F',
                end='')
        print('')
        

    def trial_split(self, n_data_conf, n_data_val, pct_paramtune, bsz):
        if self.tuning_criterion == "size":
            def tuning_metric_function(prediction_sets, labels):
                return metrics("average_size")(prediction_sets, labels)
        elif self.tuning_criterion == "sscv":
            def tuning_metric_function(prediction_sets, labels):
                return metrics("SSCV")(prediction_sets, labels, self.alpha)
        else:
            raise NotImplementedError
        alpha = self.alpha
        cal_dataset, val_dataset, cal_indices, val_indices = split2(self.dataset, n_data_conf, n_data_val)
        
        transformation = OptimalTeamperatureScaling(1)
        cal_loader = torch.utils.data.DataLoader(cal_dataset, batch_size=bsz, shuffle=False, pin_memory=True)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=bsz, shuffle=False, pin_memory=True)

        cal_logits, cal_labels = postHocLogits(transformation, cal_loader, self.device, self.num_calsses)
        val_logits, val_labels = postHocLogits(transformation, val_loader, self.device, self.num_calsses)

        
        gt_locs_kstar = torch.tensor(
            [torch.where(torch.argsort(cal_logits[index, :]).flip(dims=(0,)) == cal_labels[index])[0][0] for index in
             range(cal_logits.shape[0])]).float()+1
        kstar = torch.quantile(gt_locs_kstar, 1 - alpha, interpolation='higher')
        kstar = int(kstar.numpy())
        prec_kstar, = accuracy(cal_logits, cal_labels, topk=(kstar,))
        
        self.logits_trans = torch.softmax

    
        if self.score_function_name in ["APS", "THR", "Margin"]:    
            predictor = self.choosed_predictor(self.score_function(transform=self.logits_trans))
        else:        
            
            if  self.hyperpar_tuning:
                pc_indices = int(cal_logits.size(0) * pct_paramtune)
                indices = torch.randperm(cal_logits.size(0))
                tuning_logits = cal_logits[indices[:pc_indices]]
                tuning_labels = cal_labels[indices[:pc_indices]]
                cal_logits = cal_logits[indices[pc_indices:]]
                cal_labels = cal_labels[indices[pc_indices:]]
                gt_locs_kstar = torch.tensor(
                    [torch.where(torch.argsort(tuning_logits[index, :]).flip(dims=(0,)) == tuning_labels[index])[0][0] for index in
                    range(tuning_logits.shape[0])]).float()+1
                kstar = torch.quantile(gt_locs_kstar, 1 - alpha, interpolation='higher')
                kstar = int(kstar.numpy())
                    
            if self.score_function_name == "SAPS":
                ################
                # Choose the best ranking weight
                ################        
                
                
                if self.conf_cal == "TS":
                    tuning_dataset = TensorDataset(tuning_logits, tuning_labels)
                    tuning_dataloader = torch.utils.data.DataLoader(tuning_dataset, batch_size=bsz, shuffle=False, pin_memory=True)

                    transformation = OptimalTeamperatureScaling(1.3)
                    transformation = self.get_optimal_parameters(transformation, tuning_dataloader)

                    tuning_logits, tuning_labels = postHocLogits(transformation, tuning_dataloader, self.device, self.num_calsses)

                    cal_dataset = TensorDataset(cal_logits, cal_labels)
                    cal_dataloader = torch.utils.data.DataLoader(cal_dataset, batch_size=bsz, shuffle=False,pin_memory=True)
                    cal_logits, cal_labels = postHocLogits(transformation, cal_dataloader, self.device, self.num_calsses)

                    val_dataset = TensorDataset(val_logits, val_labels)
                    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=bsz, shuffle=False,
                                                                    pin_memory=True)
                    val_logits, val_labels = postHocLogits(transformation, val_dataloader, self.device, self.num_calsses)

                # tunning the hyper-parameter
                if self.hyperpar_tuning:
                    ranking_weight_star = 0
                    best_tuning_metric = self.num_calsses
                    for temp_ranking_weight in np.arange(0.02, 0.65, 0.02):
                        predictor = self.choosed_predictor(self.score_function(transform=self.logits_trans, weight=temp_ranking_weight))
                        predictor.calculate_threshold(tuning_logits, tuning_labels, alpha)
                        prediction_sets = predictor.predict_with_logits(tuning_logits)
                        tuning_metric = tuning_metric_function(prediction_sets, tuning_labels)
                        if tuning_metric < best_tuning_metric:
                            ranking_weight_star = temp_ranking_weight
                            best_tuning_metric = tuning_metric
                else:
                    ranking_weight_star = 0.1
                        
                        
                    
                predictor = self.choosed_predictor(self.score_function(transform=self.logits_trans, weight=ranking_weight_star))
                
            elif self.score_function_name == "RAPS":

                ################
                # Choose the best ranking weight
                ################
                
                if self.hyperpar_tuning:     
                    if self.tuning_criterion == "sscv":
                        tuning_parameters = np.array([0.00001, 0.0001, 0.0008, 0.001, 0.0015, 0.002]) 
                    else:
                        tuning_parameters = np.arange(0.02, 0.65, 0.02)
                    penalty_star = 0
                    best_tuning_metric = self.num_calsses
                    for temp_penalty in tuning_parameters:
                        predictor = self.choosed_predictor(self.score_function(transform=self.logits_trans, penalty=temp_penalty, kreg=kstar))
                        predictor.calculate_threshold(tuning_logits, tuning_labels, alpha)
                        prediction_sets = predictor.predict_with_logits(tuning_logits)
                        tuning_metric = tuning_metric_function(prediction_sets, tuning_labels)
                        if tuning_metric < best_tuning_metric:
                            penalty_star = temp_penalty
                            best_tuning_metric = tuning_metric
                else:
                    penalty_star = 0.01
                    kstar = 5

                predictor = self.choosed_predictor(self.score_function(transform=self.logits_trans, penalty=penalty_star, kreg=kstar))
            

            else:
                raise NotImplementedError
        
        
        self._fix_randomness(self.seed)
        predictor.calculate_threshold(cal_logits, cal_labels, alpha)
        val_prediction_sets = predictor.predict_with_logits(val_logits)
        
        
        if self.truncation_name == "truncated":
            cal_prediction_sets = predictor.predict_with_logits(cal_logits)
            cal_set_sizes = torch.Tensor([len(element) for element in cal_prediction_sets])
            max_set_size = int(torch.max(cal_set_sizes).item())
            truncated_size = self.num_calsses
            gt_locs_kstar = torch.tensor(
                [torch.where(torch.argsort(cal_logits[index, :]).flip(dims=(0,)) == cal_labels[index])[0][0] for
                 index in
                 range(cal_logits.shape[0])]).float() + 1
            for ele in range(max_set_size, 1,-1):
                temp_coverage = torch.logical_and(gt_locs_kstar <= ele , gt_locs_kstar<= cal_set_sizes).sum() / gt_locs_kstar.shape[0]
                if temp_coverage < 1-alpha:
                    truncated_size = ele+1
                    break
                
            val_scores = predictor.score_function(val_logits).to(val_logits.device)
            truncated_val_prediction_sets = []
            for index in range(val_scores.shape[0]):
                set1 = val_prediction_sets[index]
                _, set2 =  torch.topk(val_logits[index], truncated_size)
                set2 = set2.tolist()
                prediction_set = list(set(set1) & set(set2))
                truncated_val_prediction_sets.append(prediction_set)
                
            val_prediction_sets = truncated_val_prediction_sets 
                
        elif self.truncation_name == "class_truncated":
            cal_prediction_sets = predictor.predict_with_logits(cal_logits)
            cal_set_sizes = torch.Tensor([len(element) for element in cal_prediction_sets])
            max_set_size = int(torch.max(cal_set_sizes).item())
            
            gt_locs_kstar = torch.tensor(
                [torch.where(torch.argsort(cal_logits[index, :]).flip(dims=(0,)) == cal_labels[index])[0][0] for
                 index in
                 range(cal_logits.shape[0])]).float() + 1
            
            truncated_sizes = torch.ones(self.num_calsses)
            for temp_class in range(self.num_calsses):
                temp_gt_locs_kstar = gt_locs_kstar[cal_labels == temp_class]
                indices = np.where(cal_labels == temp_class)[0]
                covred_indices = []
                for temp_index in indices:
                    if temp_class in cal_prediction_sets[temp_index]:
                        covred_indices.append(temp_index)
                covred_indices  = torch.tensor(covred_indices)
                truncated_sizes[temp_class] = int(torch.max(gt_locs_kstar[covred_indices]).item())

            
            max_set_size = int(torch.max(cal_set_sizes).item())
            
            gt_locs_kstar = torch.tensor(
                [torch.where(torch.argsort(cal_logits[index, :]).flip(dims=(0,)) == cal_labels[index])[0][0]                for index in
                 range(cal_logits.shape[0])]).float() + 1
            
            
            val_scores = predictor.score_function(val_logits).to(val_logits.device)
            truncated_val_prediction_sets = []
            for index in range(val_scores.shape[0]):
                the_logits = val_logits[index]
                score = val_scores[index]
                set1 = val_prediction_sets[index]
                sorted_indices = torch.argsort(the_logits, descending=True)

                category_positions = torch.empty_like(sorted_indices)
                
                category_positions[sorted_indices] = torch.arange(len(score))+1
                valid_indices = torch.where(category_positions <= truncated_sizes)[0]
                set2 = set(valid_indices.tolist())
                prediction_set = list(set(set1) & set(set2))
                truncated_val_prediction_sets.append(prediction_set)
                
            val_prediction_sets = truncated_val_prediction_sets 
            
        elif self.truncation_name == "cluster_truncated":
            clustering_indices = predictor.idx1.cpu()
            cal_labels  = cal_labels[~clustering_indices]
            cal_logits  = cal_logits[~clustering_indices]
            
            cal_prediction_sets = predictor.predict_with_logits(cal_logits)
            cluster_nums = predictor.num_clusters
            
            cluster_assignments = predictor.cluster_assignments
            cluster_assignments = cluster_assignments.cpu().numpy()
            cluster_labels_indexes = {value: torch.tensor([i for i, x in enumerate(cluster_assignments) if x == value]) for value in set(cluster_assignments)}
            cluster_labels_indexes_set = {value: set([i for i, x in enumerate(cluster_assignments) if x == value]) for value in set(cluster_assignments)}

            cluster_labels = list(set(cluster_assignments))
            from collections import Counter

            counts = Counter(cluster_assignments)
            truncated_sizes = {}
            cal_prediction_sets = predictor.predict_with_logits(cal_logits)
            for temp_cluster in cluster_labels:                
                
                temp_indices = torch.tensor([the_cal_label in cluster_labels_indexes[temp_cluster] for the_cal_label in cal_labels])
                selected_rows = cal_logits[temp_indices]

                temp_cal_logits = selected_rows[:, cluster_labels_indexes[temp_cluster]]
                
                temp_cal_labels = []
                for index in range(cal_labels.shape[0]):
                    if temp_indices[index] == 0:
                        continue
                    temp_cal_labels.append(torch.where(cluster_labels_indexes[temp_cluster] == cal_labels[index])[0] )
                temp_cal_labels = torch.tensor(temp_cal_labels)
                temp_cal_prediction_sets =  []
                for index in range(cal_logits.shape[0]):
                    if temp_indices[index] == 0:
                        continue
                    the_pre_set = set(cal_prediction_sets[index])
                    temp_cal_prediction_sets.append(list(the_pre_set&set(cluster_labels_indexes_set[temp_cluster])))
                
                temp_cal_set_sizes = torch.Tensor([len(element) for element in temp_cal_prediction_sets])
                
                if temp_cal_set_sizes.numel() <= 0:
                    truncated_sizes[temp_cluster] = counts[temp_cluster]
                    continue
                    
                temp_max_set_size = int(torch.max(temp_cal_set_sizes).item())
                truncated_sizes[temp_cluster] = counts[temp_cluster]
                
                temp_gt_locs_kstar = torch.tensor(
                    [torch.where(torch.argsort(temp_cal_logits[index, :]).flip(dims=(0,)) == temp_cal_labels[index])[0][0] for
                    index in range(temp_cal_logits.shape[0])]).float() + 1
                temp_all_labels = list(set(temp_cal_labels.numpy().tolist()))
                
                best_all_coverage  = 0
                coveraged_dict = {}
                for temp_label in temp_all_labels:
                    coveraged_dict[temp_label] = temp_gt_locs_kstar[temp_cal_labels==temp_label]<= temp_cal_set_sizes[temp_cal_labels==temp_label]
                    best_all_coverage += (coveraged_dict[temp_label]).sum() / temp_gt_locs_kstar[temp_cal_labels==temp_label].shape[0]
                best_all_coverage = best_all_coverage/len(temp_all_labels)
                    
                    
                for ele in range(temp_max_set_size, 1,-1):
                    temp_all_coverage  = 0
                    for temp_label in temp_all_labels:
                        temp_all_coverage += torch.logical_and(temp_gt_locs_kstar[temp_cal_labels==temp_label] <= ele , coveraged_dict[temp_label]).sum() / temp_gt_locs_kstar[temp_cal_labels==temp_label].shape[0]
                    temp_all_coverage = temp_all_coverage/len(temp_all_labels)
                    
                    if temp_all_coverage < best_all_coverage:
                        truncated_sizes[temp_cluster] = ele+1
                        break
                    
            val_scores = predictor.score_function(val_logits).to(val_logits.device)
            truncated_val_prediction_sets = []
            for index in range(val_scores.shape[0]):
                set1 = val_prediction_sets[index]
                truncated_set = set()
                for temp_cluster in cluster_labels: 
                    _, set2 =  torch.topk(val_logits[index, cluster_labels_indexes[temp_cluster]], truncated_sizes[temp_cluster])
                    temp_set = cluster_labels_indexes[temp_cluster][set2]
                    temp_set = temp_set.tolist()
                    truncated_set = (truncated_set | set(temp_set))
                prediction_set = list(set(set1) & truncated_set)
                truncated_val_prediction_sets.append(prediction_set)
            val_prediction_sets = truncated_val_prediction_sets

        coverage_rate = metrics('coverage_rate')(val_prediction_sets, val_labels)
        average_size = metrics('average_size')(val_prediction_sets, val_labels)
        sscv = metrics('SSCV')(val_prediction_sets, val_labels, self.alpha)
        prec1, prec5 = accuracy(val_logits, val_labels, topk=(1, 5))
        wsc = metrics('WSC')(self.features[val_indices,:], val_prediction_sets, val_labels, self.alpha, M = 100, random_state = self.seed, verbose = False)
        covgap = metrics('CovGap')( val_prediction_sets, val_labels, self.alpha, self.num_calsses)

        return prec1, prec5, coverage_rate, average_size, sscv, kstar, prec_kstar,wsc,covgap

    def get_optimal_parameters(self, transformation, calib_loader):
        """
        Tune the tempearature of the model (using the validation set).
        We're going to set it to optimize NLL.
        valid_loader (DataLoader): validation set loader
        """

        device = self.device
        transformation.to(device)
        max_iters = 10
        lr = 0.01
        epsilon = 0.01
        nll_criterion = nn.CrossEntropyLoss().cuda()

        T = transformation.temperature

        optimizer = optim.SGD([transformation.temperature], lr=lr)
        for iter in range(max_iters):
            T_old = T.item()
            # print(T_old)
            for x, targets in calib_loader:
                optimizer.zero_grad()
                x = x.cuda()
                x.requires_grad = True
                out = x / transformation.temperature
                loss = nll_criterion(out, targets.long().cuda())

                loss.backward()
                optimizer.step()
            T = transformation.temperature
            if abs(T_old - T.item()) < epsilon:
                break

        return transformation

    def _fix_randomness(self, seed=0):
        ### Fix randomness 
        np.random.seed(seed=seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        random.seed(seed)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Evaluates conformal predictors',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--dataset', type=str, default='imagenet', help='dataset')
    parser.add_argument('--score', type=str, default='THR', help='the score function of CP.')
    parser.add_argument('--predictor', type=str, default='Split', help='the conformal predictor of CP.')
    parser.add_argument('--truncation_name', type=str, default='base', help='the transformation of base conformal prediction. Optional: base, truncated')
    parser.add_argument('--sparse_topk', type=int, default=5, help='the top-k retained probabilities.')
    parser.add_argument('--alpha', type=float, default=0.1, help='the error rate.')
    parser.add_argument('--trials', type=int, default=1, help='number of trials')
    parser.add_argument('--conf_cal', type=str, default="None", help='the confidence calibration method.')
    parser.add_argument('--hyperpar', type=int, default=1, help='tuning the hyper-parameter of score functions.')
    parser.add_argument('--tuning_criterion', type=str, default="size",
                        help='the criterion to tune the best hyper-parameter.')
    parser.add_argument('--gpu', type=int, default=4, help='chose gpu id')
    parser.add_argument('--n_cal', type=int, default=5000, help='the number of calibration examples')
    parser.add_argument('--model', type=str, default=None, help='model')

    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    
    import torch.backends.cudnn as cudnn
    from torch import optim

    dataset_name = args.dataset

    if dataset_name == "imagenetv2":
        n_data_conf = 5000
        n_data_val = 5000
        models= ["ResNeXt101", "ResNet152", "ResNet101", "ResNet50", "DenseNet161", "Inception", "ShuffleNet"]
    elif dataset_name == "imagenet-a":
        n_data_conf = 4500
        n_data_val = 3000
    elif dataset_name == "imagenet-r":
        n_data_conf = 15000
        n_data_val = 15000
    elif dataset_name == "imagenet":
        n_data_conf = args.n_cal
        n_data_val = 30000
        if n_data_conf > 20000:
            n_data_val = 50000 - n_data_val
        # models= ["ResNeXt101", "ResNet152", "ResNet101", "ResNet50", "DenseNet161", "Inception", "ShuffleNet"]
        models= ["ResNeXt101"]
    elif dataset_name == "cifar10" or dataset_name == "cifar100":
        
        n_data_conf = args.n_cal
        if n_data_conf>5000:
            n_data_conf = 5000
        n_data_val = 5000
        models= ["ResNet101", "DenseNet161", "VGG16", "Inception","ViT"]

    elif dataset_name == "ina21":
        n_data_conf = args.n_cal
        n_data_val = 5000
    else:
        raise NotImplementedError

    pct_paramtune = 0.2
    bsz = 128
    cudnn.benchmark = True
    
    if args.score == "SAPS" and args.hyperpar :
            args.conf_cal = "TS"

    from main import experiment
    if args.model == None:
        
        save_csv_path = f"cache/{args.dataset}/res/predictor={args.predictor}_score={args.score}_truncation_name={args.truncation_name}_tuning_criterion={args.tuning_criterion}_cal={args.conf_cal}_cal={n_data_conf}_val={n_data_val}_hyperpar={args.hyperpar}.csv"
        print(save_csv_path)
        if os.path.exists(save_csv_path):
            os.remove(save_csv_path)
        for model in models:
            args.model = model
            this_experiment = experiment(args, save_csv_path)
            this_experiment.run(n_data_conf, n_data_val, pct_paramtune, bsz)
    else:
        save_csv_path = f"cache/{args.dataset}/res/predictor={args.predictor}_score={args.score}_truncation_name={args.truncation_name}_tuning_criterion={args.tuning_criterion}_cal={args.conf_cal}_cal={n_data_conf}_val={n_data_val}_hyperpar={args.hyperpar}_model={args.model}.csv"
        if os.path.exists(save_csv_path):
            os.remove(save_csv_path)
        this_experiment = experiment(args, save_csv_path)
        this_experiment.run(n_data_conf, n_data_val, pct_paramtune, bsz)

