import numpy as np
import gurobipy as gp
from gurobipy import GRB
from pathlib import Path
from counterfactual_explanations.input_properties import InputProperties
from models.abstract_model import AbstractModel
import json
from typing import Type
from counterfactual_explanations.dim_reduction import DimensionalityReduction

def scorefn_softmax(fx_logits, y_correct):
    fx_softmax = np.exp(fx_logits) / np.sum(np.exp(fx_logits))
    return 1 - fx_softmax[y_correct]

def scorefn_linear(fx_logits, y_correct):
        return -1 * fx_logits[y_correct] + max(fx_logits)

def scorefn_linear_2(fx_logits, y_correct):
    return -1 * fx_logits[y_correct] + max([fx_logits[i] for i in range(len(fx_logits)) if i != y_correct])

def scorefn_linear_logits(fx_logits, y_correct):
    return 1 - fx_logits[y_correct]


class SplitConformalPrediction:
    scorefns = {'softmax': scorefn_softmax, 'linear': scorefn_linear, 'linear2': scorefn_linear_2, 'linear_logits': scorefn_linear_logits}

    def __init__(self, model: AbstractModel, input_properties: InputProperties, config: dict, save_path: Path=None, use_pretrained: bool=True):
        self.model = model
        self.input_properties = input_properties

        self.config = config
        self.alpha = self.config.get('alpha', 0.05)
        self.scorefn_name = self.config.get('scorefn_name', 'linear2')
        self.dim_reduction = self.config.get('dim_reduction', None)

        self.save_path = save_path
        self.use_pretrained = use_pretrained

        self.scorefn = SplitConformalPrediction.scorefns[self.scorefn_name]

        self.is_calibrated = False
        self.scores = None
        self.calib_preds = None

    def name(self, exclude=None):
        class CustomEncoder(json.JSONEncoder):
            def default(self, obj):
                if isinstance(obj, Type):
                    return str(obj)
                if isinstance(obj, DimensionalityReduction):
                    return obj.name()
                return super().default(obj)
        
        config = self.config

        if exclude:
            config = {k: v for k, v in config.items() if k != exclude}
    
        return self.__class__.__name__ + json.dumps(config, separators=(',', ':'), cls=CustomEncoder)

    def get_scores(self, X_calib, y_calib):
        if self.model.save_dir:
            scores_path = self.model.save_dir / f"scores_{self.scorefn_name}.npz"

            if scores_path.is_file() and self.use_pretrained:
                loaded = np.load(scores_path)
                if len(loaded["X_calib"]) == len(X_calib) and np.all(loaded["X_calib"] == X_calib) and np.all(loaded["y_calib"] == y_calib):
                    self.calib_preds = loaded["calib_preds"]
                    self.scores = loaded["scores"]
                    return self.scores

        preds = self.model.predict(X_calib)
        
        scores = np.zeros((len(y_calib),))
        for j in range(len(y_calib)):
            scores[j] = self.scorefn(preds[j], y_calib[j]) 

        if self.save_path:
            scores_path.parent.mkdir(parents=True, exist_ok=True)
            np.savez(scores_path, X_calib=X_calib, y_calib=y_calib, calib_preds=preds, scores=scores)

        self.calib_preds = preds
        self.scores = scores
        return self.scores

    def calibrate(self, X_calib, y_calib, test_point=None):
        scores = self.get_scores(X_calib, y_calib)
        self.quantile_val = np.quantile(scores, 1 - self.alpha)
        self.is_calibrated = True
        return self.quantile_val

    def predict(self, X):
        y_labels = self.input_properties.get_labels()
        prediction = self.model.predict(X.reshape(1, -1))[0]
        pred_interval = []
        for element in y_labels:
            score = self.scorefn(prediction, element)
            if score <= self.quantile_val:
                pred_interval.append(element)

        return pred_interval
    
    def predict_batch(self, X):
        y_labels = self.input_properties.get_labels()
        predictions = self.model.predict(X)
        pred_intervals = []

        for i in range(len(predictions)):
            pred_interval = []
            for element in y_labels:
                score = self.scorefn(predictions[i], element)
                if score <= self.quantile_val:
                    pred_interval.append(element)
            pred_intervals.append(pred_interval)

        return pred_intervals

    def compute_stats(self, pred_intervals, y_correct, indices=None, cov_gap=False):
        if indices is not None:
            pred_intervals = [pred_intervals[i] for i in indices]
            y_correct = y_correct[indices]

        coverage = 0
        set_size = 0
        num_points = len(pred_intervals)

        if num_points == 0:
            return np.nan, np.nan

        for i in range(num_points):
            set_size += len(pred_intervals[i])
            if y_correct[i] in pred_intervals[i]:
                coverage += 1

        coverage /= num_points
        set_size /= num_points

        if cov_gap:
            coverage = coverage - (1 - self.alpha)
            coverage *= 100

        return set_size, coverage
    
    def cov_gap(self, set_sizes, coverage, partition_sizes, only_penalise_undercoverage=False, gap=False):
        num_sets = len(coverage)
        avg_set_size = np.sum((np.array(set_sizes) * np.array(partition_sizes))) / np.sum(partition_sizes)

        if not gap:
            coverage = np.average(np.array(coverage))
            return avg_set_size, coverage
        
        cov_gaps = np.array(coverage) - (1 - self.alpha)

        if only_penalise_undercoverage:
            cov_gaps = np.minimum([0] * len(cov_gaps), cov_gaps)

        cov_gap = np.sum(np.abs(cov_gaps)) * 100 / num_sets

        avg_set_size = np.sum((np.array(set_sizes) * np.array(partition_sizes))) / np.sum(partition_sizes)

        return avg_set_size, cov_gap

    
    def evaluate_conditional(self, X, y, n_bins=10, cov_gap=False):
        assert self.is_calibrated
        pred_intervals = self.predict_batch(X)

        # print("Marginal")
        set_size, coverage = self.compute_stats(pred_intervals, y, cov_gap=cov_gap)
        # print(set_size, coverage)
        
        # print("Class conditional")
        unique_y = np.unique(y)
        set_sizes_cc = []
        coverages_cc = []
        partition_sizes_cc = []

        for label in unique_y:
            indices = np.where(y == label)[0]
            set_size, coverage = self.compute_stats(pred_intervals, y, indices)

            set_sizes_cc.append(set_size)
            coverages_cc.append(coverage)

            partition_sizes_cc.append(len(indices))

        # print(set_sizes_cc, coverages_cc, partition_sizes_cc)
        set_size_cc, cov_gap_cc = self.cov_gap(set_sizes_cc, coverages_cc, partition_sizes_cc, gap=cov_gap)
        # print(set_size_cc, cov_gap_cc)

        # print("Random binning")
        random_indices = np.array_split(np.random.permutation(len(X)), n_bins)
        set_sizes_rb = []
        coverages_rb = []
        partition_sizes_rb = []

        for indices in random_indices:
            set_size, coverage = self.compute_stats(pred_intervals, y, indices)

            set_sizes_rb.append(set_size)
            coverages_rb.append(coverage)
            partition_sizes_rb.append(len(indices))

        # print(set_sizes_rb, coverages_rb, partition_sizes_rb)
        set_size_rb, cov_gap_rb = self.cov_gap(set_sizes_rb, coverages_rb, partition_sizes_rb, gap=cov_gap)
        # print(set_size_rb, cov_gap_rb)
        

        # print("Counterfactual simulation")

        if self.model.save_dir:
            indicies_to_include_path = self.model.save_dir / f"indices_{len(X)}.npz"

            if indicies_to_include_path.is_file() and self.use_pretrained:
                indicies_to_include = np.load(indicies_to_include_path)['ind']
                indicies_to_include = indicies_to_include[indicies_to_include != -1]
                set_size_cf, coverage_cf = self.compute_stats(pred_intervals, y, cov_gap=cov_gap, indices=indicies_to_include)
                return set_size, coverage, set_size_cc, cov_gap_cc, set_size_rb, cov_gap_rb, set_size_cf, coverage_cf
        
        y_targets = (y + 1) % self.input_properties.n_targets
        indicies_to_include = np.empty((X.shape[0],), dtype=np.int32)
        indicies_to_include[:] = -1

        singleton_points = {}
        for label in self.input_properties.get_labels():
            singleton_points[label] = [i for i, interval in enumerate(pred_intervals) if interval == [label]]

        for i, x in enumerate(X):
            min_distance = float('inf')
            intervals_to_check = singleton_points[y_targets[i]]

            for x_candidate_idx in intervals_to_check:
                distance = np.linalg.norm(x - X[x_candidate_idx], ord=2)
                if distance < min_distance:
                    min_distance = distance
                    indicies_to_include[i] = x_candidate_idx
        
        if self.save_path:
            indicies_to_include_path.parent.mkdir(parents=True, exist_ok=True)
            np.savez(indicies_to_include_path, ind=indicies_to_include)

        indicies_to_include = indicies_to_include[indicies_to_include != -1]
        set_size_cf, coverage_cf = self.compute_stats(pred_intervals, y, cov_gap=cov_gap, indices=indicies_to_include)
        # print(set_size_cf, coverage_cf)

        return set_size, coverage, set_size_cc, cov_gap_cc, set_size_rb, cov_gap_rb, set_size_cf, coverage_cf


    
    def gp_set_conformal_prediction_constraint(self, grb_model: gp.Model, output_vars: gp.MVar, input_vars: gp.MVar):
        if self.scorefn_name not in ['linear', 'linear2', 'linear_logits'] :
            raise ValueError("Can only use linear scorefn in MILP")

        #Conformal prediction constraint:
        # For target class:
            # score of found cf <= quantile
        # For other classes:
            # score of found cf >= quantle

        num_classes = self.input_properties.n_targets

        self.scores_c = grb_model.addVars(num_classes, lb=-float('inf'), vtype=GRB.CONTINUOUS, name="scores") 
        self.set_score_constraint(grb_model, self.scores_c, output_vars, num_classes)


    def gp_set_singleton_constraint(self, grb_model: gp.Model, target_class: int):
        singleton_constraints = []

        for i in range(self.input_properties.n_targets):
            if i == target_class:
                c = grb_model.addConstr(self.scores_c[i] <= self.quantile_val, name=f"target_{i}")
                singleton_constraints.append(c)
            else:
                c = grb_model.addConstr(self.scores_c[i] >= self.quantile_val + 1e-6, name=f"other_{i}")
                singleton_constraints.append(c)

        return singleton_constraints

    def set_linear_score_constraint(self, grb_model, scores, output_vars, num_classes):
        max_logit = grb_model.addVar(lb=-float('inf'), vtype=GRB.CONTINUOUS, name="max_logits")

        grb_model.addConstr(max_logit == gp.max_(*output_vars))

        grb_model.addConstrs(scores[i] == -1 * output_vars[i] + max_logit for i in range(num_classes))

    def set_linear_score_rf_constraint(self, grb_model, scores, output_vars, num_classes):
        grb_model.addConstrs(scores[i] == 1 - output_vars[i] for i in range(num_classes))
    
    def set_linear_score_2_constraint(self, grb_model: gp.Model, scores, output_vars, num_classes):
        if num_classes == 2:
            grb_model.addConstr(scores[0] == -1 * output_vars[0] + output_vars[1])
            grb_model.addConstr(scores[1] == -1 * output_vars[1] + output_vars[0])

        else:
            max_other_logit = grb_model.addMVar(shape=(output_vars.shape[0],), lb=-float('inf'), vtype=GRB.CONTINUOUS, name="max_logits")
            for i in range(output_vars.shape[0]):
                grb_model.addConstr(max_other_logit[i] == gp.max_([output_vars[j] for j in range(output_vars.shape[0]) if j != i]))

            grb_model.addConstrs(scores[i] == -1 * output_vars[i] + max_other_logit[i] for i in range(num_classes))

    def set_score_constraint(self, grb_model, scores, output_vars, num_classes):
        if self.scorefn_name == "linear":
            self.set_linear_score_constraint(grb_model, scores, output_vars, num_classes)
        elif self.scorefn_name == "linear2":
            self.set_linear_score_2_constraint(grb_model, scores, output_vars, num_classes)
        elif self.scorefn_name == "linear_logits":
            self.set_linear_score_rf_constraint(grb_model, scores, output_vars, num_classes)