# +
import numpy as np
import torch
import cooper

from torch import nn


# parts of the code have been adapted from https://github.com/ryanchankh/mcr2/blob/master/loss.py

def one_hot(labels_int, n_classes):
    """Turn labels into one hot vector of K classes."""

    labels_onehot = torch.zeros(size=(len(labels_int), n_classes)).float()
    for i, y in enumerate(labels_int):
        labels_onehot[i, y] = 1.0
    return labels_onehot


def label_to_membership(targets, num_classes=None):
    """Generate a true membership matrix, and assign value to current Pi.

    Parameters:
        targets (np.ndarray): matrix with one hot labels

    Return:
        Pi: membership matirx, shape (num_classes, num_samples, num_samples)

    """

    targets = one_hot(targets, num_classes)
    num_samples, num_classes = targets.shape
    Pi = np.zeros(shape=(num_classes, num_samples, num_samples))
    for j in range(len(targets)):
        k = np.argmax(targets[j])
        Pi[k, j, j] = 1.0
    return Pi


class RateDistortion(torch.nn.Module):
    def __init__(self, gam1=1.0, gam2=1.0, eps=0.01):
        super(RateDistortion, self).__init__()
        self.gam1 = gam1
        self.gam2 = gam2
        self.eps = eps

    def rate(self, W, device='cuda:0'):
        """Empirical Discriminative Loss."""

        p, m = W.shape
        I = torch.eye(p).to(device)
        scalar = p / (m * self.eps)
        logdet = torch.logdet(I + self.gam1 * scalar * W.matmul(W.T))
        return logdet / 2.0

    def rate_for_mixture(self, W, Pi, device='cuda:0'):
        """Empirical Compressive Loss."""

        p, m = W.shape
        k, _, _ = Pi.shape
        I = torch.eye(p).to(device)
        compress_loss = 0.0
        for j in range(k):
            trPi = torch.trace(Pi[j]) + 1e-8
            scalar = p / (trPi * self.eps)
            log_det = torch.logdet(I + scalar * W.matmul(Pi[j]).matmul(W.T))
            compress_loss += log_det * trPi / m
        return compress_loss / 2.0
    
    def rate_for_continuous(self, W, kernel, device='cuda:0'):
        """Empirical Discriminative Loss with a continuous kernel."""

        p, m = W.shape
        I = torch.eye(p).to(device)
        scalar = p / (m * self.eps)
        cov = W.matmul(W.T)
        kernalized_cov = torch.mul(cov, kernel)
        logdet = torch.logdet(I + self.gam1 * scalar * kernalized_cov)
        return logdet / 2.0


class RateDistortionUnconstrained(RateDistortion):
    """
    Rate distortion loss in a unconstrained setup .
    """
    def forward(self, X, Z, device='cuda:0'):
        W = X.T

        num_classes_z = int(Z.max() + 1)
        Pi_z = label_to_membership(Z, num_classes_z)
        Pi_z = torch.tensor(Pi_z, dtype=torch.float32).to(device)

        Rz_pi = self.rate_for_mixture(W, Pi_z, device)
        Rz = self.rate(W, device)

        J_u = -Rz - Rz_pi
        return J_u


class RateDistortionUnconstrainedMultiple(RateDistortion):
    """Rate distortion loss in a unconstrained setup."""

    def forward(self, X, Z, device='cuda:0'):
        W = X.T
        
        
        Rz = self.rate(W, device)
        J_u = -Rz
        
        for z in Z:
            num_classes_z = int(z.max() + 1)
            Pi_z = label_to_membership(z, num_classes_z)
            Pi_z = torch.tensor(Pi_z, dtype=torch.float32).to(device)

            Rz_pi = self.rate_for_mixture(W, Pi_z, device)

            J_u -= Rz_pi
        return J_u


class RateDistortionConstrained(RateDistortion):
    """
    Rate distortion loss in a constrained setup for
    debiasing a single protected attribute. 
    """
    
    def forward(self, X, Y, device='cuda:0'):
        num_classes = Y.max() + 1

        W = X.T
        Pi = label_to_membership(Y.numpy(), num_classes)
        Pi = torch.tensor(Pi, dtype=torch.float32).to(device)

        R_z_pi = self.rate_for_mixture(W, Pi, device)
        R_z = self.rate(W, device)
        return R_z - R_z_pi, R_z, R_z_pi


class RateDistortionContinuous(RateDistortion):
    """
    Rate distortion loss for deleting continuous
    protected attribute with a kernel function. 
    """
    
    def forward(self, X, X_raw, kernel, device='cuda:0', scale=1.0):
        const = self.rate(X_raw.T, device)
        R_z = self.rate(X.T, device)
        eq_const = torch.abs(R_z - scale * const)
        
        R_z_K = self.rate_for_continuous(X, kernel, device)
        return -R_z_K, eq_const



