import numpy as np
from typing import Mapping, Sequence, Tuple, Callable, Dict

import torch
from torch import nn, Tensor
import torch.nn.functional as F

from .base import IndivFairModel, TorchLogReg, Autoencoder
from ._lcifr_utils import LT, Negate


class CategoricalBox():
    def __init__(self, a, b, cat):
        self.a = a
        self.b = b
        self.cat = cat
        self.name = 'categorical_box'

    def project(self, x):
        x = torch.max(torch.min(x, self.b), self.a)
        for c, ids in self.cat.items():
            x[:, ids] = F.normalize(x[:, ids], p=1, dim=1)
        return x

    def sample(self) -> Tensor:
        x = (self.b - self.a) * torch.rand(self.a.size()).to(self.a.device) + self.a
        batch_size = self.a.size()[0]
        for c, ids in self.cat.items():
            nz = torch.randint(ids[0], ids[-1] + 1, (batch_size,))
            x[:, ids] = 0.0
            x[np.arange(batch_size), nz] = 1.0

        return x


class CatConstraint():
    def __init__(
            self,
            model: nn.Module,
            sen_feat2idx: Dict[str, Sequence[int]],
            num_feat_idx: Sequence[int],
            delta: float = 0.01,
            epsilon: float = 0.3,
    ):
        self.model = model
        self.sen_feat2idx = sen_feat2idx
        self.num_feat_idx = num_feat_idx
        self.delta = delta
        self.epsilon = epsilon + 1e-4

        self.sen_feat_idx = []
        for seq in self.sen_feat2idx.values():
            self.sen_feat_idx.extend(seq)

    def get_domain(self, X: Tensor) -> CategoricalBox:
        device = X.device

        epsilon = torch.zeros(1, X.shape[1]).to(device, dtype=X.dtype)
        epsilon[:, self.num_feat_idx] = self.epsilon
        lb = X - epsilon
        ub = X + epsilon
        lb[:, self.sen_feat_idx] = -1e-4
        ub[:, self.sen_feat_idx] = +1e-4

        return CategoricalBox(lb, ub, self.sen_feat2idx)

    def get_condition(self, X: Tensor, Z: Tensor) -> LT:
        latent_data = self.model.encode(X)
        latent_adv = self.model.encode(Z)
        l_inf = torch.abs(latent_data - latent_adv).max(1)[0]

        return LT(l_inf, self.delta)

    def loss(self, X: Tensor, Z: Tensor):
        constr = self.get_condition(X, Z)
        neg_losses = Negate(constr).loss()
        pos_losses = constr.loss()
        sat = constr.satisfy()

        return neg_losses, pos_losses, sat


class Oracle():
    """ Supervised DL2 Oracle in "DL2: Training and Querying Neural Networks with Logic" """

    def __init__(self, model: nn.Module, constraint: CatConstraint, lr: float = 0.05):
        self.model = model
        self.constraint = constraint
        self.lr = lr

    def general_attack(self, X: Tensor, domains: CategoricalBox, num_restarts: int = 1, num_iters: int = 25):
        for retry in range(num_restarts):
            Z = domains.sample()
            for _ in range(num_iters):
                Z = Z.clone().detach().requires_grad_(True)
                neg_losses, _, _ = self.constraint.loss(X, Z)
                avg_neg_loss = torch.mean(neg_losses)
                avg_neg_loss.backward()
                Z = Z - self.lr * torch.sign(Z.grad.data)
                Z = domains.project(Z)

            return Z

    def eval(self, X, Z):
        neg_losses, pos_losses, sat = self.constraint.loss(X, Z)
        if not isinstance(sat, np.ndarray):
            sat = sat.cpu().numpy()
        return neg_losses, pos_losses, sat


class LCIFR(IndivFairModel):
    """ Learning Certified Individually Fair Representations """

    def __init__(
            self,
            sen_feat2idx: Dict[str, Sequence[int]],
            num_feat_idx: Sequence[int],
            dl2_weight: float = 1,
            dec_weight: float = 0.,
            n_epochs: int = 100,
            lr: float = 1e-3,
            delta: float = 1e-2,
            pred_threshold: float = 0.5,
            device: str = "cuda:1"
    ):
        self.sen_feat2idx = sen_feat2idx
        self.num_feat_idx = num_feat_idx
        self.dl2_weight = dl2_weight
        self.dec_weight = dec_weight
        self.n_epochs = n_epochs
        self.lr = lr
        self.delta = delta
        self.pred_threshold = pred_threshold
        self.device = torch.device(device)

    def _init_model(self):
        self.autoencoder = Autoencoder(self.input_dim).to(self.device)
        self.clf = TorchLogReg(self.input_dim).to(self.device)
        self.criterion = nn.BCELoss()
        self.optimizer = torch.optim.Adam(
            list(self.autoencoder.parameters()) + list(self.clf.parameters()),
            self.lr,
            weight_decay=1e-2,
        )
        self.clf_optimizer = torch.optim.Adam(
            self.clf.parameters(),
            self.lr,
            weight_decay=1e-2,
        )
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, 'min', patience=5, factor=0.5
        )
        self.clf_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.clf_optimizer, 'min', patience=5, factor=0.5
        )
        self.oracle = Oracle(
            model=self.autoencoder,
            constraint=CatConstraint(
                model=self.autoencoder,
                sen_feat2idx=self.sen_feat2idx,
                num_feat_idx=self.num_feat_idx,
            )
        )

    def _train(self, X: Tensor, y: Tensor) -> float:
        latent = self.autoencoder.encode(X)
        decode = self.autoencoder.decode(latent)
        l2_loss = torch.norm(decode - latent, dim=1)
        cls_loss = self.criterion(self.clf(latent), y)

        domains = self.oracle.constraint.get_domain(X)
        Z = self.oracle.general_attack(X, domains)
        _, dl2_loss, _ = self.oracle.eval(X, Z)

        loss = torch.mean(cls_loss + self.dl2_weight * dl2_loss + self.dec_weight * l2_loss)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def _train_clf(self, X: Tensor, y: Tensor) -> float:
        """ Step 2: train the classifier with adversarial attack """

        latent = self.autoencoder.encode(X)
        latent = self._pdg_attack(latent, y)
        loss = self.criterion(self.clf(latent), y)
        self.clf_optimizer.zero_grad()
        loss.backward()
        self.clf_optimizer.step()

        return loss.item()

    def _pdg_attack(self, latent: Tensor, y: Tensor, num_iter: int = 20, scale: float = 20.) -> Tensor:
        """ Project Gradient Descent attack """

        latent = latent.clone().detach()
        latent_min, latent_max = latent - self.delta, latent + self.delta

        latent = latent + torch.mul(self.delta, torch.rand_like(latent, device=self.device).uniform_(-1, 1))
        for _ in range(num_iter):
            latent = latent.clone().detach()
            latent = latent.requires_grad_(True)

            pred = self.clf(latent)
            self.clf.zero_grad()
            loss = self.criterion(pred, y)
            loss.backward()

            latent = latent + torch.mul(self.delta / scale, latent.grad.sign())
            latent.clamp_(min=float("-inf"), max=float("inf"))
            latent = torch.max(latent_min, latent)
            latent = torch.min(latent_max, latent)

        return latent.detach()

    def fit(self, X: np.ndarray, y: np.ndarray, verbose=True):
        self.input_dim = X.shape[1]
        self._init_model()

        X = torch.from_numpy(X).float().to(self.device)
        y = torch.from_numpy(y).float().to(self.device)

        self.autoencoder.train()
        self.clf.train()
        for i in range(self.n_epochs):
            loss = self._train(X, y)
            self.scheduler.step(loss)

            if verbose:
                print("Step 1: Iter [%d|%d], loss: %.5f" % (i, self.n_epochs, loss))

        self.autoencoder.requires_grad_(False)
        for i in range(self.n_epochs):
            loss = self._train_clf(X, y)
            self.clf_scheduler.step(loss)

            if verbose:
                print("Step 2: Iter [%d|%d], loss: %.5f" % (i, self.n_epochs, loss))

        return self

    def pred(self, X: np.ndarray):
        pred_proba = self.pred_proba(X)
        pred_label = (pred_proba > self.pred_threshold)

        return pred_label

    def pred_proba(self, X: np.ndarray):
        self.autoencoder.eval()
        self.clf.eval()
        X = torch.from_numpy(X).float().to(self.device)
        with torch.no_grad():
            pred_proba = self.clf(self.autoencoder.encode(X))
        pred_proba = pred_proba.cpu().numpy()

        return pred_proba
