from abc import ABC, abstractmethod
from sklearn.metrics import precision_score, f1_score
import gurobipy as gp
from joblib import dump, load
import os 
import json
from counterfactual_explanations.input_properties import InputProperties

class AbstractModel(ABC):
    def __init__(self, config, input_properties: InputProperties):
        self.model = None
        self.config = config
        self.config["random_state"] = self.config.get("random_state", 0)
        self.random_state = self.config["random_state"]
        self.input_properties = input_properties
        self.save_dir = None

    @abstractmethod
    def train(self, X_train, y_train):
        pass

    @abstractmethod
    def predict(self, x):
        pass

    def evaluate(self, X_test, y_test):
        accuracy = self.model.score(X_test, y_test) * 100
        y_pred = self.model.predict(X_test)
        precision = precision_score(y_test, y_pred, average='weighted') * 100
        f1 = f1_score(y_test, y_pred, average='weighted') * 100

        model_performance = {
            'accuracy': accuracy,
            'precision': precision,
            'f1_score': f1
        }

        return model_performance
    
    def load_or_save_evaluation(self, X_test, y_test, use_pretrained=True):
        evaluation_path = self.save_dir / "eval.json"
        if evaluation_path.is_file() and use_pretrained:
            with open(evaluation_path, 'r') as f:
                return json.load(f)
            
        evaluation = self.evaluate(X_test, y_test)
        with open(evaluation_path, 'w') as f:
            json.dump(evaluation, f)
        
        return evaluation

    
    def savename(self):
        return "saved.model"

    def save(self, save_path):
        dump(self.model, save_path)

    def load(self, save_path):
        self.model = load(save_path)
    
    def save_to_dir(self, save_dir):
        os.makedirs(save_dir, exist_ok=True)
        self.save_dir = save_dir
        model_path = self.save_dir / self.savename()
        self.save(model_path)

        config_path = self.save_dir / "config.json"
        with open(config_path, 'w') as f:
            json.dump(self.config, f)

    def load_from_dir(self, save_dir):
        self.save_dir = save_dir
        model_path = self.save_dir / self.savename()
        self.load(model_path)

        config_path = self.save_dir / "config.json"
        with open(config_path, 'r') as f:
            self.config = json.load(f)

    def load_external(self, model):
        self.model = model

    def load_or_train(self, save_dir, X_train, y_train, use_pretrained):
        model_path = save_dir / self.savename()
        config_path = save_dir / "config.json"
        if model_path.is_file() and config_path.is_file() and use_pretrained:
            self.load_from_dir(save_dir)
            return self.model
        else:
            self.train(X_train, y_train)
            self.save_to_dir(save_dir)
            return self.model

    
    # def model_description(self):
    #     class_name = self.__class__.__name__
    #     config_str = json.dumps(self.config, separators=(',', ':'))
    #     return f'{class_name}_{config_str}'

    # def pathname(self):
    #     return f'{self.model_description()}_{self.random_state}.model'
    
    # def save_to_directory(self, save_dir):
    #     self.save(os.path.abspath(save_dir + self.pathname()))
    

    # def load_external(self, model):
    #     self.model = model
    
    # def load(self, save_path):
    #     self.model = load(save_path)
    #     self.save_path = save_path
    
    # def load_or_train(self, save_dir, X_train, y_train, use_pretrained):
    #     os.makedirs(save_dir, exist_ok=True)
    #     save_path = save_dir + self.pathname()

    #     if os.path.isfile(save_path) and use_pretrained:
    #         self.load(save_path)
    #         return self.model
    #     else:
    #         self.train(X_train, y_train)
    #         self.save(save_path)
    #         return self.model
    
    
class DifferentiableModel(AbstractModel):
    @abstractmethod
    def compute_loss(self, y1, y2):
        pass

    @abstractmethod
    def get_optimisation_loop(input_properties, losses, n_iter, lr, min_max_lambda, early_stopping):
        pass

class MILPEncodableModel(AbstractModel):
    @abstractmethod 
    def gp_set_model_constraints(self, grb_model: gp.Model, input_mvar: gp.MVar) -> gp.MVar:
        pass

    @abstractmethod
    def gp_set_classification_constraint(self, grb_model: gp.Model, output_vars: gp.MVar, target_class: int, db_distance=1e-6) -> None:
        pass