import torch
import numpy as np
from torch import nn
import datetime


class SCFE:
    
    def __init__(self, model_classification, target_thres: float = 0.0, step: float = 0.0,
                 lr: float = 0.05, max_iter: int = 2000, t_max_min: float = 0.5,
                 _lambda: float = 0, norm: int = 1,
                 setting: str = "classification", mode: str = 'scfe', optimizer: str = 'adam'):
        super().__init__()
        
        self.model_classification = model_classification
        self.mode = mode
        self.lr = lr
        self._lambda = _lambda
        self.max_iter = max_iter
        self.optimizer = optimizer
        self.t_max_min = t_max_min
        self.norm = norm
        self.sigmoid = nn.Sigmoid()
        self.target_thres = target_thres
        self.setting = setting
        self.step = step

    def generate_counterfactuals(self, query_instance: torch.tensor, target_class: int = 1,
                                 _lambda: float = 10.0) -> torch.tensor:
        """
            query instance: the point to be explained
            target_class: Direction of the desired change. If target_class = 1, we aim to improve the score,
                if target_class = 0, we aim to decrese it (in classification and regression problems).
            _lambda: Lambda parameter (distance regularization ) parameter of the problem
        """
        if target_class == 1:
            target_prediction = torch.tensor(1).float()
        else:
            target_prediction = torch.tensor(0).float()
        
        output = self._call_model(query_instance.reshape(1, -1))
        # check if query instance is already valid
        if self._check_cf_valid(output, target_class):
            print("Query instance already has desired label.")
            return None

        cf = query_instance.clone().requires_grad_(True)

        if self.optimizer == 'adam':
            optim = torch.optim.Adam([cf], self.lr)
        else:
            optim = torch.optim.RMSprop([cf], self.lr)
        
        # Timer
        t0 = datetime.datetime.now()
        t_max = datetime.timedelta(minutes=self.t_max_min)
        
        lam = self._lambda
        counterfactuals = []
        while not self._check_cf_valid(output, target_class):
            
            iter = 0
            
            distances = []
            all_loss = []
            
            while not self._check_cf_valid(output, target_class) and iter < self.max_iter:
                optim.zero_grad()
                total_loss, loss_distance = self.compute_loss(_lambda=lam,
                                                              cf_candidate=cf,
                                                              original_instance=query_instance,
                                                              target=target_prediction)
                
                total_loss.backward()
                optim.step()
                
                output = self._call_model(cf)
                if self._check_cf_valid(output, target_class):
                    counterfactuals.append(cf.detach())
                    distances.append(loss_distance.detach())
                    all_loss.append(total_loss.detach().numpy())
                
                iter = iter + 1
            
            output = self._call_model(cf).reshape(1, -1)
            if datetime.datetime.now() - t0 > t_max:
                # print('Timeout - No counterfactual explanation found')
                break
            elif self._check_cf_valid(output, target_class):
                # print('Counterfactual explanation found')
                pass

            if self.step == 0.0:  # Don't search over lambdas
                break
            else:
                lam -= self.step

        if not len(counterfactuals):
            print('No counterfactual explanation found')
            return None
        
        # Choose the nearest counterfactual
        counterfactuals = torch.stack(counterfactuals)
        distances = torch.stack(distances)
        distances = distances.detach().numpy()
        index = np.argmin(distances)
        counterfactuals = counterfactuals.detach()
        
        return counterfactuals[index]
    
    def compute_loss(self, _lambda: float,
                     cf_candidate: torch.tensor,
                     original_instance: torch.tensor,
                     target: torch.tensor) -> torch.tensor:
        
        output = self._call_model(cf_candidate)
        # classification loss
        if self.setting == "classification":
            bce_loss = nn.BCELoss()
            loss_classification = bce_loss(output, target)

        elif self.setting == "regression":
            mse_loss = nn.MSELoss()
            loss_classification = mse_loss(output, target)
        else:
            raise ValueError("Illegal setting. Only classification and regression are supported.")
        # distance loss
        loss_distance = torch.norm((cf_candidate - original_instance), self.norm)
        total_loss = loss_classification + _lambda * loss_distance
        return total_loss, loss_distance

    def _call_model(self, cf_candidate):
        if self.setting == "classification":
            output = self.model_classification(cf_candidate)[0]

        elif self.setting == "regression":
            output = self.model_classification.predict_with_logits(cf_candidate).reshape(1, -1)

        else:
            raise ValueError("Illegal setting. Only classification and regression are supported.")
        return output

    def _check_cf_valid(self, output, target_class):
        """ Check if the output constitutes a sufficient CF-example.
            target_class = 1 in general means that we aim to improve the score,
            whereas for target_class = 0 we aim to decrese it.
        """
        if target_class == 1:
            check = output >= self.target_thres
            return check
        else:
            check = output <= self.target_thres
            return check


class DICE:

    def __init__(self, model_classification, target_thres: float = 0.0, step: float = 0.0,
                 lr: float = 0.05, max_iter: int = 2000, t_max_min: float = 0.5,
                 norm: int = 1, _lambda: float = 0,
                 setting: str = "classification", optimizer: str = 'adam', diversity_loss_type="avg_dist"):
        super().__init__()
    
        self.model_classification = model_classification
        self.total_cfs = 2
        self.lr = lr
        self._lambda = _lambda
        self.max_iter = max_iter
        self.optimizer = optimizer
        self.t_max_min = t_max_min
        self.norm = norm
        self.sigmoid = nn.Sigmoid()
        self.target_thres = target_thres
        self.setting = setting
        self.diversity_loss_type = diversity_loss_type
        self.step = step
        self.cfs = []

    def do_cf_initializations(self, dim):
        """Intializes CFs and other related variables."""

        # CF initialization
        self.cfs = []
        for ix in range(self.total_cfs):
            one_init = []
            for jx in range(dim):
                one_init.append(np.random.uniform(-2.5, +2.5))
            self.cfs.append(torch.tensor(one_init).float())
                
    def initialize_cfs(self, query_instance, init_near_query_instance=True):
        """Initialize counterfactuals."""
        for n in range(self.total_cfs):
            for i in range(query_instance.shape[1]):
                if init_near_query_instance:
                    self.cfs[n].data[i] = query_instance[0, i]+(n*0.01)
                else:
                    self.cfs[n].data[i] = np.random.uniform(-2.5, +2.5)
            # self.cfs[n].requires_grad_(True)

    def generate_counterfactuals(self, query_instance: torch.tensor,
                                 target_class: int = 1,
                                 ) -> torch.tensor:
        """
            query instance: the point to be explained
            target_class: Direction of the desired change. If target_class = 1, we aim to improve the score,
                if target_class = 0, we aim to decrese it (in classification and regression problems).
            _lambda: Lambda parameter (distance regularization ) parameter of the problem
        """
        if target_class == 1:
            target_prediction = torch.tensor(1).float()
        else:
            target_prediction = torch.tensor(0).float()
            
        # These methods are inspired by the original DICE code
        self.do_cf_initializations(dim=query_instance.shape[1])
        self.initialize_cfs(query_instance, init_near_query_instance=True)
    
        output = self._call_model(query_instance.reshape(1, -1))
        # check if query instance is already valid
        if self._check_cf_valid(output, target_class):
            print("Query instance already has desired label.")
            return None
        
        xprime = []
        for i in range(self.total_cfs):
            x = self.cfs[i].clone().detach()
            x.requires_grad_(True)
            xprime.append(x)
        
        if self.optimizer == 'adam':
            optim = torch.optim.Adam(xprime, self.lr)
        else:
            optim = torch.optim.RMSprop(xprime, self.lr)
    
        # Timer
        t0 = datetime.datetime.now()
        t_max = datetime.timedelta(minutes=self.t_max_min)
    
        outputs = []
        for i in range(self.total_cfs):
            output = self._call_model(query_instance.reshape(1, -1))
            outputs.append(output)

        lam = self._lambda
        counterfactuals = []
        while not self._check_cf_valid(outputs, target_class):
        
            iter = 0
        
            distances = []
            all_loss = []
        
            while not self._check_cf_valid(outputs, target_class) and iter < self.max_iter:
                optim.zero_grad()
                total_loss, loss_distance = self.compute_loss(_lambda=lam,
                                                              counterfactuals=xprime,
                                                              original_instance=query_instance,
                                                              target=target_prediction
                                                              )
                total_loss.backward()
                optim.step()
                
                outputs = []
                for i in range(self.total_cfs):
                    output = self._call_model(xprime[i].reshape(1, -1))
                    outputs.append(output)
                    # self.cfs[i].requires_grad_(True)

                # print(output)
                if self._check_cf_valid(outputs, target_class):
                    coin = np.random.binomial(n=1, p=0.5)
                    cf = xprime[int(coin)]
                    counterfactuals.append(cf.detach())
                    distances.append(loss_distance.detach())
                    all_loss.append(total_loss.detach().numpy())
            
                iter = iter + 1
        
            # print("output:", output)
            if datetime.datetime.now() - t0 > t_max:
                # print('Timeout - No counterfactual explanation found')
                break
            elif self._check_cf_valid(outputs, target_class):
                # print('Counterfactual explanation found')
                pass
        
            if self.step == 0.0:  # Don't search over lambdas
                break
            else:
                lam -= self.step
    
        if not len(counterfactuals):
            print('No counterfactual explanation found')
            return None
    
        # Choose the nearest counterfactual
        counterfactuals = torch.stack(counterfactuals)
        distances = torch.stack(distances)
        distances = distances.detach().numpy()
        index = np.argmin(distances)
        counterfactuals = counterfactuals.detach()
    
        return counterfactuals[index]

    def compute_loss(self, _lambda: float, original_instance: torch.tensor,
                     counterfactuals: list,
                     target: torch.tensor) -> torch.tensor:
        
        outputs = []
        loss_classification = 0.0
        
        for i in range(self.total_cfs):
            output = self._call_model(counterfactuals[i].reshape(1, -1))
            # print('output', output)
            outputs.append(output)
            # output.backward()
        # classification loss
        if self.setting == "classification":
            bce_loss = nn.BCELoss()
            for i in range(self.total_cfs):
                loss_classification += (1/self.total_cfs) * bce_loss(outputs[i], target)
    
        elif self.setting == "regression":
            mse_loss = nn.MSELoss()
            for i in range(self.total_cfs):
                loss_classification += (1/self.total_cfs) * mse_loss(outputs[i][0].reshape(-1), target.reshape(-1))
        else:
            raise ValueError("Illegal setting. Only classification and regression are supported.")
        
        # distance loss
        loss_distance = self.proximity_loss(original_instance, counterfactuals)
        # diversity loss
        diversity_loss = self.compute_diversity_loss(counterfactuals)
        # total loss
        total_loss = loss_classification + _lambda * loss_distance + diversity_loss
        # print("total loss:", total_loss)
    
        return total_loss, loss_distance
    
    def proximity_loss(self, original_instance, counterfactuals):
        proximity_loss = 0.0
        for i in range(self.total_cfs):
            proximity_loss += torch.norm((counterfactuals[i]-original_instance), self.norm)
        return proximity_loss/self.total_cfs

    def _call_model(self, cf_candidate):
        if self.setting == "classification":
            output = self.model_classification(cf_candidate)[0]
    
        elif self.setting == "regression":
            output = self.model_classification.predict_with_logits(cf_candidate).reshape(1, -1)
    
        else:
            raise ValueError("Illegal setting. Only classification and regression are supported.")
        return output

    def _check_cf_valid(self, outputs, target_class):
        """ Check if the output constitutes a sufficient CF-example.
            target_class = 1 in general means that we aim to improve the score,
            whereas for target_class = 0 we aim to decrese it.
        """
        if target_class == 1:
            checks = []
            for output in outputs:
                check = output >= self.target_thres
                checks.append(check)
            # print(checks)
            return all(checks)
        else:
            checks = []
            for output in outputs:
                check = output <= self.target_thres
                checks.append(check)
            return all(checks)

    def dpp_style(self, submethod, counterfactuals):
        """Computes the DPP of a matrix."""
        det_entries = torch.ones((self.total_cfs, self.total_cfs))
        if submethod == "inverse_dist":
            for i in range(self.total_cfs):
                for j in range(self.total_cfs):
                    det_entries[(i, j)] = 1.0/(1.0 + torch.norm(counterfactuals[i] - counterfactuals[j], self.norm))
                    if i == j:
                        det_entries[(i, j)] += 0.0001
                    
        elif submethod == "exponential_dist":
            for i in range(self.total_cfs):
                for j in range(self.total_cfs):
                    det_entries[(i, j)] = 1.0/(torch.exp(torch.norm(counterfactuals[i] - counterfactuals[j], self.norm)))
                    if i == j:
                        det_entries[(i, j)] += 0.0001

        diversity_loss = torch.det(det_entries)
        return diversity_loss

    def compute_diversity_loss(self, counterfactuals):
        """Computes the third part (diversity) of the loss function."""
        if self.total_cfs == 1:
            return torch.tensor(0.0)

        if "dpp" in self.diversity_loss_type:
            submethod = self.diversity_loss_type.split(':')[1]
            return self.dpp_style(submethod, counterfactuals)
        elif self.diversity_loss_type == "avg_dist":
            diversity_loss = 0.0
            count = 0.0
            # computing pairwise distance and transforming it to normalized similarity
            for i in range(self.total_cfs):
                for j in range(i+1, self.total_cfs):
                    count += 1.0
                    diversity_loss += 1.0/(1.0 + torch.norm(counterfactuals[i]-counterfactuals[j], self.norm))

        return 1.0 - (diversity_loss/count)
