from models.abstract_model import DifferentiableModel, MILPEncodableModel

import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import torch.optim as optim
from models.latent_encodings import IdentityEncoding

from sklearn.metrics import precision_score, f1_score
from gurobi_ml.torch.sequential import add_sequential_constr

import json

import gurobipy as gp
from gurobipy import GRB
import os
import numpy as np

from abc import ABC, abstractmethod
from counterfactual_explanations.gradient_based.losses import *

import torch
from torch.optim.optimizer import Optimizer
from counterfactual_explanations.input_properties import InputProperties

class PyTorchModel(DifferentiableModel):
    def __init__(self, config, input_properties: InputProperties):
        super().__init__(config, input_properties)
        self.batch_size = self.config.get('batch_size', 64)
        self.epochs = self.config.get('epochs', 50)
        self.lr = self.config.get("lr", 0.01)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu")

        if input_properties is not None:
            self.pytorch_model, self.loss_fn, self.optimiser = self._build_model()
            self.pytorch_model.to(self.device)

    @abstractmethod
    def _build_model(self):
        pass

    def train(self, X_train, y_train):
        torch.manual_seed(self.random_state)
        cfg = self.config

        X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(self.device)
        y_train_tensor = torch.tensor(y_train, dtype=torch.long).to(self.device)

        if self.input_properties.y_onehot:
            n_classes = self.input_properties.n_targets
            y_train_onehot = nn.functional.one_hot(y_train_tensor, num_classes=n_classes).float()
            train_dataset = TensorDataset(X_train_tensor, y_train_onehot)
        else:
            train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
        
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)

        self.pytorch_model.train()

        for epoch in range(self.epochs):
            for X_batch, y_batch in train_loader:
                X_batch = X_batch.to(self.device)
                y_batch = y_batch.to(self.device)
                self.optimiser.zero_grad()
                outputs = self.pytorch_model(X_batch)
                loss = self.loss_fn(outputs, y_batch)
                loss.backward()
                self.optimiser.step()

        return self.pytorch_model

    def predict(self, x):
        return self.pytorch_model(torch.tensor(x, dtype=torch.float).to(self.device)).detach().cpu().numpy()

    def load(self, save_path) -> None:
        self.pytorch_model = torch.load(save_path, weights_only=False)
        self.pytorch_model.to(self.device)
        self.save_path = save_path

    def save(self, save_path):
        torch.save(self.pytorch_model, save_path)
        self.save_path = save_path
    
    def evaluate(self, X_test, y_test):
        self.pytorch_model.eval()

        X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(self.device)
        y_test_tensor = torch.tensor(y_test, dtype=torch.long).to(self.device)

        with torch.no_grad():
            test_outputs = self.pytorch_model(X_test_tensor)
            test_loss = self.loss_fn(test_outputs, y_test_tensor).item()
            _, predicted = torch.max(test_outputs, 1)
            accuracy = (predicted == y_test_tensor).sum().item() / y_test_tensor.size(0)
            precision = precision_score(y_test_tensor.cpu().numpy(), predicted.cpu().numpy(), average='weighted')
            f1 = f1_score(y_test_tensor.cpu().numpy(), predicted.cpu().numpy(), average='weighted')

        model_performance = {
            'loss': test_loss,
            'accuracy': accuracy * 100,
            'precision': precision * 100,
            'f1_score': f1 * 100
        }

        return model_performance
    
    def compute_loss(self, y1, y2):
        return self.loss_fn(y1, y2)

    def get_optimisation_loop(self, input_properties, losses, n_iter, lr, min_max_lambda=None, losses_weights=None, jsma=False, latent_encoding=IdentityEncoding(), early_stopping=False, retain_graph=False):
        optimisation = PyTorchMLP_Optimisation(self, input_properties, losses, n_iter, lr, min_max_lambda, losses_weights, jsma=jsma, latent_encoding=latent_encoding, early_stopping=early_stopping, retain_graph=retain_graph)
        return optimisation
    
    def load_external(self, model):
        self.pytorch_model = model
        

class PyTorchMLP(PyTorchModel, MILPEncodableModel):
    def _build_model(self):
        if self.config.get("dims"):
            self.dims = self.config["dims"]
        else:
            self.dims = self.config.get('hidden_dims', [50])
            self.dims = [self.input_properties.n_features] + self.dims + [self.input_properties.n_targets]

        layers = []

        for i in range(len(self.dims) - 1):
            layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))
            layers.append(nn.ReLU())

        layers = layers[:-1]
        
        model = nn.Sequential(*layers)
        model.to(self.device)

        loss_fn = nn.CrossEntropyLoss()
        optimiser = optim.Adam(model.parameters(), lr=self.lr)

        return model, loss_fn, optimiser
    
    def gp_set_model_constraints(self, grb_model: gp.Model, input_mvar: gp.MVar) -> gp.MVar:
        output_mvar = grb_model.addMVar(shape=(self.pytorch_model[-1].out_features,), lb=-GRB.INFINITY, ub=GRB.INFINITY, name="output")
        add_sequential_constr(grb_model, self.pytorch_model.cpu(), input_mvar, output_mvar)
        self.pytorch_model.to(self.device)
        return output_mvar
    
    def gp_set_classification_constraint(self, grb_model: gp.Model, output_vars: gp.MVar, target_class: int, db_distance=1e-6) -> None:
        classification_constrs = []

        if output_vars.shape[0] == 1:
            #Single output
            assert target_class in [0, 1], "Target class must be 0 or 1 for a single logit output"

            if target_class == 0:
                c1 = grb_model.addConstr(output_vars[0] >= 0.5, name="Output class 0 constraint")
                classification_constrs.append(c1)
            else:
                c2 = grb_model.addConstr(output_vars[0] <= 0.5, name="Output class 1 constraint")
                classification_constrs.append(c2)
            
        else:
            #One-hot output
            for i in range(output_vars.shape[0]):
                if i != target_class:
                    c = grb_model.addConstr(output_vars[target_class] >= output_vars[i] + db_distance)
                    classification_constrs.append(c)

        
        return classification_constrs


class DifferentiableOptimisation(ABC):
    def __init__(self, model, input_properties, losses, n_iter, lr, min_max_lambda=None, losses_weights=None, latent_encoding=IdentityEncoding(), jsma=False, early_stopping=False, retain_graph=False):
        self.model = model
        self.input_properties = input_properties
        self.losses = losses
        self.n_iter = n_iter
        self.lr = lr
        self.early_stopping = early_stopping
        self.min_max_lambda = min_max_lambda
        self.losses_weights = losses_weights
        self.latent_encoding = latent_encoding
        self.retain_graph=retain_graph
        if self.losses_weights is None and self.losses is not None:
            self.losses_weights = np.ones(len(self.losses))
            assert len(self.losses_weights) == len(self.losses), "losses and losses_weight length mismatch"

        self.jsma = jsma
        self.tensor_bounds = None
        

    @abstractmethod
    def optimise_minmax(self, x, y_target):
        pass

    def optimise_min(self, x, y_target):
        pass

class SalientFeatureOptimizer(Optimizer):
    def __init__(self, params, lr=0.1):
        defaults = dict(lr=lr)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group['lr']
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad

                grad_flat = grad.view(-1)
                if grad_flat.numel() == 0:
                    continue

                idx = torch.argmax(grad_flat.abs())
                update_val = -1 * lr * torch.sign(grad_flat[idx])

                grad_update = torch.zeros_like(grad)
                grad_update.view(-1)[idx] = update_val

                p.add_(grad_update)

        return loss


class PyTorchMLP_Optimisation(DifferentiableOptimisation):

    def setup(self, x_factual, y_target):
        device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu")
        torch.manual_seed(0)

        x_factual = torch.from_numpy(x_factual).float().to(device)
        y_factual = self.model.pytorch_model(x_factual)

        if self.input_properties.y_onehot:
            y_target = torch.nn.functional.one_hot(torch.tensor(y_target), self.input_properties.n_targets).float().to(device)
        else:
            y_target = torch.from_numpy(y_target).float().to(device)

        z = torch.autograd.Variable(self.latent_encoding.encode(x_factual.clone()), requires_grad=True)
        z_factual = z.clone()
        x = self.latent_encoding.decode(z)
        x_enc = self.fix_encoding(x, self.latent_encoding)

        if self.jsma:
            optimiser = SalientFeatureOptimizer([z], self.lr)
        else:
            optimiser = torch.optim.Adam([z], self.lr, amsgrad=True)
        
        y_enc = self.model.pytorch_model(x_enc).to(device)

        opt_state = OptimisationState(self.model, z, z_factual, x_enc, y_enc, x_factual, y_factual, y_target, 0, self.n_iter)

        return device, optimiser, opt_state
    
    def correct_classification(self, y_enc, y_target):
        return torch.argmax(y_enc) == torch.argmax(y_target)
    
    def fix_encoding(self, x, latent_encoding = None):
        if not isinstance(self.latent_encoding, IdentityEncoding) :
            return x

        x_enc = torch.zeros_like(x)

        for i in range(self.input_properties.n_features):
            feature_class = self.input_properties.feature_classes[i]
            bound = self.input_properties.bound[i]

            if feature_class == 'numeric' and bound is not None:
                proj = torch.clamp(x[i], bound[0], bound[1])

            elif feature_class == 'ordinal' or feature_class == 'ordinal_normalised':
                if self.tensor_bounds is None:
                    self.tensor_bounds = []
                    for j in range(self.input_properties.n_features):
                        if self.input_properties.bound[j]:
                            self.tensor_bounds.append(torch.tensor(self.input_properties.bound[j], device=x.device))
                        else:
                            self.tensor_bounds.append(None)

                diffs = (x[i] - self.tensor_bounds[i]) ** 2
                idx = torch.argmin(diffs, dim=-1)
                proj = bound[idx]

            x_enc[i] = proj

        for group in self.input_properties.categorical_groups:
            group_vals = x[group]

            idx = torch.argmax(group_vals, dim=-1)
            onehot = torch.functional.F.one_hot(idx, num_classes=group_vals.shape[-1])
            proj = onehot.to(group_vals.dtype)

            x_enc[group] = proj

        return x_enc


    def optimise_minmax(self, x, y_target):
        device, optimiser, opt_state = self.setup(x, y_target)
        y_target = opt_state.y_target

        min_max_lambda = torch.tensor(self.min_max_lambda).float().to(device)

        prev_solution = None
        change = torch.inf

        l = 2
        while not (self.early_stopping and self.correct_classification(opt_state.y_enc, y_target)) and l > 0:
            opt_state.it = 0
            while not (self.early_stopping and self.correct_classification(opt_state.y_enc, y_target)) and opt_state.it < self.n_iter:
                optimiser.zero_grad()
                
                loss_term_0 = self.losses[0].loss(opt_state) * self.losses_weights[0]

                remaining_losses = torch.tensor(0.0, requires_grad=True)

                for i in range(1, len(self.losses)):
                    remaining_losses = remaining_losses + self.losses[i].loss(opt_state) * self.losses_weights[i]

                loss = loss_term_0 + min_max_lambda * remaining_losses
                loss.backward(retain_graph=self.retain_graph)
                optimiser.step()
                
                x = self.latent_encoding.decode(opt_state.z)
                opt_state.x_enc = self.fix_encoding(x, self.latent_encoding)
                opt_state.y_enc = self.model.pytorch_model(opt_state.x_enc).to(device)


                if prev_solution is not None:
                    change = torch.norm(opt_state.x_enc - prev_solution).item()
                if change < 1e-4:  # Threshold for minimal change
                    # print("Early stopping")
                    break

                prev_solution = opt_state.x_enc.clone()

                opt_state.it += 1
                
            l -= 1
            min_max_lambda -= 0.05

        return opt_state.x_enc.cpu().detach().numpy()

    def optimise_min(self, x, y_target):
        device, optimiser, opt_state = self.setup(x, y_target)
        y_target = opt_state.y_target
        prev_solution = None
        change = torch.inf

        while (not (self.early_stopping and self.correct_classification(opt_state.y_enc, y_target))) and opt_state.it < self.n_iter:
            optimiser.zero_grad()
            
            losses = 0.0

            for i in range(0, len(self.losses)):
                losses = losses + self.losses[i].loss(opt_state) * self.losses_weights[i]
            
            # losses *= -1

            losses.backward(retain_graph=self.retain_graph)

            optimiser.step()
            
            x = self.latent_encoding.decode(opt_state.z)
            opt_state.x_enc = self.fix_encoding(x, self.latent_encoding)
            opt_state.y_enc = self.model.pytorch_model(opt_state.x_enc).to(device)


            if prev_solution is not None:
                change = torch.norm(opt_state.x_enc - prev_solution).item()
            if change < 1e-4:  # Threshold for minimal change
                # print("Early stopping")
                break

            prev_solution = opt_state.x_enc.clone()
            opt_state.it += 1

        return opt_state.x_enc.cpu().detach().numpy()


