import torch
import torch.nn as nn
import torch.nn.functional as F
from baseline import baselineNN


class VFair(nn.Module):

    def __init__(
            self,
            embedding_size,
            n_num_cols,
            learner_hidden_units=[64, 32],
            batch_size=256,
            activation_fn=nn.ReLU,
            device='cpu',
            phi=0.9,
            epsilon=3,
            train_dataset=None):
        super().__init__()
        # torch.autograd.set_detect_anomaly(True)

        self.device = device
        # TODO: its a 2 classification, edit to multi-class
        self.learner = baselineNN(
            embedding_size,
            n_num_cols,
            learner_hidden_units,
            activation_fn=activation_fn,
            device=device
        )
        self.phi = 1 - phi
        self.learner.to(device)
        self.global_data = train_dataset
        self.epsilon = epsilon

    def learner_step(self, x_cat, x_num, targets):
        self.learner.zero_grad()
        logits, sig, _ = self.learner(x_cat, x_num)
        # select

        selected_idx = torch.where(torch.abs(targets - sig) > self.phi)[0]
        if len(selected_idx) == 0:
            return None, None, {'full_loss': 0, 'variance': 0, "lambda": 0, "cons1": 0, "cons2": 0}

        loss = F.binary_cross_entropy_with_logits(logits[selected_idx], targets[selected_idx], reduction='none')
        g = torch.mean(loss)

        with torch.no_grad():
            full_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")

        # get global mean loss
        if len(x_cat) & (len(x_cat) - 1) != 0:
            # full batch
            mean_loss = torch.mean(full_loss)
        else:
            # mini-batch
            g_cat, g_num, g_tar = self.global_data[:]
            with torch.no_grad():
                g_log, _, _ = self.learner(g_cat.to('cuda'), g_num.to('cuda'))
                mean_loss = F.binary_cross_entropy_with_logits(g_log.squeeze(), g_tar.to('cuda'), reduction="mean")

        f = torch.sqrt(torch.mean(torch.stack([torch.square(los - mean_loss) for los in loss])))

        g.backward(retain_graph=True)
        grad_g = self.flatten_grads().detach().clone()

        self.learner.zero_grad()
        f.backward()
        grad_f = self.flatten_grads().detach().clone()

        min_loss = torch.min(loss)

        cons1 = self.epsilon - torch.dot(grad_f, grad_g) / torch.square(torch.norm(grad_g))
        cons2 = max(0, (mean_loss - min_loss) / f)

        lam = max(cons1, cons2)
        grad_final = grad_f + lam * grad_g
        self.assign_grads(grad_final)
        logging_dict = {
            'full_loss': mean_loss,
            'cons1': cons1,
            'cons2': cons2,
            'lam': lam,
            'selected': len(selected_idx)
        }
        return cons1, cons2, logging_dict

    def flatten_grads(self):
        """
        Flattens the gradients of a model (after `.backward()` call) as a single, large vector.
        :return: 1D torch Tensor
        """
        all_grads = []
        for param in self.parameters():
            if param.grad is not None:
                all_grads.append(param.grad.view(-1))
        return torch.cat(all_grads)

    def assign_grads(self, grads):
        """
        Similar to `assign_weights` but this time, manually assign `grads` vector to a model.
        :param grads: Gradient vectors.
        """
        state_dict = self.state_dict(keep_vars=True)
        index = 0
        for param in state_dict.keys():
            # ignore batchnorm params
            if state_dict[param].grad is None:
                continue
            if 'running_mean' in param or 'running_var' in param or 'num_batches_tracked' in param:
                continue
            param_count = state_dict[param].numel()
            param_shape = state_dict[param].shape
            state_dict[param].grad = grads[index:index + param_count].view(param_shape).clone()
            index += param_count
        self.load_state_dict(state_dict)
        return
