import numpy as np
from tqdm import tqdm
from typing import Callable, Sequence

import torch
from torch import nn

from ..base import IndivFairModel, MLPClassifier


class DRO(IndivFairModel):
    """ Distributionally robust optimization """

    def __init__(self, sen_idx: Sequence[int], drop_sen=False, device="cuda:1"):
        super(DRO, self).__init__()
        self.sen_idx = sen_idx
        self.drop_sen = drop_sen
        self.n_iter = 10000
        self.lr = 1e-1
        self.l2_reg = 0.
        self.pred_threshold = 0.5
        self.device = torch.device(device)

    def fit(self, X: np.ndarray, y: np.ndarray, antidote_X: np.ndarray, comp_mat: np.ndarray):
        if self.drop_sen:
            all_idx = [i for i in range(X.shape[1])]
            self.remain_idx = list(set(all_idx) - set(self.sen_idx))
            X = np.take(X, self.remain_idx, axis=1)
            antidote_X = np.take(antidote_X, self.remain_idx, axis=2)

        X_tensor = torch.from_numpy(X).float().to(self.device)
        y_tensor = torch.from_numpy(y).float().to(self.device)
        anti_X_tensor = torch.from_numpy(antidote_X).float().to(self.device)
        anti_y_tensor = y_tensor.unsqueeze(1).repeat(1, anti_X_tensor.shape[1])
        comp_mat = torch.from_numpy(comp_mat).to(self.device)

        self.model = MLPClassifier(input_dim=X_tensor.size(1)).to(self.device)
        self.optimizer = torch.optim.SGD(self.model.parameters(), self.lr, weight_decay=self.l2_reg)

        self.criterion = nn.BCELoss()
        self.anti_criterion = nn.BCELoss(reduction="none")
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.n_iter // 4, gamma=0.5)

        self.model.train()
        for _ in tqdm(range(self.n_iter), desc="Training neural networks with DRO..."):
            self.optimizer.zero_grad()

            pred = self.model(X_tensor)
            loss = self.criterion(pred, y_tensor)
            loss.backward(retain_graph=True)

            anti_pred = self.model(anti_X_tensor)
            anti_loss = self.anti_criterion(anti_pred, anti_y_tensor)
            anti_loss = anti_loss * comp_mat
            anti_loss = torch.mean(torch.max(anti_loss, dim=1)[0])
            anti_loss.backward()

            self.optimizer.step()
            self.scheduler.step()

        return

    def infer(self, X):
        if self.drop_sen:
            X = np.take(X, self.remain_idx, axis=1)
        self.model.eval()
        X = torch.from_numpy(X).float().to(self.device)
        with torch.no_grad():
            output = self.model(X)
        output = output.cpu().numpy()
        return output

    def pred(self, X):
        pred = self.infer(X)
        pred_label = (pred > self.pred_threshold)
        return pred_label

    def pred_proba(self, X):
        pred = self.infer(X)
        return pred
