# Utils
import torch
from torch import nn
import datetime


class REVISE:

    def __init__(self, classifier, model_vae, optimizer: str = "adam", max_iter: int = 750, target_thres: float = 0.0,
                 t_max_min: float = 0.5, _lambda: float = 10.0,
                 lr: float = 0.05, norm: int = 1, step: float = 0.05, setting: str = "classification"):
        
        super().__init__()
        self.model_classification = classifier
        self.model_vae = model_vae
        self.lr = lr
        self.norm = norm
        self.max_iter = max_iter
        self.optimizer = optimizer
        self.target_threshold = target_thres
        self._lambda = _lambda
        self.t_max_min = t_max_min
        self.step = step
        self.setting = setting

    def generate_counterfactuals(self, query_instance: torch.tensor, target_class: int = 1) -> torch.tensor:
        """
            query instance: the point to be explained
            target_class: Direction of the desired change. If target_class = 1, we aim to improve the score,
            if target_class = 0, we aim to decrese it (in classification and regression problems).
            _lambda: Lambda parameter (distance regularization) parameter of the problem
        """
        
        if target_class == 1:
            target_prediction = torch.tensor(1).float()
        else:
            target_prediction = torch.tensor(0).float()
            
        output = self._call_model(query_instance.reshape(1, -1))
        # check if query instance is already valid
        if self._check_cf_valid(output, target_class):
            print("Query instance already has desired label.")
            return None

        query_instance = query_instance.clone().detach()
        z = self.model_vae.encode_csearch(query_instance).clone().detach()
        z.requires_grad = True
        
        if self.optimizer == "adam":
            optim = torch.optim.Adam([z], self.lr)
        else:
            optim = torch.optim.RMSprop([z], self.lr)
            
        # Timer
        t0 = datetime.datetime.now()
        t_max = datetime.timedelta(minutes=self.t_max_min)
        
        counterfactuals = []  # all possible counterfactuals
        distances = []        # distance of the possible counterfactuals from the initial value
        
        # set for now: will be positive in our setting
        output = torch.tensor(-1)
        lam = self._lambda
        
        while not self._check_cf_valid(output, target_class):
    
            it = 0
            distances = []

            while not self._check_cf_valid(output, target_class) and it < self.max_iter:
    
                cf = self.model_vae.decode_csearch(z)
                output = self._call_model(cf)
                predicted = output[0] > self.target_threshold
                if predicted == target_prediction:
                    counterfactuals.append(cf)
                    
                z.requires_grad = True
                total_loss, distance = self.compute_loss(cf_proposal=cf,
                                                         query_instance=query_instance,
                                                         target=target_prediction,
                                                         _lambda=lam)
                
                optim.zero_grad()
                total_loss.backward()
                optim.step()
                
                output = self._call_model(cf)
                if self._check_cf_valid(output, target_class):
                    counterfactuals.append(cf.detach())
                    distances.append(torch.tensor(distance.detach()))

                it = it + 1

            # print('output', output)
            if datetime.datetime.now() - t0 > t_max:
                break
            elif self._check_cf_valid(output, target_class):
                # print('Counterfactual explanation found')
                pass
    
            if self.step == 0.0:  # Don't search over lambdas
                break
            else:
                lam -= self.step
        
        if not len(counterfactuals):
            print('No counterfactual explanation found')
            return None

        # Choose the nearest counterfactual
        counterfactuals = torch.stack(counterfactuals)
        
        distances = torch.stack(distances)
        distances = distances.detach()
        index = torch.argmin(distances)
        counterfactuals = counterfactuals.detach()
        ce_star = counterfactuals[index]
        distance_star = distances[index]

        return ce_star

    def compute_loss(self, cf_proposal, query_instance, target, _lambda):

        output = self.model_classification(cf_proposal)[0]
        if self.setting == "classification":
            bce_loss = nn.BCELoss()
            loss_classification = bce_loss(output, target)

        elif self.setting == "regression":
            mse_loss = nn.MSELoss()
            loss_classification = mse_loss(output, target)
        else:
            raise ValueError("Illegal setting. Only classification and regression are supported.")
            
        loss_distance = torch.norm((cf_proposal - query_instance), self.norm)
        total_loss = loss_classification + _lambda * loss_distance
        return total_loss, loss_distance
    
    def _check_cf_valid(self, output, target_class):
        """ Check if the output constitutes a sufficient CF-example.
            target_class = 1 in general means that we aim to improve the score,
            whereas for target_class = 0 we aim to decrese it.
        """
        if target_class == 1:
            check = output >= self.target_threshold
            return check
        else:
            check = output <= self.target_threshold
            return check
    
    def _call_model(self, cf_candidate):
        if self.setting == "classification":
            output = self.model_classification(cf_candidate)[0]

        elif self.setting == "regression":
            output = self.model_classification.predict_with_logits(cf_candidate).reshape(1, -1)

        else:
            raise ValueError("Illegal setting. Only classification and regression are supported.")
        return output


class CEM:
    
    def __init__(self, classifier, model_vae, optimizer: str = "adam", max_iter: int = 750,  target_thres: float = 0.0,
                 t_max_min: float = 0.5, _lambda: float = 0.25,
                 lr: float = 0.05, norm: int = 1, step: float = 0.00, setting: str = "classification"):
        
        super().__init__()
        self.model_classification = classifier
        self.model_vae = model_vae
        self.lr = lr
        self.norm = norm
        self.max_iter = max_iter
        self.optimizer = optimizer
        self.target_threshold = target_thres
        self._lambda = _lambda
        self.t_max_min = t_max_min
        self.step = step
        self.setting = setting
    
    def generate_counterfactuals(self, query_instance: torch.tensor, target_class: int = 1) -> torch.tensor:
        """
            query instance: the point to be explained
            target_class: Direction of the desired change. If target_class = 1, we aim to improve the score,
            if target_class = 0, we aim to decrese it (in classification and regression problems).
            _lambda: Lambda parameter (distance regularization) parameter of the problem
        """
        
        if target_class == 1:
            target_prediction = torch.tensor(1).float()
        else:
            target_prediction = torch.tensor(0).float()
        
        output = self._call_model(query_instance.reshape(1, -1))
        # print(output)
        # check if query instance is already valid
        if self._check_cf_valid(output, target_class):
            print("Query instance already has desired label.")
            return None
        
        query_instance = query_instance.clone().detach()
        cf = query_instance.clone().requires_grad_(True)
        
        if self.optimizer == "adam":
            optim = torch.optim.Adam([cf], self.lr)
        else:
            optim = torch.optim.RMSprop([cf], self.lr)
        
        # Timer
        t0 = datetime.datetime.now()
        t_max = datetime.timedelta(minutes=self.t_max_min)
        
        counterfactuals = []  # all possible counterfactuals
        distances = []  # distance of the possible counterfactuals from the initial value
        
        # set for now: will be 0 in our setting
        output = torch.tensor(-1)
        lam = self._lambda
        
        while not self._check_cf_valid(output, target_class):
            
            it = 0
            distances = []
            
            while not self._check_cf_valid(output, target_class) and it < self.max_iter:
                
                total_loss, distance = self.compute_loss(cf_proposal=cf,
                                                         query_instance=query_instance,
                                                         target=target_prediction,
                                                         _lambda=lam)
                
                optim.zero_grad()
                total_loss.backward()
                optim.step()
                
                output = self._call_model(cf)
                if self._check_cf_valid(output, target_class):
                    counterfactuals.append(cf.detach())
                    distances.append(torch.tensor(distance.detach()))
                
                it = it + 1
                
            if datetime.datetime.now() - t0 > t_max:
                break
            elif self._check_cf_valid(output, target_class):
                # print('Counterfactual explanation found')
                pass

            if self.step == 0.0:  # Don't search over lambdas
                break
            else:
                lam -= self.step

        if not len(counterfactuals):
            print('No counterfactual explanation found')
            return None
        
        # Choose the nearest counterfactual
        counterfactuals = torch.stack(counterfactuals)
        distances = torch.stack(distances)
        distances = distances.detach()
        index = torch.argmin(distances)
        counterfactuals = counterfactuals.detach()
        ce_star = counterfactuals[index]
        distance_star = distances[index]
        
        return ce_star
    
    def compute_loss(self, cf_proposal, query_instance, target, _lambda):
        
        output = self._call_model(cf_proposal.reshape(1, -1))
        if self.setting == "classification":
            bce_loss = nn.BCELoss()
            loss_classification = bce_loss(output, target)

        elif self.setting == "regression":
            mse_loss = nn.MSELoss()
            loss_classification = mse_loss(output, target)
        else:
            raise ValueError("Illegal setting. Only classification and regression are supported.")
        loss_distance_input = torch.norm((cf_proposal - query_instance), self.norm)
        loss_distance_reconstucted = torch.norm((self.model_vae.predict_csearch(cf_proposal) - cf_proposal), 2) ** 2
        
        total_loss = loss_classification + _lambda * loss_distance_input + 0.15 * loss_distance_reconstucted
        return total_loss, loss_distance_input
    
    def _check_cf_valid(self, output, target_class):
        """ Check if the output constitutes a sufficient CF-example.
            target_class = 1 in general means that we aim to improve the score,
            whereas for target_class = 0 we aim to decrese it.
        """
        if target_class == 1:
            check = output >= self.target_threshold
            return check
        else:
            check = output <= self.target_threshold
            return check
    
    def _call_model(self, cf_candidate):
        if self.setting == "classification":
            output = self.model_classification(cf_candidate)[0]

        elif self.setting == "regression":
            output = self.model_classification.predict_with_logits(cf_candidate).reshape(1, -1)

        else:
            raise ValueError("Illegal setting. Only classification and regression are supported.")
        return output
