from abc import ABC, abstractmethod
from counterfactual_explanations.input_properties import InputProperties
from counterfactual_explanations.gradient_based.losses import *
from counterfactual_explanations.cf_generator import CounterfactualGenerator
from conformal.localised_conformal_baselcp import *
from conformal.losses_conformal import *
import numpy as np

from models.abstract_model import DifferentiableModel
from models.gradientboosting_sklearn import GradientBoostingSKLearn
from models.decisiontree_sklearn import DecisionTreeSKLearn
from models.randomforest_sklearn import RandomForestSKLearn

from counterfactual_explanations.gradient_based.auxillary_models import VariationalAutoencoder
from pathlib import Path

from tqdm import tqdm

class GradientBasedGenerator(CounterfactualGenerator):
    def __init__(self, model: DifferentiableModel, input_properties: InputProperties, config: dict, save_dir: Path=None, use_pregenerated: bool=True):
        super().__init__(model, input_properties, config, save_dir, use_pregenerated)

        assert isinstance(self.model, DifferentiableModel)
        self.model = model

        self.norm = self.config.get('norm', 1)
        self.dist_weight = self.config.get('dist_weight', None)
        self.mad = self.config.get('mad', False)
        self.n_iter = self.config.get('n_iter', 1000)
        self.lr = self.config.get('lr', 0.005)
        self.min_max_lambda = self.config.get('min_max_lambda', 1)

class WachterGenerator(GradientBasedGenerator):
    def setup(self, X_train, y_train, X_calib, y_calib):
        distance_loss = DistanceLoss(self.norm, self.mad, X_train)
        clf_loss = ClassificationLoss()

        self.optimisation_loop = self.model.get_optimisation_loop(self.input_properties, losses=[clf_loss, distance_loss], n_iter=self.n_iter, lr=self.lr, min_max_lambda=self.min_max_lambda, early_stopping=True)

    def generate_counterfactual(self, x, y_target):
        return self.optimisation_loop.optimise_minmax(x, y_target)

class ECCCOGenerator(GradientBasedGenerator):
    # Distance + loss + energy based + conformal
    
    def setup(self, X_train, y_train, X_calib, y_calib):
        clf_loss = ClassificationLoss()
        distance_loss = DistanceLoss(self.norm, self.mad, X_train)
        energy_based_loss = EnergyLoss()

        conformal_config = self.config.get('conformal_config', {})
        conformal = SplitConformalPrediction(self.model, self.input_properties, conformal_config, self.save_dir, self.use_pregenerated)
        conformal.calibrate(X_calib, y_calib)

        set_size_loss = SetSizeLoss(conformal)
        
        self.optimisation_loop = self.model.get_optimisation_loop(self.input_properties, losses=[clf_loss, distance_loss, energy_based_loss, set_size_loss], n_iter=self.n_iter, lr=self.lr, 
                                                                  losses_weights=[1, 0.2, 0.4, 0.4])

    def generate_counterfactual(self, x, y_target):
        return self.optimisation_loop.optimise_min(x, y_target)
    
class DifferentiableCONFEXGenerator(GradientBasedGenerator):
    
    def setup(self, X_train, y_train, X_calib, y_calib):
        clf_loss = ClassificationLoss()
        distance_loss = DistanceLoss(self.norm, self.mad, X_train)

        conformal_config = self.config.get('conformal_config', {})
        conformal = BaseLCP(self.model, self.input_properties, conformal_config, self.save_dir, self.use_pregenerated)
        conformal.calibrate(X_calib, y_calib)
        self.conformal = conformal

        set_size_loss = SetSizeLossBaseLCP(conformal)
        
        self.optimisation_loop = self.model.get_optimisation_loop(self.input_properties, losses=[clf_loss, distance_loss, set_size_loss], n_iter=self.n_iter, lr=self.lr, 
                                                                  losses_weights=[1, 1, 1], retain_graph=True)

    def generate_counterfactual(self, x, y_target):
        return self.optimisation_loop.optimise_min(x, y_target)

class ReviseGenerator(GradientBasedGenerator):
    # Distance + loss + latent VAE
    def setup(self, X_train, y_train, X_calib, y_calib):
        distance_loss = DistanceLoss(self.norm, self.mad, X_train)
        clf_loss = ClassificationLoss()

        vae_config = self.config.get("vae_config", {})
        #Train VAE, save
        encoder = VariationalAutoencoder(vae_config, self.input_properties)

        if self.save_dir is not None:
            encoder.load_or_train(self.save_dir / "vae", X_train, y_train, self.use_pregenerated) 
        else:
            encoder.train(X_train, y_train) 

        # self.encoder = encoder

        self.optimisation_loop = self.model.get_optimisation_loop(self.input_properties, losses=[clf_loss, distance_loss], n_iter=self.n_iter, lr=self.lr, min_max_lambda=self.min_max_lambda, latent_encoding=encoder,
                                                                  losses_weights=[1, 1])

    def generate_counterfactual(self, x, y_target):
        return self.optimisation_loop.optimise_min(x, y_target)

class SchutGenerator(GradientBasedGenerator):
    def setup(self, X_train, y_train, X_calib, y_calib):
        clf_loss = ClassificationLoss()
        self.optimisation_loop = self.model.get_optimisation_loop(self.input_properties, losses=[clf_loss], n_iter=self.n_iter, lr=self.lr, min_max_lambda=None, jsma=True)

    def generate_counterfactual(self, x, y_target):
        return self.optimisation_loop.optimise_min(x, y_target)