import math
import pandas as pd
import numpy as np
import random
from tqdm import tqdm
from typing import Mapping, Sequence, Tuple, Callable
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelEncoder
import torch
from torch import nn

from .base import IndivFairModel, LogReg, NeuralNets, MLPClassifier
from .sensr import Project


class LogRegSenSubspace():
    def __init__(self):
        self.basic_vectors_ = None
        self.proj_ = None

    def fit(self, X: np.ndarray, train_df: pd.DataFrame, sen_feat: Sequence[str], sen_idx: Sequence[int]):
        log_reg = LogisticRegression(
            solver="liblinear",
            penalty="l1",
        )

        le = LabelEncoder()
        sen_df = train_df[list(sen_feat)].astype("str")
        sen_y = le.fit_transform(sen_df.sum(axis=1).values)  # targets for logistic regression

        # all_idx = [i for i in range(X.shape[1])]
        # remain_idx = list(set(all_idx) - set(sen_idx))
        # sen_X = np.take(X, remain_idx, axis=1)

        # fit logistic regression to predict sensitive information
        log_reg.fit(X, sen_y)
        self.basic_vectors_ = log_reg.coef_
        self.proj_ = Project.compl_svd_projector(self.basic_vectors_)

        return


class SenSeI(IndivFairModel):
    """
    SenSeI: Sensitive Set Invariance for Enforcing Individual Fairness
    Adopted from https://github.com/IBM/inFairness
    """

    def __init__(
            self,
            input_dim: int,
            n_input: int,
            n_iter: int = 1000,
            lr: float = 1e-3,
            adv_iter: int = 100,
            adv_lr: float = 1e-3,
            rho: float = 5e+5,
            eps: float = 1e-1,
            lamb: float = 1.,
            min_lamb: float = 1e-5,
            max_noise: float = 0.1,
            min_noise: float = 0.1,
            device: str = "cuda:1",
            verbose: bool = True,
    ):
        super(SenSeI, self).__init__()

        self.input_dim = input_dim
        self.n_input = n_input
        self.n_iter = n_iter
        self.lr = lr
        self.adv_iter = adv_iter
        self.adv_lr = adv_lr
        self.rho = rho
        self.eps = eps
        self.lamb = lamb
        self.min_lamb = min_lamb
        self.max_noise = max_noise
        self.min_noise = min_noise
        self.device = torch.device(device)
        self.verbose = verbose

        self.pred_threshold = 0.5

        self.delta = torch.zeros(n_input, self.input_dim).to(self.device)
        self.eps = torch.tensor(self.eps)

        self.model = MLPClassifier(input_dim=self.input_dim).to(self.device)
        self.model_optim = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.delta_optim = torch.optim.Adam([self.delta], lr=self.adv_lr)
        self.criterion = nn.BCELoss()

        self.sen_subspace = LogRegSenSubspace()
        self.min_lamb = torch.as_tensor(self.min_lamb, dtype=torch.float32).to(self.device)

        self.model_skd = torch.optim.lr_scheduler.StepLR(self.model_optim, step_size=self.n_iter // 4, gamma=0.1)
        self.delta_skd = torch.optim.lr_scheduler.StepLR(self.delta_optim, step_size=self.n_iter // 4, gamma=0.1)

    def dist_X(self, X_1, X_2, proj):
        X_diff = X_1 - X_2
        dist = torch.sum((X_diff @ proj) * X_diff, dim=-1, keepdim=True)
        return dist

    def dist_y(self, y_1, y_2):
        return torch.pow(y_1 - y_2, 2)

    def get_worst_X(self, X, proj):
        self.reset_delta()
        for _ in range(self.adv_iter):
            self.delta_optim.zero_grad()
            X_worst = X + self.delta
            input_dist = self.dist_X(X, X_worst, proj)

            pred = self.model(X)
            pred_worst = self.model(X_worst)
            out_dist = self.dist_y(pred, pred_worst)

            loss = -(torch.mean(out_dist) - self.lamb * torch.mean(input_dist))
            loss.backward()
            self.delta_optim.step()

        return (X + self.delta).detach()

    def reset_delta(self):
        nn.init.uniform_(self.delta)
        self.delta = self.delta.mul(self.max_noise - self.min_noise) + self.min_noise
        return

    def fit(
            self,
            X: np.ndarray,
            y: np.ndarray,
            train_df: pd.DataFrame,
            sen_feat: Sequence[str],
            sen_idx: Sequence[int],
            verbose: bool = True,
    ):
        self.sen_subspace.fit(X, train_df, sen_feat, sen_idx)
        proj = torch.as_tensor(self.sen_subspace.proj_, dtype=torch.float32, device=self.device)

        X = torch.from_numpy(X).float().to(self.device)
        y = torch.from_numpy(y).float().to(self.device)

        pert = []
        pert_label = []
        self.model.train()
        for i in range(self.n_iter):
            pred = self.model(X)
            X_worst = self.get_worst_X(X, proj)

            select_idx = random.sample(range(0, X_worst.size(0)), math.ceil(float(X_worst.size(0)) / self.n_iter))
            select = np.take(X_worst.cpu().numpy(), select_idx, axis=0)
            pert.append(select)
            pert_label.append(np.take(y.cpu().numpy(), select_idx, axis=0))

            pred_worst = self.model(X_worst)

            dist_X = self.dist_X(X, X_worst, proj)
            mean_dist_X = dist_X.mean()
            lr_factor = torch.maximum(mean_dist_X, self.eps) / torch.minimum(mean_dist_X, self.eps)

            self.lamb = torch.max(
                torch.stack(
                    [self.min_lamb, self.lamb + lr_factor * (mean_dist_X - self.eps)]
                )
            )

            self.model_optim.zero_grad()
            loss = self.criterion(pred, y) + self.rho * torch.mean(self.dist_y(pred, pred_worst))
            loss.backward()
            self.model_optim.step()

            self.model_skd.step()
            self.delta_skd.step()

            if verbose:
                print("Iter [%d|%d], loss: %.5f" % (i, self.n_iter, loss.item()))

        return pert, pert_label

    def infer(self, X):
        self.model.eval()
        X = torch.from_numpy(X).float().to(self.device)
        with torch.no_grad():
            output = self.model(X)
        output = output.cpu().numpy()
        return output

    def pred(self, X):
        pred = self.infer(X)
        pred_label = (pred > self.pred_threshold)
        return pred_label

    def pred_proba(self, X):
        pred = self.infer(X)
        return pred
