from counterfactual_explanations.cf_generator import CounterfactualGenerator
import gurobipy as gp
from gurobipy import GRB
from counterfactual_explanations.input_properties import InputProperties
from conformal.split_conformal import SplitConformalPrediction
from conformal.conformal_helpers import median_pairwise_distances
import numpy as np
from models.abstract_model import MILPEncodableModel
import os

class ConformalCF(CounterfactualGenerator):
    def __init__(self, model: MILPEncodableModel, input_properties: InputProperties, config: dict, save_dir=None, use_pregenerated=True):
        super().__init__(model, input_properties, config, save_dir, use_pregenerated)

        #assert conformal.is_calibrated, "Conformal prediction must be calibrated before generating counterfactuals."
        self.conformal_class = self.config.get('conformal_class', SplitConformalPrediction)
        self.conformal_config = self.config.get('conformal_config', {'alpha': 0.05})

    def setup(self, X_train, y_train, X_calib, y_calib):
        self.conformal = self.conformal_class(self.model, self.input_properties, config=self.conformal_config, save_path=self.save_dir, use_pretrained=self.use_pregenerated)

        if self.conformal.dim_reduction:
            self.conformal.dim_reduction.setup(self.model, self.input_properties, X_train, y_train, self.save_dir, self.use_pregenerated)

        self.conformal.calibrate(X_calib, y_calib)
        
        self.grb_model = gp.Model("model")
        # self.grb_model.setParam('TimeLimit', 60)
        self.grb_model.setParam('OutputFlag', 1)

        self.input_vars, self.input_mvar = self.input_properties.gp_set_input_var_constraints(self.grb_model)
        self.output_vars = self.model.gp_set_model_constraints(self.grb_model, self.input_mvar)

        self.conformal.gp_set_conformal_prediction_constraint(self.grb_model, self.output_vars, self.input_mvar)

        self.singleton_constraints = []
        self.distance_constrs = []

        self.distance_vars = self.grb_model.addVars(len(self.input_vars), lb=0, name="d")
        self.grb_model.setObjective(gp.quicksum(self.distance_vars), GRB.MINIMIZE)

            
    def _set_instance(self, x, y_target):
        self.grb_model.remove(self.singleton_constraints)
        
        self.singleton_constraints = self.conformal.gp_set_singleton_constraint(self.grb_model, y_target)

        self.grb_model.remove(self.distance_constrs)
        self.distance_constrs = []
        for i in range(len(self.input_vars)):
            c1 = self.grb_model.addConstr(-1 * self.distance_vars[i] <= self.input_vars[i] - x[i], name=f"abs_pos_{i}")
            c2 = self.grb_model.addConstr(self.input_vars[i] - x[i] <= self.distance_vars[i], name=f"abs_neg_{i}")
            self.distance_constrs.extend([c1, c2])
        

    def generate_counterfactual(self, x, y_target):
        self._set_instance(x, y_target)
        self.grb_model.optimize()

        if self.grb_model.status == GRB.OPTIMAL:
            # for var in self.grb_model.getVars():
            #     print(var.VarName, '=', var.X)

            return self.check_solution(self.input_mvar, y_target)
        else:
            return np.full_like(self.input_mvar, np.nan)
        