from conformal.split_conformal import SplitConformalPrediction

from counterfactual_explanations.input_properties import InputProperties
from counterfactual_explanations.dim_reduction import DimensionalityReduction
from typing import Literal
import numpy as np
from conformal.milp_utils import *
from sklearn.ensemble import RandomForestRegressor
from gurobi_ml.sklearn import add_random_forest_regressor_constr
from conformal.conformal_helpers import *

import numpy as np
import gurobipy as gp
from gurobipy import GRB
import torch

def feature_distance(x1, x2, ord, input_properties):
    if input_properties is None:
        return np.linalg.norm(x1 - x2, ord)

    dist_elements = np.zeros(input_properties.n_distinct_features)
    j = 0

    for i in range(input_properties.n_features):
        if input_properties.feature_classes[i] != 'categorical':
            dist_elements[j] += x1[i] - x2[i]
            j += 1

    for group in input_properties.categorical_groups:
        group_vals = x1[group]
        group_vals_2 = x2[group]
        dist_elements[j] = np.sum(group_vals == group_vals_2) 
        j += 1

    return np.linalg.norm(dist_elements, ord)

def gaussian_kernel(x1, x2, h, input_properties=None):
    diff = feature_distance(x1, x2, 2, input_properties)
    return np.exp(-1 * diff * diff / (2 * h * h))

def box_kernel_l1(x1, x2, h, input_properties=None):
    dist = feature_distance(x1, x2, 1, input_properties)
    return int(dist <= h)

def box_kernel_l2(x1, x2, h, input_properties=None):
    dist = feature_distance(x1, x2, 2, input_properties)
    return int(dist <= h)

def box_kernel_linf(x1, x2, h, input_properties=None):
    dist = feature_distance(x1, x2, 1, input_properties)
    return np.max(dist) <= h 

class BaseLCP(SplitConformalPrediction):
    kernels = {'box_l1':box_kernel_l1, 'box_l2': box_kernel_l2, 'gaussian': gaussian_kernel, 'box_linf': box_kernel_linf}

    def __init__(self, model, input_properties, config, save_path=None, use_pretrained=None):
        super().__init__(model, input_properties, config, save_path, use_pretrained)

        self.kernel_name = self.config.get('kernel_name', 'box_l1')
        self.kernel_bandwidth = self.config.get('kernel_bandwidth', 1)
        self.kernel_bandwidth_scaling = self.config.get('kernel_bandwidth_scaling', True)
        self.kernel = BaseLCP.kernels[self.kernel_name]
        self.sample_threshold = self.config.get('sample_threshold', 1000)
        self.dim_reduction = self.config.get('dim_reduction', None)
        self.scores_nonlocalised = None
        self.med_pairwise_distance = None
        self.is_calibrated = False
        
        self.X_calib = None
        self.y_calib = None
        self.X_calib_encoded = None

    def calibrate(self, X_calib, y_calib, test_point=None):

        if self.kernel_bandwidth_scaling and not self.is_calibrated:
            X_calib_r = X_calib

            if len(X_calib) > 10000:
                np.random.seed(2)
                random_indices = np.random.choice(len(X_calib), size=10000, replace=False)
                X_calib_r = X_calib[random_indices]

            if self.med_pairwise_distance is None and self.save_path and self.use_pretrained:
                dim_reduction_name = None
                if self.dim_reduction:
                    dim_reduction_name = self.dim_reduction.name()

                med_pairwise_distances_path = self.save_path / f"med_pairwise_distances_{dim_reduction_name}.npy"

                if med_pairwise_distances_path.is_file():
                    self.med_pairwise_distance = np.load(med_pairwise_distances_path)
                else:
                    self.med_pairwise_distance = median_pairwise_distances(X_calib_r, self.dim_reduction)
                    np.save(med_pairwise_distances_path, self.med_pairwise_distance)
            
            if self.med_pairwise_distance is None:
                self.med_pairwise_distance = median_pairwise_distances(X_calib_r, self.dim_reduction)
            
            self.kernel_bandwidth = self.kernel_bandwidth * self.med_pairwise_distance


        if self.sample_threshold is not None and len(X_calib) > self.sample_threshold:
            self.X_calib, self.y_calib = sample_points(X_calib, y_calib, self.sample_threshold)
        else:
            self.X_calib, self.y_calib = X_calib, y_calib

        self.scores_nonlocalised = self.get_scores(self.X_calib, self.y_calib)

        scores = np.append(self.scores, float('inf'))
        
        self.X_calib_encoded = self.X_calib
        
        kernel_input_properties = self.input_properties

        if self.dim_reduction is not None:
            self.X_calib_encoded = self.dim_reduction.encode(self.X_calib)
            test_point = self.dim_reduction.encode([test_point])[0]
            kernel_input_properties = None

        if test_point is None:
            self.is_calibrated = True
            self.quantile_val = np.quantile(scores, 1-self.alpha)
            return self.quantile_val
        
        calib_len = len(self.y_calib)
        weights = np.zeros((calib_len+1,))

        for j in range(calib_len):
            weights[j] = self.kernel(self.X_calib_encoded[j], test_point, self.kernel_bandwidth, kernel_input_properties)

        weights[calib_len] = self.kernel(test_point, test_point, self.kernel_bandwidth, kernel_input_properties)
        weights /= np.sum(weights)

        sorted_indices = np.argsort(self.scores_nonlocalised)
        scores = self.scores_nonlocalised[sorted_indices]
        weights = weights[sorted_indices]
        
        self.quantile_val = float('inf')
        cumulative_prob = 0.0
        for i in range(calib_len):
            cumulative_prob += weights[i]
            if cumulative_prob >= 1.0 - self.alpha:
                self.quantile_val = scores[i]
                break
        
        self.is_calibrated = True
        return self.quantile_val
    
    def predict_batch(self, X):
        assert self.is_calibrated

        y_labels = self.input_properties.get_labels()
        predictions = self.model.predict(X)
        pred_intervals = []

        for i in range(len(predictions)):
            self.calibrate(X_calib=self.X_calib, y_calib=self.y_calib, test_point=X[i])
            pred_interval = []
            for element in y_labels:
                score = self.scorefn(predictions[i], element)
                if score <= self.quantile_val:
                    pred_interval.append(element)
            pred_intervals.append(pred_interval)

        return pred_intervals


    def gp_set_conformal_prediction_constraint(self, 
                                               grb_model: gp.Model, 
                                               output_vars: gp.MVar, 
                                               input_vars: gp.MVar):
        if self.kernel_name not in ['box_l1', 'box_linf'] and self.scorefn_name not in ['linear', 'linear2', 'linear_logits']:
            raise ValueError("Can only use linear scorefn and box_l1 kernel in MILP")

        norm = 1        
        if self.kernel_name == "box_linf":
            norm = np.inf

        assert self.is_calibrated

        scores = self.scores_nonlocalised
        scores = scores[:-1]
        sorted_indices = np.argsort(scores)
        scores = self.scores_nonlocalised[sorted_indices]
        scores = np.append(scores, [np.max(scores) + 100], axis=0)
        scores_mvar = gp_set_np_mvar(grb_model, scores, "scores")

        sorted_Xcalib = self.X_calib_encoded[sorted_indices]
        values_mvar = gp_set_np_mvar(grb_model, sorted_Xcalib, "X_calib")
    
        if self.dim_reduction:
            input_vars_reduced = self.dim_reduction.gp_dim_encoding(grb_model, input_vars)
            weights_c_mvar = gp_get_weights(grb_model, values_mvar, input_vars_reduced, self.kernel_bandwidth, norm=norm)
        else:
            weights_c_mvar = gp_get_weights(grb_model, values_mvar, input_vars, self.kernel_bandwidth, input_properties=self.input_properties, norm=norm)
            

        self.quantile_val = gp_get_weighted_quantile_new(grb_model, scores_mvar, weights_c_mvar, 1-self.alpha)
        # quantile_val = gp_get_weighted_quantile(grb_model, scores_mvar, weights_c_mvar, 1-self.alpha)

        # Test conformal prediction
        num_classes = self.input_properties.n_targets

        self.scores_c = grb_model.addVars(num_classes, lb=-float('inf'), vtype=GRB.CONTINUOUS, name="scores_test") 

        self.set_score_constraint(grb_model, self.scores_c, output_vars, num_classes)

