from abc import ABC, abstractmethod
import numpy as np
from tqdm import tqdm
from typing import Any, Sequence

from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
import torch
from torch import nn
import torch.nn.functional as F


class IndivFairModel(ABC):
    """ Abstract base class for models """

    @abstractmethod
    def fit(self, *args, **kwargs) -> None:
        raise NotImplementedError

    @abstractmethod
    def pred(self, X: np.ndarray) -> np.ndarray:
        raise NotImplementedError

    @abstractmethod
    def pred_proba(self, X: np.ndarray) -> np.ndarray:
        raise NotImplementedError


class LogReg(IndivFairModel):
    """ Vanilla logistic regression """

    def __init__(self, l2_reg=1., fit_intercept=False):
        super(LogReg, self).__init__()
        self.l2_reg = l2_reg
        self.model = LogisticRegression(
            penalty="l2",
            C=(1. / l2_reg),
            fit_intercept=fit_intercept,
            multi_class="ovr",
            warm_start=False,
            max_iter=2048,
        )

    def fit(self, X: np.ndarray, y: np.ndarray):
        self.model.fit(X, y)
        return

    def pred(self, X):
        return self.model.predict(X)

    def pred_proba(self, X):
        return self.model.predict_proba(X)[:, 1]


class TorchLogReg(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()
        self.linear = nn.Linear(input_dim, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return self.sigmoid(self.linear(x)).squeeze()

    def predict(self, x):
        return (0.5 <= self(x)).int()

    def logits(self, x):
        return self(x)


class Autoencoder(nn.Module):
    def __init__(self, input_dim: int, factor: float = 1.):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, int(input_dim // factor), bias=True),
            nn.ReLU(),
            nn.Linear(int(input_dim // factor), int(input_dim // factor), bias=True),
        )
        self.decoder = nn.Sequential(
            nn.Linear(int(input_dim // factor), int(input_dim // factor), bias=True),
            nn.ReLU(),
            nn.Linear(int(input_dim // factor), input_dim, bias=True),
        )

    def forward(self, x):
        return self.decoder(self.encoder(x))

    def encode(self, x):
        return self.encoder(x)

    def decode(self, x):
        return self.decoder(x)


class RandomForest(IndivFairModel):
    """ Random Forest Classifier """

    def __init__(self):
        super(RandomForest, self).__init__()
        self.model = RandomForestClassifier(
            criterion="gini",
        )

    def fit(self, X: np.ndarray, y: np.ndarray):
        self.model.fit(X, y)
        return

    def pred(self, X):
        return self.model.predict(X)

    def pred_proba(self, X):
        return self.model.predict_proba(X)[:, 1]


class MLPClassifier(nn.Module):
    def __init__(self, input_dim: int, factor=1):
        super(MLPClassifier, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_dim, input_dim // factor, bias=True),
            nn.ReLU(),
            nn.Linear(input_dim // factor, input_dim // factor, bias=True),
            nn.ReLU(),
            nn.Linear(input_dim // factor, 1, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.layer(x)
        return x.squeeze()


class NeuralNets(IndivFairModel):
    """ Vanilla neural networks in GPU """

    def __init__(self, input_dim: int, device="cuda:1"):
        super(NeuralNets, self).__init__()
        self.input_dim = input_dim
        self.n_iter = 10000
        self.lr = 1e-1
        self.l2_reg = 1e-2
        self.pred_threshold = 0.5
        self.device = torch.device(device)

        self.model = MLPClassifier(input_dim=self.input_dim).to(self.device)
        self.optimizer = torch.optim.SGD(self.model.parameters(), self.lr, weight_decay=self.l2_reg)

        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.n_iter // 4, gamma=0.5)

    def fit(self, X: np.ndarray, y: np.ndarray):
        # weight = len(y) / (len(np.unique(y)) * np.bincount(y))
        # weight = {0: weight[0], 1: weight[1]}
        # weight = torch.from_numpy(np.vectorize(weight.get)(y)).float().to(self.device)

        X_tensor = torch.from_numpy(X).float().to(self.device)
        y_tensor = torch.from_numpy(y).float().to(self.device)

        # self.criterion = nn.BCELoss(weight=weight)
        self.criterion = nn.BCELoss()
        self.model.train()
        for _ in range(self.n_iter):
            self.optimizer.zero_grad()
            pred = self.model(X_tensor)

            loss = self.criterion(pred, y_tensor)
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()

        return

    def infer(self, X):
        self.model.eval()
        X = torch.from_numpy(X).float().to(self.device)
        with torch.no_grad():
            output = self.model(X)
        output = output.cpu().numpy()
        return output

    def pred(self, X):
        pred = self.infer(X)
        pred_label = (pred > self.pred_threshold)
        return pred_label

    def pred_proba(self, X):
        pred = self.infer(X)
        return pred


class SenDrop(IndivFairModel):
    """ Simply discard the sensitive features """

    def __init__(self, drop_idx: Sequence[int], model: IndivFairModel):
        super(SenDrop, self).__init__()
        self.model = model
        self.drop_idx = drop_idx

    def fit(self, X, y):
        all_idx = [i for i in range(X.shape[1])]
        self.remain_idx = list(set(all_idx) - set(self.drop_idx))
        self.model.fit(np.take(X, self.remain_idx, axis=1), y)
        return

    def pred(self, X):
        return self.model.pred(np.take(X, self.remain_idx, axis=1))

    def pred_proba(self, X):
        return self.model.pred_proba(np.take(X, self.remain_idx, axis=1))
