import pandas as pd
import numpy as np
from tqdm import tqdm
from typing import Mapping, Sequence, Tuple, Callable
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelEncoder
from sklearn.decomposition import TruncatedSVD
import torch
from torch import nn

from .base import IndivFairModel, LogReg, NeuralNets, MLPClassifier


class Project(IndivFairModel):
    """ Re-implementation of https://github.com/IBM/sensitive-subspace-robustness """

    def __init__(self, base: IndivFairModel):
        super(Project, self).__init__()
        self.model = base

    @staticmethod
    def compl_svd_projector(basis: np.ndarray) -> np.ndarray:
        proj = np.linalg.inv(basis @ basis.T)
        proj = basis.T @ proj @ basis
        proj_compl = np.eye(proj.shape[0]) - proj

        return proj_compl

    @staticmethod
    def get_sen_dirs_and_proj_mat(
            X: np.ndarray,
            train_df: pd.DataFrame,
            sen_feat: Sequence[str],
            sen_idx: Sequence[int],
    ) -> Tuple[np.ndarray, np.ndarray]:
        """

        :param X: the feature matrix after one-hot encoding
        :param train_df: the dataframe for training
        :param sen_feat: the name of sensitive features in train_df
        :param sen_idx: the index of sensitive features in X after one-hot encoding
        :return: sensitive directions and projection matrix
        """

        log_reg = LogisticRegression(
            penalty="l2",
            C=10.,
            fit_intercept=False,
            multi_class="ovr",
            warm_start=False,
            max_iter=2048,
        )

        le = LabelEncoder()
        sen_df = train_df[list(sen_feat)].astype("str")
        sen_y = le.fit_transform(sen_df.sum(axis=1).values)  # targets for logistic regression

        all_idx = [i for i in range(X.shape[1])]
        remain_idx = list(set(all_idx) - set(sen_idx))
        sen_X = np.take(X, remain_idx, axis=1)

        # fit logistic regression to predict sensitive information
        log_reg.fit(sen_X, sen_y)
        weight = log_reg.coef_

        sen_dirs = []
        for idx in sen_idx:
            weight = np.insert(weight, idx, 0, axis=1)
        sen_dirs.append(weight)

        # TODO: determine the constant values in sensitive directions
        cnt = 1.
        for idx in sen_idx:
            temp_dirs = np.zeros((1, weight.shape[1]))
            temp_dirs[0, idx] = cnt
            sen_dirs.append(np.copy(temp_dirs))

        sen_dirs = np.vstack(sen_dirs)

        tSVD = TruncatedSVD(n_components=2 + len(sen_idx))
        tSVD.fit(sen_dirs)
        sen_dirs = tSVD.components_
        proj_mat = Project.compl_svd_projector(sen_dirs)

        return sen_dirs, proj_mat

    def fit(self, X: np.ndarray, y: np.ndarray, train_df: pd.DataFrame, sen_feat: Sequence[str],
            sen_idx: Sequence[int]):
        _, self.proj_mat = self.get_sen_dirs_and_proj_mat(X, train_df, sen_feat, sen_idx)
        self.model.fit(X @ self.proj_mat, y)
        return

    def pred(self, X):
        return self.model.pred(X @ self.proj_mat)

    def pred_proba(self, X):
        return self.model.pred_proba(X @ self.proj_mat)


class SenSR(IndivFairModel):
    """
    Re-implementation of https://github.com/IBM/sensitive-subspace-robustness
    Single batch training with all data
    """

    def __init__(
            self,
            input_dim: int,
            sen_dirs_dim: int,
            n_input: int,
            enable_full_adv: bool = False,
            n_iter: int = 1000,
            n_sen_adv_iter: int = 50,
            n_full_adv_iter: int = 40,
            lr: float = 1e-3,
            sen_adv_lr: float = 1e-1,
            full_adv_lr: float = 1e-4,
            lamb_init: int = 2.,
            eps: float = 1e-3,
            device: str = "cuda:1",
    ):
        super(SenSR, self).__init__()

        self.input_dim = input_dim
        self.sen_dirs_dim = sen_dirs_dim
        self.n_input = n_input
        self.enable_full_adv = enable_full_adv
        self.n_iter = n_iter
        self.n_sen_adv_iter = n_sen_adv_iter
        self.n_full_adv_iter = n_full_adv_iter
        self.lr = lr
        self.sen_adv_lr = sen_adv_lr
        self.full_adv_lr = full_adv_lr
        self.lamb = lamb_init
        self.eps = eps
        self.device = torch.device(device)
        self.pred_threshold = 0.5

        self.model = MLPClassifier(input_dim=self.input_dim).to(self.device)
        self.model_optim = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.criterion = nn.BCELoss()

        # sensitive subspace attack
        self.sen_adv = torch.zeros(n_input, self.sen_dirs_dim).to(self.device)
        self.sen_adv.requires_grad = True
        self.sen_adv_optim = torch.optim.Adam([self.sen_adv], lr=self.sen_adv_lr)
        # original space attack
        self.full_adv = torch.zeros(n_input, self.input_dim).to(self.device)
        self.full_adv.requires_grad = True
        self.full_adv_optim = torch.optim.Adam([self.full_adv], lr=self.full_adv_lr)

        self.model_skd = torch.optim.lr_scheduler.StepLR(self.model_optim, step_size=self.n_iter // 4, gamma=0.1)
        self.sen_adv_skd = torch.optim.lr_scheduler.StepLR(self.sen_adv_optim, step_size=self.n_iter // 4, gamma=0.1)
        self.full_adv_skd = torch.optim.lr_scheduler.StepLR(self.full_adv_optim, step_size=self.n_iter // 4, gamma=0.1)

    def reset_sen_adv(self):
        nn.init.zeros_(self.sen_adv)
        return

    def reset_full_adv(self):
        nn.init.zeros_(self.full_adv)
        return

    def fair_dist(self, x, y, proj):
        return torch.mean(torch.sum(torch.square((x - y) @ proj), dim=1))

    def fit(
            self,
            X: np.ndarray,
            y: np.ndarray,
            train_df: pd.DataFrame,
            sen_feat: Sequence[str],
            sen_idx: Sequence[int],
            verbose: bool = False,
    ):
        sen_dirs, proj_mat = Project.get_sen_dirs_and_proj_mat(X, train_df, sen_feat, sen_idx)
        sen_dirs = torch.from_numpy(sen_dirs).float().to(self.device)
        proj_mat = torch.from_numpy(proj_mat).float().to(self.device)

        X = torch.from_numpy(X).float().to(self.device)
        y = torch.from_numpy(y).float().to(self.device)

        self.model.train()
        for i in range(self.n_iter):
            self.reset_sen_adv()
            if self.enable_full_adv:
                self.reset_full_adv()

            with torch.no_grad():
                pred = self.model(X)
                init_loss = self.criterion(pred, y)

            # sensitive subspace attack
            for sen_adv_iter in range(self.n_sen_adv_iter):
                self.sen_adv_optim.zero_grad()

                adv_X = torch.add(X, torch.matmul(self.sen_adv, sen_dirs))
                pred = self.model(adv_X)
                loss = torch.neg(self.criterion(pred, y))
                loss.backward()
                self.sen_adv_optim.step()

                if sen_adv_iter == self.n_sen_adv_iter - 1:
                    if torch.neg(loss) <= init_loss:
                        # sensitive subspace attack failed
                        print("=> Sensitive subspace attack failed")
                        self.reset_sen_adv()

            if self.enable_full_adv:
                adv_X_after_sen_adv = torch.add(X, torch.matmul(self.sen_adv.detach(), sen_dirs))

                with torch.no_grad():
                    pred = self.model(adv_X_after_sen_adv)
                    loss_after_sen_adv = self.criterion(pred, y)

                # original space attack
                for full_adv_iter in range(self.n_full_adv_iter):
                    self.full_adv_optim.zero_grad()

                    adv_X = torch.add(adv_X_after_sen_adv, self.full_adv)
                    pred = self.model(adv_X)

                    loss = torch.neg(self.criterion(pred, y))
                    dist_loss = self.fair_dist(X, adv_X, proj_mat)
                    loss += self.lamb * dist_loss

                    loss.backward()
                    self.full_adv_optim.step()

                    if full_adv_iter == self.n_full_adv_iter - 1:
                        if torch.neg(loss) <= loss_after_sen_adv:
                            print("=> Original space attack failed")
                            self.reset_sen_adv()

                        mean_dist = dist_loss.detach().item()
                        self.lamb = max(1e-5, self.lamb + (max(mean_dist, self.eps) / min(mean_dist, self.eps)) * (
                                mean_dist - self.eps))

            self.model_optim.zero_grad()
            adv_X = torch.add(X, torch.matmul(self.sen_adv, sen_dirs))
            if self.enable_full_adv:
                adv_X = torch.add(adv_X, self.full_adv)
            pred = self.model(adv_X)
            loss = self.criterion(pred, y)
            loss.backward()
            self.model_optim.step()

            self.model_skd.step()
            self.sen_adv_skd.step()
            if self.enable_full_adv:
                self.full_adv_skd.step()

            if verbose:
                print("Iter [%d|%d], loss: %.5f" % (i, self.n_iter, loss.item()))

        return

    def infer(self, X):
        self.model.eval()
        X = torch.from_numpy(X).float().to(self.device)
        with torch.no_grad():
            output = self.model(X)
        output = output.cpu().numpy()
        return output

    def pred(self, X):
        pred = self.infer(X)
        pred_label = (pred > self.pred_threshold)
        return pred_label

    def pred_proba(self, X):
        pred = self.infer(X)
        return pred
