# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
from torch.autograd import Variable

import copy
import numpy as np
from collections import defaultdict

from domainbed import networks
from domainbed.lib.misc import random_pairs_of_minibatches, ParamDict

ALGORITHMS = [
    'ERM',
    'Fish',
    'IRM',
    'GroupDRO',
    'Mixup',
    'MLDG',
    'CORAL',
    'MMD',
    'DANN',
    'CDANN',
    'SagNet',
    'VREx',
    'RSC',
    'CMMD',
    'CACM_ACause',
    'CACM_AInd',
    'CACM_ACauseUAInd'
]

def get_algorithm_class(algorithm_name):
    """Return the algorithm class with the given name."""
    if algorithm_name not in globals():
        raise NotImplementedError("Algorithm not found: {}".format(algorithm_name))
    return globals()[algorithm_name]


def _get_optimizer(optimizer_name, params, lr, weight_decay, betas=(0.9, 0.999)):
    if optimizer_name == 'adam':  # domainbed uses Adam by default
        return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay, betas=betas)
    elif optimizer_name == 'sgdm':
        return torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=weight_decay)
    else:
        raise NotImplementedError


# def _get_lr_scheduler(optimizer, hparams):
#     if hparams.get('lr_scheduler', None) == 'reduce_on_plateau':
#         return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience, verbose=True)
#     else:  # domainbed uses constant learning rate by default
#         return torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 1)


class Algorithm(torch.nn.Module):
    """
    A subclass of Algorithm implements a domain generalization algorithm.
    Subclasses should implement the following:
    - update()
    - predict()
    """
    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(Algorithm, self).__init__()
        self.hparams = hparams

    def update(self, minibatches, unlabeled=None):
        """
        Perform one update step, given a list of (x, y) tuples for all
        environments.

        Admits an optional list of unlabeled minibatches from the test domains,
        when task is domain_adaptation.
        """
        raise NotImplementedError

    def predict(self, x):
        raise NotImplementedError

class ERM(Algorithm):
    """
    Empirical Risk Minimization (ERM)
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(ERM, self).__init__(input_shape, num_classes, num_domains,
                                  hparams)
        self.featurizer = networks.Featurizer(input_shape, self.hparams)
        self.classifier = networks.Classifier(
            self.featurizer.n_outputs,
            num_classes,
            self.hparams['nonlinear_classifier'])

        self.network = nn.Sequential(self.featurizer, self.classifier)
        self.optimizer = _get_optimizer(
            self.hparams.get('optimizer', 'adam'),
            self.network.parameters(),
            lr=self.hparams["lr"],
            weight_decay=self.hparams['weight_decay']
        )

    def update(self, minibatches, unlabeled=None):
        if len(minibatches[0]) == 4:
            all_x = torch.cat([x for x,y,_,_ in minibatches])
            all_y = torch.cat([y for x,y,_,_ in minibatches])
        elif len(minibatches[0]) == 3:
            all_x = torch.cat([x for x,y,_ in minibatches])
            all_y = torch.cat([y for x,y,_ in minibatches])
        else:
            all_x = torch.cat([x for x,y in minibatches])
            all_y = torch.cat([y for x,y in minibatches])

        loss = F.cross_entropy(self.predict(all_x), all_y)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # return {'loss': loss.item()}
        return {'loss': loss.item()}

    def predict(self, x):
        return self.network(x)

class AbstractDANN(Algorithm):
    """Domain-Adversarial Neural Networks (abstract class)"""

    def __init__(self, input_shape, num_classes, num_domains,
                 hparams, conditional, class_balance):

        super(AbstractDANN, self).__init__(input_shape, num_classes, num_domains,
                                  hparams)

        self.register_buffer('update_count', torch.tensor([0]))
        self.conditional = conditional
        self.class_balance = class_balance

        # Algorithms
        self.featurizer = networks.Featurizer(input_shape, self.hparams)
        self.classifier = networks.Classifier(
            self.featurizer.n_outputs,
            num_classes,
            self.hparams['nonlinear_classifier'])
        self.discriminator = networks.MLP(self.featurizer.n_outputs,
            num_domains, self.hparams)
        self.class_embeddings = nn.Embedding(num_classes,
            self.featurizer.n_outputs)

        # Optimizers
        self.disc_opt = _get_optimizer(
            self.hparams.get('optimizer', 'adam'),
            (list(self.discriminator.parameters()) +
                list(self.class_embeddings.parameters())),
            lr=self.hparams["lr_d"],
            weight_decay=self.hparams['weight_decay_d'],
            betas=(self.hparams['beta1'], 0.9))

        self.gen_opt = _get_optimizer(
            self.hparams.get('optimizer', 'adam'),
            (list(self.featurizer.parameters()) +
                list(self.classifier.parameters())),
            lr=self.hparams["lr_g"],
            weight_decay=self.hparams['weight_decay_g'],
            betas=(self.hparams['beta1'], 0.9))

    def update(self, minibatches, unlabeled=None):
        device = "cuda" if minibatches[0][0].is_cuda else "cpu"
        self.update_count += 1

        if len(minibatches[0]) == 3:
            all_x = torch.cat([x for x, y, _ in minibatches])
            all_y = torch.cat([y for x, y, _ in minibatches])
            all_z = self.featurizer(all_x)
            if self.conditional:
                disc_input = all_z + self.class_embeddings(all_y)
            else:
                disc_input = all_z
            disc_out = self.discriminator(disc_input)
            disc_labels = torch.cat([
                torch.full((x.shape[0], ), i, dtype=torch.int64, device=device)
                for i, (x, y, _) in enumerate(minibatches)
            ])
        elif len(minibatches[0]) == 4:
            all_x = torch.cat([x for x, y, _, _ in minibatches])
            all_y = torch.cat([y for x, y, _, _ in minibatches])
            all_z = self.featurizer(all_x)
            if self.conditional:
                disc_input = all_z + self.class_embeddings(all_y)
            else:
                disc_input = all_z
            disc_out = self.discriminator(disc_input)
            disc_labels = torch.cat([
                torch.full((x.shape[0], ), i, dtype=torch.int64, device=device)
                for i, (x, y, _, _) in enumerate(minibatches)
            ])
        else:
            all_x = torch.cat([x for x, y in minibatches])
            all_y = torch.cat([y for x, y in minibatches])
            all_z = self.featurizer(all_x)
            if self.conditional:
                disc_input = all_z + self.class_embeddings(all_y)
            else:
                disc_input = all_z
            disc_out = self.discriminator(disc_input)
            disc_labels = torch.cat([
                torch.full((x.shape[0], ), i, dtype=torch.int64, device=device)
                for i, (x, y) in enumerate(minibatches)
            ])

        if self.class_balance:
            y_counts = F.one_hot(all_y).sum(dim=0)
            weights = 1. / (y_counts[all_y] * y_counts.shape[0]).float()
            disc_loss = F.cross_entropy(disc_out, disc_labels, reduction='none')
            disc_loss = (weights * disc_loss).sum()
        else:
            disc_loss = F.cross_entropy(disc_out, disc_labels)

        disc_softmax = F.softmax(disc_out, dim=1)
        input_grad = autograd.grad(disc_softmax[:, disc_labels].sum(),
            [disc_input], create_graph=True)[0]
        grad_penalty = (input_grad**2).sum(dim=1).mean(dim=0)
        disc_loss += self.hparams['grad_penalty'] * grad_penalty

        d_steps_per_g = self.hparams['d_steps_per_g_step']
        if (self.update_count.item() % (1+d_steps_per_g) < d_steps_per_g):

            self.disc_opt.zero_grad()
            disc_loss.backward()
            self.disc_opt.step()
            return {'disc_loss': disc_loss.item()}
        else:
            all_preds = self.classifier(all_z)
            classifier_loss = F.cross_entropy(all_preds, all_y)
            gen_loss = (classifier_loss +
                        (self.hparams['lambda'] * -disc_loss))
            self.disc_opt.zero_grad()
            self.gen_opt.zero_grad()
            gen_loss.backward()
            self.gen_opt.step()
            return {'gen_loss': gen_loss.item()}

    def predict(self, x):
        return self.classifier(self.featurizer(x))

class DANN(AbstractDANN):
    """Unconditional DANN"""
    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(DANN, self).__init__(input_shape, num_classes, num_domains,
            hparams, conditional=False, class_balance=False)


class CDANN(AbstractDANN):
    """Conditional DANN"""
    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(CDANN, self).__init__(input_shape, num_classes, num_domains,
            hparams, conditional=True, class_balance=True)


class IRM(ERM):
    """Invariant Risk Minimization"""

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(IRM, self).__init__(input_shape, num_classes, num_domains,
                                  hparams)
        self.register_buffer('update_count', torch.tensor([0]))

    @staticmethod
    def _irm_penalty(logits, y):
        device = "cuda" if logits[0][0].is_cuda else "cpu"
        scale = torch.tensor(1.).to(device).requires_grad_()
        loss_1 = F.cross_entropy(logits[::2] * scale, y[::2])
        loss_2 = F.cross_entropy(logits[1::2] * scale, y[1::2])
        grad_1 = autograd.grad(loss_1, [scale], create_graph=True)[0]
        grad_2 = autograd.grad(loss_2, [scale], create_graph=True)[0]
        result = torch.sum(grad_1 * grad_2)
        return result

    def update(self, minibatches, unlabeled=None):
        device = "cuda" if minibatches[0][0].is_cuda else "cpu"
        penalty_weight = (self.hparams['irm_lambda'] if self.update_count
                          >= self.hparams['irm_penalty_anneal_iters'] else
                          1.0)
        nll = 0.
        penalty = 0.

        if len(minibatches[0]) == 3:
            all_x = torch.cat([x for x,y,_ in minibatches])
            all_logits = self.network(all_x)
            all_logits_idx = 0
            for i, (x, y,_) in enumerate(minibatches):
                logits = all_logits[all_logits_idx:all_logits_idx + x.shape[0]]
                all_logits_idx += x.shape[0]
                nll += F.cross_entropy(logits, y)
                penalty += self._irm_penalty(logits, y)
        elif len(minibatches[0]) == 4:
            all_x = torch.cat([x for x,y,_,_ in minibatches])
            all_logits = self.network(all_x)
            all_logits_idx = 0
            for i, (x, y,_,_) in enumerate(minibatches):
                logits = all_logits[all_logits_idx:all_logits_idx + x.shape[0]]
                all_logits_idx += x.shape[0]
                nll += F.cross_entropy(logits, y)
                penalty += self._irm_penalty(logits, y)
        else:
            all_x = torch.cat([x for x,y in minibatches])
            all_logits = self.network(all_x)
            all_logits_idx = 0
            for i, (x, y) in enumerate(minibatches):
                logits = all_logits[all_logits_idx:all_logits_idx + x.shape[0]]
                all_logits_idx += x.shape[0]
                nll += F.cross_entropy(logits, y)
                penalty += self._irm_penalty(logits, y)
        nll /= len(minibatches)
        penalty /= len(minibatches)
        loss = nll + (penalty_weight * penalty)

        if self.update_count == self.hparams['irm_penalty_anneal_iters']:
            # Reset Adam, because it doesn't like the sharp jump in gradient
            # magnitudes that happens at this step.
            self.optimizer = _get_optimizer(
                self.hparams.get('optimizer', 'adam'),
                self.network.parameters(),
                lr=self.hparams["lr"],
                weight_decay=self.hparams['weight_decay'])

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.update_count += 1
        return {'loss': loss.item(), 'nll': nll.item(), 'penalty': penalty.item()}


class Mixup(ERM):
    """
    Mixup of minibatches from different domains
    https://arxiv.org/pdf/2001.00677.pdf
    https://arxiv.org/pdf/1912.01805.pdf
    """
    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(Mixup, self).__init__(input_shape, num_classes, num_domains,
                                    hparams)

    def update(self, minibatches, unlabeled=None):
        objective = 0

        for (xi, yi), (xj, yj) in random_pairs_of_minibatches(minibatches):
            lam = np.random.beta(self.hparams["mixup_alpha"],
                                 self.hparams["mixup_alpha"])

            x = lam * xi + (1 - lam) * xj
            predictions = self.predict(x)

            objective += lam * F.cross_entropy(predictions, yi)
            objective += (1 - lam) * F.cross_entropy(predictions, yj)

        objective /= len(minibatches)

        self.optimizer.zero_grad()
        objective.backward()
        self.optimizer.step()

        return {'loss': objective.item()}


class GroupDRO(ERM):
    """
    Robust ERM minimizes the error at the worst minibatch
    Algorithm 1 from [https://arxiv.org/pdf/1911.08731.pdf]
    """
    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(GroupDRO, self).__init__(input_shape, num_classes, num_domains,
                                        hparams)
        self.register_buffer("q", torch.Tensor())

    def update(self, minibatches, unlabeled=None):
        device = "cuda" if minibatches[0][0].is_cuda else "cpu"

        if not len(self.q):
            self.q = torch.ones(len(minibatches)).to(device)

        losses = torch.zeros(len(minibatches)).to(device)

        for m in range(len(minibatches)):
            x, y, _, _ = minibatches[m]
            losses[m] = F.cross_entropy(self.predict(x), y)
            self.q[m] *= (self.hparams["groupdro_eta"] * losses[m].data).exp()

        self.q /= self.q.sum()

        loss = torch.dot(losses, self.q)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return {'loss': loss.item()}


class MLDG(ERM):
    """
    Model-Agnostic Meta-Learning
    Algorithm 1 / Equation (3) from: https://arxiv.org/pdf/1710.03463.pdf
    Related: https://arxiv.org/pdf/1703.03400.pdf
    Related: https://arxiv.org/pdf/1910.13580.pdf
    """
    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(MLDG, self).__init__(input_shape, num_classes, num_domains,
                                   hparams)

    def update(self, minibatches, unlabeled=None):
        """
        Terms being computed:
            * Li = Loss(xi, yi, params)
            * Gi = Grad(Li, params)

            * Lj = Loss(xj, yj, Optimizer(params, grad(Li, params)))
            * Gj = Grad(Lj, params)

            * params = Optimizer(params, Grad(Li + beta * Lj, params))
            *        = Optimizer(params, Gi + beta * Gj)

        That is, when calling .step(), we want grads to be Gi + beta * Gj

        For computational efficiency, we do not compute second derivatives.
        """
        num_mb = len(minibatches)
        objective = 0

        self.optimizer.zero_grad()
        for p in self.network.parameters():
            if p.grad is None:
                p.grad = torch.zeros_like(p)

        for (xi, yi), (xj, yj) in random_pairs_of_minibatches(minibatches):
            # fine tune clone-network on task "i"
            inner_net = copy.deepcopy(self.network)

            inner_opt = _get_optimizer(
                self.hparams.get('optimizer', 'adam'),
                inner_net.parameters(),
                lr=self.hparams["lr"],
                weight_decay=self.hparams['weight_decay']
            )

            inner_obj = F.cross_entropy(inner_net(xi), yi)

            inner_opt.zero_grad()
            inner_obj.backward()
            inner_opt.step()

            # The network has now accumulated gradients Gi
            # The clone-network has now parameters P - lr * Gi
            for p_tgt, p_src in zip(self.network.parameters(),
                                    inner_net.parameters()):
                if p_src.grad is not None:
                    p_tgt.grad.data.add_(p_src.grad.data / num_mb)

            # `objective` is populated for reporting purposes
            objective += inner_obj.item()

            # this computes Gj on the clone-network
            loss_inner_j = F.cross_entropy(inner_net(xj), yj)
            grad_inner_j = autograd.grad(loss_inner_j, inner_net.parameters(),
                allow_unused=True)

            # `objective` is populated for reporting purposes
            objective += (self.hparams['mldg_beta'] * loss_inner_j).item()

            for p, g_j in zip(self.network.parameters(), grad_inner_j):
                if g_j is not None:
                    p.grad.data.add_(
                        self.hparams['mldg_beta'] * g_j.data / num_mb)

            # The network has now accumulated gradients Gi + beta * Gj
            # Repeat for all train-test splits, do .step()

        objective /= len(minibatches)

        self.optimizer.step()

        return {'loss': objective}

    # This commented "update" method back-propagates through the gradients of
    # the inner update, as suggested in the original MAML paper.  However, this
    # is twice as expensive as the uncommented "update" method, which does not
    # compute second-order derivatives, implementing the First-Order MAML
    # method (FOMAML) described in the original MAML paper.

    # def update(self, minibatches, unlabeled=None):
    #     objective = 0
    #     beta = self.hparams["beta"]
    #     inner_iterations = self.hparams["inner_iterations"]

    #     self.optimizer.zero_grad()

    #     with higher.innerloop_ctx(self.network, self.optimizer,
    #         copy_initial_weights=False) as (inner_network, inner_optimizer):

    #         for (xi, yi), (xj, yj) in random_pairs_of_minibatches(minibatches):
    #             for inner_iteration in range(inner_iterations):
    #                 li = F.cross_entropy(inner_network(xi), yi)
    #                 inner_optimizer.step(li)
    #
    #             objective += F.cross_entropy(self.network(xi), yi)
    #             objective += beta * F.cross_entropy(inner_network(xj), yj)

    #         objective /= len(minibatches)
    #         objective.backward()
    #
    #     self.optimizer.step()
    #
    #     return objective


class AbstractMMD(ERM):
    """
    Perform ERM while matching the pair-wise domain feature distributions
    using MMD (abstract class)
    """
    def __init__(self, input_shape, num_classes, num_domains, hparams, gaussian, conditional=False, \
            causal=False, ind=False, causalUind=False):
        super(AbstractMMD, self).__init__(input_shape, num_classes, num_domains,
                                  hparams)
        if gaussian:
            self.kernel_type = "gaussian"
        else:
            self.kernel_type = "mean_cov"

        self.conditional = conditional
        self.causal = causal
        self.ind = ind
        self.causalUind = causalUind


    def my_cdist(self, x1, x2):
        x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
        x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
        res = torch.addmm(x2_norm.transpose(-2, -1),
                          x1,
                          x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
        return res.clamp_min_(1e-30)

    def gaussian_kernel_main(self, x, y, gamma=[0.001, 0.01, 0.1, 1, 10, 100,
                                           1000]):
        D = self.my_cdist(x, y)
        K = torch.zeros_like(D)

        for g in gamma:
            K.add_(torch.exp(D.mul(-g)))

        return K

    def gaussian_kernel(self, x, y):
        D = self.my_cdist(x, y)
        K = torch.zeros_like(D)

        K.add_(torch.exp(D.mul(-self.hparams['sigma'])))

        return K

    def mmd(self, x, y):
        if self.kernel_type == "gaussian":
            Kxx = self.gaussian_kernel(x, x).mean()
            Kyy = self.gaussian_kernel(y, y).mean()
            Kxy = self.gaussian_kernel(x, y).mean()
            return Kxx + Kyy - 2 * Kxy
        else:
            mean_x = x.mean(0, keepdim=True)
            mean_y = y.mean(0, keepdim=True)
            cent_x = x - mean_x
            cent_y = y - mean_y
            cova_x = (cent_x.t() @ cent_x) / (len(x) - 1)
            cova_y = (cent_y.t() @ cent_y) / (len(y) - 1)

            mean_diff = (mean_x - mean_y).pow(2).mean()
            cova_diff = (cova_x - cova_y).pow(2).mean()

            return mean_diff + cova_diff

    def update(self, minibatches, unlabeled=None):
        objective = 0
        penalty = 0
        penalty_ind = 0
        nmb = len(minibatches)

        if self.causal:
            if len(minibatches[0]) == 4:
                features = [self.featurizer(xi) for xi, _, _, _ in minibatches]
                classifs = [self.classifier(fi) for fi in features]
                targets = [yi for _, yi, _, _ in minibatches]
                attribute_labels = [ai for _, _, ai, _ in minibatches]
                azimuth_labels = [ai for _, _, _, ai in minibatches]
            elif len(minibatches[0]) == 3:
                features = [self.featurizer(xi) for xi, _, _ in minibatches]
                classifs = [self.classifier(fi) for fi in features]
                targets = [yi for _, yi, _ in minibatches]
                attribute_labels = [ai for _, _, ai in minibatches]

            attribute_labels_orig = attribute_labels
            
            for i in range(nmb):

                unique_labels = torch.unique(targets[i])
                unique_label_indices = []
                for label in unique_labels:
                    label_ind = [ind for ind, j in enumerate(targets[i]) if j == label]
                    unique_label_indices.append(label_ind)

                nulabels = unique_labels.shape[0]
                for idx in range(nulabels):
                    objective += F.cross_entropy(classifs[i][unique_label_indices[idx]], targets[i][unique_label_indices[idx]])
                    unique_attrs = torch.unique(attribute_labels[i][unique_label_indices[idx]])
                    unique_attr_indices = []
                    for attr in unique_attrs:
                        single_attr = []
                        for y_attr_idx in unique_label_indices[idx]:
                            if attribute_labels[i][y_attr_idx] == attr:
                                single_attr.append(y_attr_idx)
                        unique_attr_indices.append(single_attr)

                    nuattr = unique_attrs.shape[0]
                    for aidx in range(nuattr):
                        for bidx in range(aidx + 1, nuattr):
                            penalty += self.mmd(classifs[i][unique_attr_indices[aidx]], classifs[i][unique_attr_indices[bidx]])

        elif self.ind:
            if len(minibatches[0]) == 4:
                features = [self.featurizer(xi) for xi, _, _, _ in minibatches]
                classifs = [self.classifier(fi) for fi in features]
                targets = [yi for _, yi, _, _ in minibatches]
                attribute_labels = [ai for _, _, ai, _ in minibatches]
                azimuth_labels = [ai for _, _, _, ai in minibatches]
            elif len(minibatches[0]) == 3:
                features = [self.featurizer(xi) for xi, _, _ in minibatches]
                classifs = [self.classifier(fi) for fi in features]
                targets = [yi for _, yi, _ in minibatches]
                attribute_labels = [ai for _, _, ai in minibatches]

            for i in range(nmb):
                objective += F.cross_entropy(classifs[i], targets[i])
            
            overall_nmb_indices, nmb_id = [], []
            for i in range(nmb):

                unique_attrs = torch.unique(azimuth_labels[i])
                unique_attr_indices = []
                for attr in unique_attrs:
                    attr_ind = [ind for ind, j in enumerate(azimuth_labels[i]) if j == attr]
                    unique_attr_indices.append(attr_ind)
                    overall_nmb_indices.append(attr_ind)
                    nmb_id.append(i)

            nuattr = len(overall_nmb_indices)
            for aidx in range(nuattr):
                for bidx in range(aidx + 1, nuattr):
                    a_nmb_id = nmb_id[aidx]
                    b_nmb_id = nmb_id[bidx]
                    penalty += self.mmd(classifs[a_nmb_id][overall_nmb_indices[aidx]], classifs[b_nmb_id][overall_nmb_indices[bidx]])
                            
        if self.causalUind:
            if len(minibatches[0]) == 4:
                features = [self.featurizer(xi) for xi, _, _, _ in minibatches]
                classifs = [self.classifier(fi) for fi in features]
                targets = [yi for _, yi, _, _ in minibatches]
                attribute_labels = [ai for _, _, ai, _ in minibatches]
                aind_labels = [ai for _, _, _, ai in minibatches]
            elif len(minibatches[0]) == 3:
                features = [self.featurizer(xi) for xi, _, _ in minibatches]
                classifs = [self.classifier(fi) for fi in features]
                targets = [yi for _, yi, _ in minibatches]
                attribute_labels = [ai for _, _, ai in minibatches]

            
            for i in range(nmb):

                # Aind regularization
                unique_aind_labels = torch.unique(aind_labels[i])
                unique_aind_label_indices = []
                for label in unique_aind_labels:
                    label_ind = [ind for ind, j in enumerate(aind_labels[i]) if j == label]
                    unique_aind_label_indices.append(label_ind)

                nulabels = unique_aind_labels.shape[0]
                for aidx in range(nulabels):
                    for bidx in range(aidx + 1, nulabels):
                        penalty_ind += self.mmd(classifs[i][unique_aind_label_indices[aidx]], classifs[i][unique_aind_label_indices[bidx]])
                
                # Acause regularization
                unique_labels = torch.unique(targets[i])
                unique_label_indices = []
                for label in unique_labels:
                    label_ind = [ind for ind, j in enumerate(targets[i]) if j == label]
                    unique_label_indices.append(label_ind)

                nulabels = unique_labels.shape[0]
                for idx in range(nulabels):
                    objective += F.cross_entropy(classifs[i][unique_label_indices[idx]], targets[i][unique_label_indices[idx]])
                    unique_attrs = torch.unique(attribute_labels[i][unique_label_indices[idx]])
                    unique_attr_indices = []
                    for attr in unique_attrs:
                        single_attr = []
                        for y_attr_idx in unique_label_indices[idx]:
                            if attribute_labels[i][y_attr_idx] == attr:
                                single_attr.append(y_attr_idx)
                        unique_attr_indices.append(single_attr)

                    nuattr = unique_attrs.shape[0]
                    for aidx in range(nuattr):
                        for bidx in range(aidx + 1, nuattr):
                            penalty += self.mmd(classifs[i][unique_attr_indices[aidx]], classifs[i][unique_attr_indices[bidx]])   

                        
        elif self.conditional:
            if len(minibatches[0]) == 4:
                features = [self.featurizer(xi) for xi, _, _, _ in minibatches]
                classifs = [self.classifier(fi) for fi in features]
                targets = [yi for _, yi, _, _ in minibatches]
            elif len(minibatches[0]) == 3:
                features = [self.featurizer(xi) for xi, _, _ in minibatches]
                classifs = [self.classifier(fi) for fi in features]
                targets = [yi for _, yi, _ in minibatches]

            for i in range(nmb):
                objective += F.cross_entropy(classifs[i], targets[i])

            unique_labels = torch.unique(targets[0])
            unique_label_indices = {} # {label1: [[env0_inds][env1_inds]...[envn_inds]], label2:..}
            for label in unique_labels:
                unique_label_indices[label] = []
                for i in range(nmb):
                    label_ind = [ind for ind, j in enumerate(targets[i]) if j == label]
                    unique_label_indices[label].append(label_ind)

            for label_key in unique_label_indices:
                nuenvs = len(unique_label_indices[label_key])
                for aidx in range(nuenvs):
                    for bidx in range(aidx + 1, nuenvs):
                        penalty += self.mmd(features[aidx][unique_label_indices[label_key][aidx]], features[bidx][unique_label_indices[label_key][bidx]])
                        
        else:
            if len(minibatches[0]) == 4:
                features = [self.featurizer(xi) for xi, _, _, _ in minibatches]
                classifs = [self.classifier(fi) for fi in features]
                targets = [yi for _, yi, _, _ in minibatches]
            elif len(minibatches[0]) == 3:
                features = [self.featurizer(xi) for xi, _, _ in minibatches]
                classifs = [self.classifier(fi) for fi in features]
                targets = [yi for _, yi, _ in minibatches]
            else:
                features = [self.featurizer(xi) for xi, _ in minibatches]
                classifs = [self.classifier(fi) for fi in features]
                targets = [yi for _, yi in minibatches]

            for i in range(nmb):
                objective += F.cross_entropy(classifs[i], targets[i])
                for j in range(i + 1, nmb):
                    penalty += self.mmd(features[i], features[j])

        objective /= nmb
        if nmb > 1:
            penalty /= (nmb * (nmb - 1) / 2)

        self.optimizer.zero_grad()
        if self.corr_ind_reg_difflambda:
                 (objective + (self.hparams['mmd_gamma']*penalty)+ (self.hparams['mmd_gamma_ind']*penalty_ind)).backward()
        else:
            (objective + (self.hparams['mmd_gamma']*penalty)).backward()
        self.optimizer.step()

        if torch.is_tensor(penalty):
            penalty = penalty.item()
    
        return {'loss': objective.item(), 'penalty': penalty}


class MMD(AbstractMMD):
    """
    MMD using Gaussian kernel
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(MMD, self).__init__(input_shape, num_classes,
                                          num_domains, hparams, gaussian=True)

class CMMD(AbstractMMD):
    """
    MMD using Gaussian kernel
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(CMMD, self).__init__(input_shape, num_classes,
                                          num_domains, hparams, gaussian=True, conditional=True)

class CACM_ACause(AbstractMMD):
    """
    CACM-Causal shift using Gaussian kernel
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(CACM_Cause, self).__init__(input_shape, num_classes,
                                          num_domains, hparams, gaussian=True, causal=True)

class CACM_Aind(AbstractMMD):
    """
    CACM-Independent shift using Gaussian kernel
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(CACM_Aind, self).__init__(input_shape, num_classes,
                                          num_domains, hparams, gaussian=True, ind=True)

class CACM_ACauseUAind(AbstractMMD):
    """
    CACM-Causal+Independent shift using Gaussian kernel
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(CACM_ACauseUAind, self).__init__(input_shape, num_classes,
                                          num_domains, hparams, gaussian=True, causalUind=True)

class CORAL(AbstractMMD):
    """
    MMD using mean and covariance difference
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(CORAL, self).__init__(input_shape, num_classes,
                                         num_domains, hparams, gaussian=False)

class SagNet(Algorithm):
    """
    Style Agnostic Network
    Algorithm 1 from: https://arxiv.org/abs/1910.11645
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(SagNet, self).__init__(input_shape, num_classes, num_domains,
                                  hparams)
        # featurizer network
        self.network_f = networks.Featurizer(input_shape, self.hparams)
        # content network
        self.network_c = networks.Classifier(
            self.network_f.n_outputs,
            num_classes,
            self.hparams['nonlinear_classifier'])
        # style network
        self.network_s = networks.Classifier(
            self.network_f.n_outputs,
            num_classes,
            self.hparams['nonlinear_classifier'])

        # # This commented block of code implements something closer to the
        # # original paper, but is specific to ResNet and puts in disadvantage
        # # the other algorithms.
        # resnet_c = networks.Featurizer(input_shape, self.hparams)
        # resnet_s = networks.Featurizer(input_shape, self.hparams)
        # # featurizer network
        # self.network_f = torch.nn.Sequential(
        #         resnet_c.network.conv1,
        #         resnet_c.network.bn1,
        #         resnet_c.network.relu,
        #         resnet_c.network.maxpool,
        #         resnet_c.network.layer1,
        #         resnet_c.network.layer2,
        #         resnet_c.network.layer3)
        # # content network
        # self.network_c = torch.nn.Sequential(
        #         resnet_c.network.layer4,
        #         resnet_c.network.avgpool,
        #         networks.Flatten(),
        #         resnet_c.network.fc)
        # # style network
        # self.network_s = torch.nn.Sequential(
        #         resnet_s.network.layer4,
        #         resnet_s.network.avgpool,
        #         networks.Flatten(),
        #         resnet_s.network.fc)

        def opt(p):
            return _get_optimizer(
                self.hparams.get('optimizer', 'adam'), p, lr=hparams["lr"],
                weight_decay=hparams["weight_decay"])

        self.optimizer_f = opt(self.network_f.parameters())
        self.optimizer_c = opt(self.network_c.parameters())
        self.optimizer_s = opt(self.network_s.parameters())
        self.weight_adv = hparams["sag_w_adv"]

    def forward_c(self, x):
        # learning content network on randomized style
        return self.network_c(self.randomize(self.network_f(x), "style"))

    def forward_s(self, x):
        # learning style network on randomized content
        return self.network_s(self.randomize(self.network_f(x), "content"))

    def randomize(self, x, what="style", eps=1e-5):
        device = "cuda" if x.is_cuda else "cpu"
        sizes = x.size()
        alpha = torch.rand(sizes[0], 1).to(device)

        if len(sizes) == 4:
            x = x.view(sizes[0], sizes[1], -1)
            alpha = alpha.unsqueeze(-1)

        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, keepdim=True)

        x = (x - mean) / (var + eps).sqrt()

        idx_swap = torch.randperm(sizes[0])
        if what == "style":
            mean = alpha * mean + (1 - alpha) * mean[idx_swap]
            var = alpha * var + (1 - alpha) * var[idx_swap]
        else:
            x = x[idx_swap].detach()

        x = x * (var + eps).sqrt() + mean
        return x.view(*sizes)

    def update(self, minibatches, unlabeled=None):
        if len(minibatches[0]) == 3:
            all_x = torch.cat([x for x, y, _ in minibatches])
            all_y = torch.cat([y for x, y, _ in minibatches])
        elif len(minibatches[0]) == 4:
            all_x = torch.cat([x for x, y, _, _ in minibatches])
            all_y = torch.cat([y for x, y, _, _ in minibatches])
        else:
            all_x = torch.cat([x for x, y in minibatches])
            all_y = torch.cat([y for x, y in minibatches])

        # learn content
        self.optimizer_f.zero_grad()
        self.optimizer_c.zero_grad()
        loss_c = F.cross_entropy(self.forward_c(all_x), all_y)
        loss_c.backward()
        self.optimizer_f.step()
        self.optimizer_c.step()

        # learn style
        self.optimizer_s.zero_grad()
        loss_s = F.cross_entropy(self.forward_s(all_x), all_y)
        loss_s.backward()
        self.optimizer_s.step()

        # learn adversary
        self.optimizer_f.zero_grad()
        loss_adv = -F.log_softmax(self.forward_s(all_x), dim=1).mean(1).mean()
        loss_adv = loss_adv * self.weight_adv
        loss_adv.backward()
        self.optimizer_f.step()

        return {'loss_c': loss_c.item(), 'loss_s': loss_s.item(),
                'loss_adv': loss_adv.item()}

    def predict(self, x):
        return self.network_c(self.network_f(x))


class RSC(ERM):
    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(RSC, self).__init__(input_shape, num_classes, num_domains,
                                   hparams)
        self.drop_f = (1 - hparams['rsc_f_drop_factor']) * 100
        self.drop_b = (1 - hparams['rsc_b_drop_factor']) * 100
        self.num_classes = num_classes

    def update(self, minibatches, unlabeled=None):
        device = "cuda" if minibatches[0][0].is_cuda else "cpu"

        # inputs
        if len(minibatches[0]) == 3:
            all_x = torch.cat([x for x, y, _ in minibatches])
            all_y = torch.cat([y for x, y, _ in minibatches])
        elif len(minibatches[0]) == 4:
            all_x = torch.cat([x for x, y, _, _ in minibatches])
            all_y = torch.cat([y for x, y, _, _ in minibatches])
        else:
            all_x = torch.cat([x for x, y in minibatches])
            # labels
            all_y = torch.cat([y for _, y in minibatches])
        # one-hot labels
        all_o = torch.nn.functional.one_hot(all_y, self.num_classes)
        # features
        all_f = self.featurizer(all_x)
        # predictions
        all_p = self.classifier(all_f)

        # Equation (1): compute gradients with respect to representation
        all_g = autograd.grad((all_p * all_o).sum(), all_f)[0]

        # Equation (2): compute top-gradient-percentile mask
        percentiles = np.percentile(all_g.cpu(), self.drop_f, axis=1)
        percentiles = torch.Tensor(percentiles)
        percentiles = percentiles.unsqueeze(1).repeat(1, all_g.size(1))
        mask_f = all_g.lt(percentiles.to(device)).float()

        # Equation (3): mute top-gradient-percentile activations
        all_f_muted = all_f * mask_f

        # Equation (4): compute muted predictions
        all_p_muted = self.classifier(all_f_muted)

        # Section 3.3: Batch Percentage
        all_s = F.softmax(all_p, dim=1)
        all_s_muted = F.softmax(all_p_muted, dim=1)
        changes = (all_s * all_o).sum(1) - (all_s_muted * all_o).sum(1)
        percentile = np.percentile(changes.detach().cpu(), self.drop_b)
        mask_b = changes.lt(percentile).float().view(-1, 1)
        mask = ((mask_f > 0) | (mask_b > 0)).float()
        # mask = torch.logical_or(mask_f, mask_b).float()   # not available until pytorch 1.5

        # Equations (3) and (4) again, this time mutting over examples
        all_p_muted_again = self.classifier(all_f * mask)

        # Equation (5): update
        loss = F.cross_entropy(all_p_muted_again, all_y)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return {'loss': loss.item()}


# +
class IB_ERM(ERM):
    """Information Bottleneck based ERM on feature with conditionning"""

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(IB_ERM, self).__init__(input_shape, num_classes, num_domains,
                                  hparams)
        self.optimizer = torch.optim.Adam(
            list(self.featurizer.parameters()) + list(self.classifier.parameters()),
            lr=self.hparams["lr"],
            weight_decay=self.hparams['weight_decay']
        )
        self.register_buffer('update_count', torch.tensor([0]))

    def update(self, minibatches, unlabeled=None):
        device = "cuda" if minibatches[0][0].is_cuda else "cpu"
        ib_penalty_weight = (self.hparams['ib_lambda'] if self.update_count
                          >= self.hparams['ib_penalty_anneal_iters'] else
                          0.0)

        nll = 0.
        ib_penalty = 0.

        if len(minibatches[0]) == 3:
            all_x = torch.cat([x for x,y,_ in minibatches])
            all_features = self.featurizer(all_x)
            all_logits = self.classifier(all_features)
            all_logits_idx = 0
            for i, (x, y, _) in enumerate(minibatches):
                features = all_features[all_logits_idx:all_logits_idx + x.shape[0]]
                logits = all_logits[all_logits_idx:all_logits_idx + x.shape[0]]
                all_logits_idx += x.shape[0]
                nll += F.cross_entropy(logits, y)
                ib_penalty += features.var(dim=0).mean()
        elif len(minibatches[0]) == 4:
            all_x = torch.cat([x for x,y,_,_ in minibatches])
            all_features = self.featurizer(all_x)
            all_logits = self.classifier(all_features)
            all_logits_idx = 0
            for i, (x, y,_,_) in enumerate(minibatches):
                features = all_features[all_logits_idx:all_logits_idx + x.shape[0]]
                logits = all_logits[all_logits_idx:all_logits_idx + x.shape[0]]
                all_logits_idx += x.shape[0]
                nll += F.cross_entropy(logits, y)
                ib_penalty += features.var(dim=0).mean()
        else:
            all_x = torch.cat([x for x,y in minibatches])
            all_features = self.featurizer(all_x)
            all_logits = self.classifier(all_features)
            all_logits_idx = 0
            for i, (x, y) in enumerate(minibatches):
                features = all_features[all_logits_idx:all_logits_idx + x.shape[0]]
                logits = all_logits[all_logits_idx:all_logits_idx + x.shape[0]]
                all_logits_idx += x.shape[0]
                nll += F.cross_entropy(logits, y)
                ib_penalty += features.var(dim=0).mean()
        

        nll /= len(minibatches)
        ib_penalty /= len(minibatches)

        # Compile loss
        loss = nll 
        loss += ib_penalty_weight * ib_penalty

        if self.update_count == self.hparams['ib_penalty_anneal_iters']:
            # Reset Adam, because it doesn't like the sharp jump in gradient
            # magnitudes that happens at this step.
            self.optimizer = torch.optim.Adam(
                list(self.featurizer.parameters()) + list(self.classifier.parameters()),
                lr=self.hparams["lr"],
                weight_decay=self.hparams['weight_decay'])

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.update_count += 1
        return {'loss': loss.item(), 
                'nll': nll.item(),
                'IB_penalty': ib_penalty.item()}

class IB_IRM(ERM):
    """Information Bottleneck based IRM on feature with conditionning"""

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(IB_IRM, self).__init__(input_shape, num_classes, num_domains,
                                  hparams)
        self.optimizer = torch.optim.Adam(
            list(self.featurizer.parameters()) + list(self.classifier.parameters()),
            lr=self.hparams["lr"],
            weight_decay=self.hparams['weight_decay']
        )
        self.register_buffer('update_count', torch.tensor([0]))

    @staticmethod
    def _irm_penalty(logits, y):
        device = "cuda" if logits[0][0].is_cuda else "cpu"
        scale = torch.tensor(1.).to(device).requires_grad_()
        loss_1 = F.cross_entropy(logits[::2] * scale, y[::2])
        loss_2 = F.cross_entropy(logits[1::2] * scale, y[1::2])
        grad_1 = autograd.grad(loss_1, [scale], create_graph=True)[0]
        grad_2 = autograd.grad(loss_2, [scale], create_graph=True)[0]
        result = torch.sum(grad_1 * grad_2)
        return result

    def update(self, minibatches, unlabeled=None):
        device = "cuda" if minibatches[0][0].is_cuda else "cpu"
        irm_penalty_weight = (self.hparams['irm_lambda'] if self.update_count
                          >= self.hparams['irm_penalty_anneal_iters'] else
                          1.0)
        ib_penalty_weight = (self.hparams['ib_lambda'] if self.update_count
                          >= self.hparams['ib_penalty_anneal_iters'] else
                          0.0)

        nll = 0.
        irm_penalty = 0.
        ib_penalty = 0.

        if len(minibatches[0]) == 3:
            all_x = torch.cat([x for x,y,_ in minibatches])
            all_features = self.featurizer(all_x)
            all_logits = self.classifier(all_features)
            all_logits_idx = 0
            for i, (x, y, _) in enumerate(minibatches):
                features = all_features[all_logits_idx:all_logits_idx + x.shape[0]]
                logits = all_logits[all_logits_idx:all_logits_idx + x.shape[0]]
                all_logits_idx += x.shape[0]
                nll += F.cross_entropy(logits, y)
                irm_penalty += self._irm_penalty(logits, y)
                ib_penalty += features.var(dim=0).mean()
        elif len(minibatches[0]) == 4:
            all_x = torch.cat([x for x,y,_,_ in minibatches])
            all_features = self.featurizer(all_x)
            all_logits = self.classifier(all_features)
            all_logits_idx = 0
            for i, (x, y,_,_) in enumerate(minibatches):
                features = all_features[all_logits_idx:all_logits_idx + x.shape[0]]
                logits = all_logits[all_logits_idx:all_logits_idx + x.shape[0]]
                all_logits_idx += x.shape[0]
                nll += F.cross_entropy(logits, y)
                irm_penalty += self._irm_penalty(logits, y)
                ib_penalty += features.var(dim=0).mean()
        else:
            all_x = torch.cat([x for x,y in minibatches])
            all_features = self.featurizer(all_x)
            all_logits = self.classifier(all_features)
            all_logits_idx = 0
            for i, (x, y) in enumerate(minibatches):
                features = all_features[all_logits_idx:all_logits_idx + x.shape[0]]
                logits = all_logits[all_logits_idx:all_logits_idx + x.shape[0]]
                all_logits_idx += x.shape[0]
                nll += F.cross_entropy(logits, y)
                irm_penalty += self._irm_penalty(logits, y)
                ib_penalty += features.var(dim=0).mean()

        nll /= len(minibatches)
        irm_penalty /= len(minibatches)
        ib_penalty /= len(minibatches)

        # Compile loss
        loss = nll 
        loss += irm_penalty_weight * irm_penalty
        loss += ib_penalty_weight * ib_penalty

        if self.update_count == self.hparams['irm_penalty_anneal_iters'] or self.update_count == self.hparams['ib_penalty_anneal_iters']:
            # Reset Adam, because it doesn't like the sharp jump in gradient
            # magnitudes that happens at this step.
            self.optimizer = torch.optim.Adam(
                list(self.featurizer.parameters()) + list(self.classifier.parameters()),
                lr=self.hparams["lr"],
                weight_decay=self.hparams['weight_decay'])

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.update_count += 1
        return {'loss': loss.item(), 
                'nll': nll.item(),
                'IRM_penalty': irm_penalty.item(), 
                'IB_penalty': ib_penalty.item()}
