###########################################################################
#
# Much of this code was taken and edited from the following GitHub repo:
# https://github.com/kohpangwei/group_DRO/tree/master
#
###########################################################################
import torch
from scipy.spatial.distance import cdist, pdist, squareform
import numpy as np
from sklearn.gaussian_process.kernels import Kernel

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')



class LossComputer:
    '''
    Used for calcuting the GDRO loss. Much of this code was taken and edited from
    the following GitHub repo:
    https://github.com/kohpangwei/group_DRO/tree/master
    '''
    
    def __init__(self, criterion, is_robust, dataset_info, alpha=0.2, gamma=0.1, adj=None, min_var_weight=0, step_size=0.01, normalize_loss=False, btl=False):

        self.criterion = criterion
        self.is_robust = is_robust
        self.gamma = gamma
        self.alpha = alpha
        self.min_var_weight = min_var_weight
        self.step_size = step_size
        self.normalize_loss = normalize_loss
        self.btl = btl

        self.n_groups = dataset_info['n_groups']
        self.group_counts = dataset_info['group_counts'].cuda()

        self.group_frac = self.group_counts/self.group_counts.sum()

        if adj is not None:
            self.adj = torch.from_numpy(adj).float().cuda()
        else:
            self.adj = torch.zeros(self.n_groups).float().cuda()

        if is_robust:
            assert alpha, 'alpha must be specified'

        # quantities maintained throughout training
        self.adv_probs = torch.ones(self.n_groups).cuda()/self.n_groups
        self.exp_avg_loss = torch.zeros(self.n_groups).cuda()
        self.exp_avg_initialized = torch.zeros(self.n_groups).byte().cuda()

        self.reset_stats()

    def loss(self, yhat, y, dataset, group_idx=None, is_training=False):
        # compute per-sample and per-group losses
        per_sample_losses = self.criterion(yhat, y)
        group_loss, group_count = self.compute_group_avg(per_sample_losses, group_idx)
        predictions = torch.where(yhat >= 0.5, 1, 0)
        group_acc, group_count = self.compute_group_avg((predictions==y).float(), group_idx)

        # update historical losses
        self.update_exp_avg_loss(group_loss, group_count)

        # compute overall loss
        if self.is_robust and not self.btl:
            actual_loss, weights = self.compute_robust_loss(group_loss, group_count)
        elif self.is_robust and self.btl:
             actual_loss, weights = self.compute_robust_loss_btl(group_loss, group_count)
        else:
            actual_loss = per_sample_losses.mean()
            weights = None

        # update stats
        self.update_stats(actual_loss, group_loss, group_acc, group_count, weights)

        return actual_loss

    def compute_robust_loss(self, group_loss, group_count):
        adjusted_loss = group_loss
        if torch.all(self.adj>0):
            adjusted_loss += self.adj/torch.sqrt(self.group_counts)
        if self.normalize_loss:
            adjusted_loss = adjusted_loss/(adjusted_loss.sum())
        self.adv_probs = self.adv_probs * torch.exp(self.step_size*adjusted_loss.data)
        self.adv_probs = self.adv_probs/(self.adv_probs.sum())

        robust_loss = group_loss @ self.adv_probs
        return robust_loss, self.adv_probs

    def compute_robust_loss_btl(self, group_loss, group_count):
        adjusted_loss = self.exp_avg_loss + self.adj/torch.sqrt(self.group_counts)
        return self.compute_robust_loss_greedy(group_loss, adjusted_loss)

    def compute_robust_loss_greedy(self, group_loss, ref_loss):
        sorted_idx = ref_loss.sort(descending=True)[1]
        sorted_loss = group_loss[sorted_idx]
        sorted_frac = self.group_frac[sorted_idx]

        mask = torch.cumsum(sorted_frac, dim=0)<=self.alpha
        weights = mask.float() * sorted_frac /self.alpha
        last_idx = mask.sum()
        weights[last_idx] = 1 - weights.sum()
        weights = sorted_frac*self.min_var_weight + weights*(1-self.min_var_weight)

        robust_loss = sorted_loss @ weights

        # sort the weights back
        _, unsort_idx = sorted_idx.sort()
        unsorted_weights = weights[unsort_idx]
        return robust_loss, unsorted_weights

    def compute_group_avg(self, losses, group_idx):
        # compute observed counts and mean loss for each group
        group_map = (group_idx == torch.arange(self.n_groups).unsqueeze(1).long().cuda()).float()
        group_count = group_map.sum(1)
        group_denom = group_count + (group_count==0).float() # avoid nans
        group_loss = (group_map @ losses.view(-1))/group_denom
        return group_loss, group_count

    def update_exp_avg_loss(self, group_loss, group_count):
        prev_weights = (1 - self.gamma*(group_count>0).float()) * (self.exp_avg_initialized>0).float()
        curr_weights = 1 - prev_weights
        self.exp_avg_loss = self.exp_avg_loss * prev_weights + group_loss*curr_weights
        self.exp_avg_initialized = (self.exp_avg_initialized>0) + (group_count>0)

    def reset_stats(self):
        self.processed_data_counts = torch.zeros(self.n_groups).cuda()
        self.update_data_counts = torch.zeros(self.n_groups).cuda()
        self.update_batch_counts = torch.zeros(self.n_groups).cuda()
        self.avg_group_loss = torch.zeros(self.n_groups).cuda()
        self.avg_group_acc = torch.zeros(self.n_groups).cuda()
        self.avg_per_sample_loss = 0.
        self.avg_actual_loss = 0.
        self.avg_acc = 0.
        self.batch_count = 0.

    def update_stats(self, actual_loss, group_loss, group_acc, group_count, weights=None):
        # avg group loss
        denom = self.processed_data_counts + group_count
        denom += (denom==0).float()
        prev_weight = self.processed_data_counts/denom
        curr_weight = group_count/denom
        self.avg_group_loss = prev_weight*self.avg_group_loss + curr_weight*group_loss

        # avg group acc
        self.avg_group_acc = prev_weight*self.avg_group_acc + curr_weight*group_acc

        # batch-wise average actual loss
        denom = self.batch_count + 1
        self.avg_actual_loss = (self.batch_count/denom)*self.avg_actual_loss + (1/denom)*actual_loss

        # counts
        self.processed_data_counts += group_count
        if self.is_robust:
            self.update_data_counts += group_count*((weights>0).float())
            self.update_batch_counts += ((group_count*weights)>0).float()
        else:
            self.update_data_counts += group_count
            self.update_batch_counts += (group_count>0).float()
        self.batch_count+=1

        # avg per-sample quantities
        group_frac = self.processed_data_counts/(self.processed_data_counts.sum())
        self.avg_per_sample_loss = group_frac @ self.avg_group_loss
        self.avg_acc = group_frac @ self.avg_group_acc

    def get_model_stats(self, model, args, stats_dict):
        model_norm_sq = 0.
        for param in model.parameters():
            model_norm_sq += torch.norm(param) ** 2
        stats_dict['model_norm_sq'] = model_norm_sq.item()
        stats_dict['reg_loss'] = args.weight_decay / 2 * model_norm_sq.item()
        return stats_dict

    def get_stats(self, model=None, args=None):
        stats_dict = {}
        for idx in range(self.n_groups):
            stats_dict[f'avg_loss_group:{idx}'] = self.avg_group_loss[idx].item()
            stats_dict[f'exp_avg_loss_group:{idx}'] = self.exp_avg_loss[idx].item()
            stats_dict[f'avg_acc_group:{idx}'] = self.avg_group_acc[idx].item()
            stats_dict[f'processed_data_count_group:{idx}'] = self.processed_data_counts[idx].item()
            stats_dict[f'update_data_count_group:{idx}'] = self.update_data_counts[idx].item()
            stats_dict[f'update_batch_count_group:{idx}'] = self.update_batch_counts[idx].item()

        stats_dict['avg_actual_loss'] = self.avg_actual_loss.item()
        stats_dict['avg_per_sample_loss'] = self.avg_per_sample_loss.item()
        stats_dict['avg_acc'] = self.avg_acc.item()

        # Model stats
        if model is not None:
            assert args is not None
            stats_dict = self.get_model_stats(model, args, stats_dict)

        return stats_dict

    def log_stats(self, logger, is_training):
        if logger is None:
            return

        logger.write(f'Average incurred loss: {self.avg_per_sample_loss.item():.3f}  \n')
        logger.write(f'Average sample loss: {self.avg_actual_loss.item():.3f}  \n')
        logger.write(f'Average acc: {self.avg_acc.item():.3f}  \n')
        for group_idx in range(self.n_groups):
            logger.write(
                f'  {self.group_str(group_idx)}  '
                f'[n = {int(self.processed_data_counts[group_idx])}]:\t'
                f'loss = {self.avg_group_loss[group_idx]:.3f}  '
                f'exp loss = {self.exp_avg_loss[group_idx]:.3f}  '
                f'adjusted loss = {self.exp_avg_loss[group_idx] + self.adj[group_idx]/torch.sqrt(self.group_counts)[group_idx]:.3f}  '
                f'adv prob = {self.adv_probs[group_idx]:3f}   '
                f'acc = {self.avg_group_acc[group_idx]:.3f}\n')
        logger.flush()



class MMD_Loss(torch.nn.Module):
    '''
    Used for calcuting the MMD loss.
    '''

    def __init__(self, mmd_sigma):
        super(MMD_Loss, self).__init__()
        self.mmd_sigma = mmd_sigma


    def rbf_kernel(self, x, y):

        x_norm = (x ** 2).sum(1).view(-1, 1)
        y_norm = (y ** 2).sum(1).view(1, -1)
        kernel_matrix = torch.clamp(x_norm + y_norm - 2.0 * torch.matmul(x, torch.transpose(y, 0, 1)), min=0)

        gamma = - 0.5 / (self.mmd_sigma ** 2)
        kernel_matrix = torch.exp(gamma * kernel_matrix)

        return kernel_matrix


    def mmd(self, x, y):

        xx_kernel = self.rbf_kernel(x, x)
        yy_kernel = self.rbf_kernel(y, y)
        xy_kernel = self.rbf_kernel(x, y)

        mmd_loss = torch.mean(xx_kernel) + torch.mean(yy_kernel) - 2 * torch.mean(xy_kernel)
        mmd_loss = torch.max(torch.tensor(0), mmd_loss)

        return mmd_loss


    def forward(self, yhat, y, aux_labels):

        # Split data into different groups
        aux_labels = torch.unsqueeze(aux_labels, dim=1)
        
        # Group y == 1
        y_1_aux_1_indices = aux_labels * y != 0
        y_1_aux_0_indices = (1 - aux_labels) * y != 0
        y_1_aux_1 = yhat[y_1_aux_1_indices]
        y_1_aux_0 = yhat[y_1_aux_0_indices]

        # Group y == 0
        y_0_aux_1_indices = aux_labels * (1 - y) != 0
        y_0_aux_0_indices = (1 - aux_labels) * (1 - y) != 0
        y_0_aux_1 = yhat[y_0_aux_1_indices]
        y_0_aux_0 = yhat[y_0_aux_0_indices]

        # Expand dims and convert to log probs
        y_1_aux_1 = torch.log(torch.unsqueeze(y_1_aux_1, 1))
        y_1_aux_0 = torch.log(torch.unsqueeze(y_1_aux_0, 1))
        y_0_aux_1 = torch.log(torch.unsqueeze(y_0_aux_1, 1))
        y_0_aux_0 = torch.log(torch.unsqueeze(y_0_aux_0, 1))

        # Get mmd loss
        mmd_loss_1 = self.mmd(y_1_aux_1, y_1_aux_0)
        mmd_loss_0 = self.mmd(y_0_aux_1, y_0_aux_0)
        mmd_loss = mmd_loss_1 + mmd_loss_0

        return mmd_loss



class KCIT_Loss(torch.nn.Module):
    def __init__(self, epsilon=1e-3, sigma=0):
        super(KCIT_Loss, self).__init__()
        self.epsilon = epsilon
        self.sigma = sigma

    # kernel width using median trick
    def set_width_median(self, x, y=None):
        if y == None:
            x = x.detach().cpu().numpy()
            dists = pdist(x, 'euclidean')
            median_dist = np.median(dists[dists > 0])
        else:
            x = x.detach().cpu().numpy()
            y = y.detach().cpu().numpy()
            dists = cdist(x, y, 'euclidean')
            median_dist = np.median(dists[dists > 0])

        width = np.sqrt(2.) * median_dist
        gamma = - 0.5 / (width ** 2)
        return gamma


    def rbf_kernel(self, x, y=None):
        if y is not None:
            x_norm = (x ** 2).sum(1).view(-1, 1)
            y_norm = (y ** 2).sum(1).view(1, -1)
            kernel_matrix = torch.clamp(x_norm + y_norm - 2.0 * x @ y.T, min=0)
        else:
            norm = torch.sum(x**2, axis=-1)
            kernel_matrix = norm[:, None] + norm[None, :] - 2 * torch.matmul(x, torch.transpose(x, dim0=1, dim1=0))

        if self.sigma == 0:
            gamma = self.set_width_median(x, y)
        else:
            gamma = -0.5 / (self.sigma ** 2)
        kernel_matrix = torch.exp(gamma * kernel_matrix)

        return kernel_matrix


    def center_kernel_matrix(self, kernel_matrix):
        n = kernel_matrix.shape[0]
        h = torch.eye(n).cuda() - (1/n) * torch.ones(size=(n,n)).cuda()
        centered_kernel_matrix = h.matmul(kernel_matrix).matmul(h)

        return centered_kernel_matrix


    def forward(self, yhat, x, z):
        n = x.shape[0]

        # Reshape tensors so that the shape is (batch_size, <feature_size>)
        x = torch.reshape(x, (x.shape[0], -1))
        z = torch.reshape(z, (z.shape[0], -1))

        # Get the kernel matrices
        kx = self.rbf_kernel(x)
        ky = self.rbf_kernel(yhat)
        kz = self.rbf_kernel(z)


        # Center the kernel matrices so that the mean of each column is 0
        ckx = self.center_kernel_matrix(kx)
        cky = self.center_kernel_matrix(ky)
        ckz = self.center_kernel_matrix(kz)

        rz = self.epsilon * torch.linalg.inv(ckz + self.epsilon * torch.eye(n).cuda())
        kxz = rz.matmul(ckx).matmul(rz)
        kyz = rz.matmul(cky).matmul(rz)

        test_statistic = torch.trace(kxz.matmul(kyz)) / n
        return test_statistic



class IRM:
    """
    Code originally from: https://github.com/YyzHarry/SubpopBench/blob/main/subpopbench/learning/algorithms.py
    Invariant Risk Minimization
    """
    def __init__(self, loss_function, irm_penalty):
        self.loss_function = loss_function
        self.irm_penalty = irm_penalty

    def _irm_penalty(self, yhat, y):
        scale = torch.tensor(1.).to(device).requires_grad_()
        loss_1 = self.loss_function(yhat[::2] * scale, y[::2])
        loss_2 = self.loss_function(yhat[1::2] * scale, y[1::2])
        grad_1 = torch.autograd.grad(loss_1, [scale], create_graph=True)[0]
        grad_2 = torch.autograd.grad(loss_2, [scale], create_graph=True)[0]
        result = torch.sum(grad_1 * grad_2)
        return result

    def loss(self, yhat, y, aux_labels):
        avg_loss = 0
        penalty = 0
        for env in range(2):
            # Get indices for the enviornments
            if env == 0:
                env_ids = y == aux_labels
            else:
                env_ids = y != aux_labels
            # If no samples in environment, set loss and penalty to 0
            if len(yhat[env_ids]) == 0:
                avg_loss += 0
                penalty += 0
                continue
            # Calculate loss and penalty
            avg_loss += self.loss_function(yhat[env_ids], y[env_ids])
            penalty += self._irm_penalty(yhat[env_ids], y[env_ids])
        avg_loss /= 2
        penalty /= 2
        loss_value = avg_loss + (self.irm_penalty * penalty)
        return loss_value