
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.distributions as dist

import copy
import numpy as np

import networks
from lib.misc import random_pairs_of_minibatches
from loss import *

ALGORITHMS = [
    # 'ERM',
    'ERM_GP',
    # 'DANN',
    # 'WD',
    # 'KL',
    # 'KLUP'
]

def get_algorithm_class(algorithm_name):
    """Return the algorithm class with the given name."""
    if algorithm_name not in globals():
        raise NotImplementedError("Algorithm not found: {}".format(algorithm_name))
    return globals()[algorithm_name]

class Algorithm(torch.nn.Module):
    """
    A subclass of Algorithm implements a domain generalization algorithm.
    Subclasses should implement the following:
    - update()
    - predict()
    """
    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(Algorithm, self).__init__()
        self.hparams = hparams

    def update(self, minibatches, unlabeled=None):
        """
        Perform one update step, given a list of (x, y) tuples for all
        environments.
        Admits an optional list of unlabeled minibatches from the test domains,
        when task is domain_adaptation.
        """
        raise NotImplementedError

    def predict(self, x):
        raise NotImplementedError

class ERM(Algorithm):
    """
    Empirical Risk Minimization (ERM)
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(ERM, self).__init__(input_shape, num_classes, num_domains,
                                  hparams)
        self.featurizer = networks.Featurizer(input_shape, self.hparams)
        self.classifier = networks.Classifier(
            self.featurizer.n_outputs,
            num_classes,
            self.hparams['nonlinear_classifier'])

        self.network = nn.Sequential(self.featurizer, self.classifier)
        self.optimizer = torch.optim.Adam(
            self.network.parameters(),
            lr=self.hparams["lr"],
            weight_decay=self.hparams['weight_decay']
        )

    def update(self, minibatches, unlabeled=None):
        all_x = torch.cat([x for x,y in minibatches])
        all_y = torch.cat([y for x,y in minibatches])
        loss = F.cross_entropy(self.predict(all_x), all_y)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return {'loss': loss.item()}

    def predict(self, x):
        return self.network(x)

class ERM_GP(Algorithm):
    """
    ERM with Gradient Penalty
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(ERM_GP, self).__init__(input_shape, num_classes, num_domains,
                                  hparams)
        self.featurizer = networks.Featurizer(input_shape, self.hparams)
        self.classifier = networks.Classifier(
            self.featurizer.n_outputs,
            num_classes,
            self.hparams['nonlinear_classifier'])

        self.network = nn.Sequential(self.featurizer, self.classifier)
        self.optimizer = torch.optim.Adam(
            self.network.parameters(),
            lr=self.hparams["lr"],
            weight_decay=self.hparams['weight_decay']
        )

    def update(self, minibatches, unlabeled=None):
        all_x = torch.cat([x for x,y in minibatches])
        all_y = torch.cat([y for x,y in minibatches])
        loss = F.cross_entropy(self.predict(all_x), all_y)
        grad_params = torch.autograd.grad(loss, self.network.parameters(), create_graph=True)
        grad_norm = 0
        for grad in grad_params:
            grad_norm += grad.pow(2).sum()
        grad_norm = grad_norm.sqrt()
        reg_loss = loss + self.hparams["grad_decay"] * grad_norm
        # loss.backward()
        # epsilon = 1e-10  # for stable torch.sqrt
        # grad = autograd.grad(loss.sum(),
        #                      self.network, create_graph=True)[0]
        # grad_penalty = ((torch.sqrt((grad ** 2).sum(dim=1) + epsilon) - 1) ** 2).mean(dim=0)
        # reg_loss = loss + self.hparams['grad_penalty'] * grad_penalty

        self.optimizer.zero_grad()
        reg_loss.backward()
        self.optimizer.step()

        return {'loss': loss.item()}

    def predict(self, x):
        return self.network(x)




class CLIGAv3(ERM):
    """
    Controlling label information and gradient alignment version 3: using only target pseudo label distribution to train
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(CLIGAv3, self).__init__(input_shape, num_classes, num_domains,
                                  hparams)
        self.featurizer = networks.Featurizer(input_shape, self.hparams)
        self.classifier = networks.Classifier(
            self.featurizer.n_outputs,
            num_classes,
            self.hparams['nonlinear_classifier'])

        self.network = nn.Sequential(self.featurizer, self.classifier)
        self.aux_net = copy.deepcopy(self.network)
        self.optimizer = torch.optim.Adam(
            self.network.parameters(),
            lr=self.hparams["lr"],
            weight_decay=self.hparams['weight_decay']
        )

        self.aux_optimizer = torch.optim.Adam(
            self.aux_net.parameters(),
            lr=self.hparams["lr"],
            weight_decay=self.hparams['weight_decay']
        )
        self.iter = 0

    def update(self, minibatches, unlabeled=None):
        all_x = torch.cat([x for x,y in minibatches])
        all_y = torch.cat([y for x,y in minibatches])
        x_target = torch.cat(unlabeled)

        # total_x = torch.cat([all_x, x_target]) self.aux_net(x_target).softmax(dim=1).detach()
        self.network.zero_grad()
        loss = F.cross_entropy(self.network(all_x), all_y)
        l2_dis = sum((w1 - w2.detach()).pow(2.0).sum()
                     for w1, w2 in zip(self.network.parameters(), self.aux_net.parameters()))
        loss = loss + self.hparams['dis_decay'] * l2_dis
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        #updating auxiliary network by using source domain data
        self.aux_net.zero_grad()
        loss_aux = F.cross_entropy(self.aux_net(x_target), self.network(x_target).softmax(dim=1).detach())
        self.aux_optimizer.zero_grad()
        loss_aux.backward()
        self.aux_optimizer.step()

        self.iter += 1
        return {'loss': loss_aux.item()}

class CLIGAv4(ERM):
    """
    Controlling label information and gradient alignment version 4: combining with gradient penalty
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(CLIGAv4, self).__init__(input_shape, num_classes, num_domains,
                                  hparams)
        self.featurizer = networks.Featurizer(input_shape, self.hparams)
        self.classifier = networks.Classifier(
            self.featurizer.n_outputs,
            num_classes,
            self.hparams['nonlinear_classifier'])

        self.network = nn.Sequential(self.featurizer, self.classifier)
        self.aux_net = copy.deepcopy(self.network)
        self.optimizer = torch.optim.Adam(
            self.network.parameters(),
            lr=self.hparams["lr"],
            weight_decay=self.hparams['weight_decay']
        )

        self.aux_optimizer = torch.optim.Adam(
            self.aux_net.parameters(),
            lr=self.hparams["lr"],
            weight_decay=self.hparams['weight_decay']
        )
        self.iter = 0

    def update(self, minibatches, unlabeled=None):
        all_x = torch.cat([x for x,y in minibatches])
        all_y = torch.cat([y for x,y in minibatches])
        x_target = torch.cat(unlabeled)

        total_x = torch.cat([all_x, x_target])


        #updating auxiliary network by using source domain data
        self.aux_net.zero_grad()
        loss_aux = F.cross_entropy(self.aux_net(all_x), all_y)
        loss_aux_reg = self.grad_penalty(loss_aux, self.aux_net.parameters())
        self.aux_optimizer.zero_grad()
        loss_aux_reg.backward()
        self.aux_optimizer.step()

        if self.iter > self.hparams['warm_up']:
            #updating the real classifier by using the pseudo labels given by the auxiliary network
            self.network.zero_grad()
            loss = F.cross_entropy(self.network(total_x), self.aux_net(total_x).softmax(dim=1).detach())
            l2_dis = sum((w1-w2.detach()).pow(2.0).sum()
                          for w1, w2 in zip(self.network.parameters(), self.aux_net.parameters()))
            loss = loss + self.hparams['dis_decay']*l2_dis
            loss_reg = self.grad_penalty(loss, self.network.parameters())
            self.optimizer.zero_grad()
            loss_reg.backward()
            self.optimizer.step()

        self.iter += 1
        return {'loss': loss_aux.item()}

    def grad_penalty(self, loss, param):
        grad_params = torch.autograd.grad(loss, param, create_graph=True)
        grad_norm = 0
        for grad in grad_params:
            grad_norm += grad.pow(2).sum()
        grad_norm = grad_norm.sqrt()
        reg_loss = loss + self.hparams["grad_decay"] * grad_norm
        return reg_loss


class CLIGAv5(ERM):
    """
    Controlling label information and gradient alignment version 5: combining with representation alignment
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(CLIGAv5, self).__init__(input_shape, num_classes, num_domains,
                                  hparams)
        self.featurizer = networks.Featurizer(input_shape, self.hparams)
        self.classifier = networks.Classifier(
            self.featurizer.n_outputs,
            num_classes,
            self.hparams['nonlinear_classifier'])

        # self.network = nn.Sequential(self.featurizer, self.classifier)
        self.aux_featurizer = copy.deepcopy(self.featurizer)
        self.aux_classifier = copy.deepcopy(self.classifier)
        self.optimizer = torch.optim.Adam(
            list(self.featurizer.parameters()) + list(self.classifier.parameters()),
            lr=self.hparams["lr"],
            weight_decay=self.hparams['weight_decay']
        )

        self.aux_optimizer = torch.optim.Adam(
            list(self.aux_featurizer.parameters()) + list(self.aux_classifier.parameters()),
            lr=self.hparams["lr"],
            weight_decay=self.hparams['weight_decay']
        )
        self.iter = 0

    def update(self, minibatches, unlabeled=None):
        all_x = torch.cat([x for x,y in minibatches])
        all_y = torch.cat([y for x,y in minibatches])
        x_target = torch.cat(unlabeled)

        total_x = torch.cat([all_x, x_target])

        # source_T_target = self.featurizer(all_x)
        target_T = self.aux_featurizer(x_target)

        #updating auxiliary network by using source domain data
        source_T = self.aux_featurizer(all_x)
        loss_aux = F.cross_entropy(self.aux_classifier(source_T), all_y)
        # loss_aux_reg = self.grad_penalty(loss_aux, list(self.aux_featurizer.parameters()) + list(self.aux_classifier.parameters()))
        l2_dis = ((target_T.mean(dim=0) - source_T.mean(dim=0)) ** 2).sum()
        loss_aux += self.hparams['dis_decay'] * l2_dis
        self.aux_optimizer.zero_grad()
        loss_aux.backward()
        self.aux_optimizer.step()

        if self.iter > self.hparams['warm_up']:
            #updating the real classifier by using the pseudo labels given by the auxiliary network
            # target_T_source = self.aux_featurizer(x_target)
            target_T = self.featurizer(x_target)
            # source_T = self.aux_featurizer(all_x)
            loss = F.cross_entropy(self.classifier(target_T), self.aux_classifier(self.aux_featurizer(x_target)).softmax(dim=1).detach())
            # l2_dis = sum((w1-w2.detach()).pow(2.0).sum()
            #               for w1, w2 in zip(self.network.parameters(), self.aux_net.parameters()))
            # l2_dis = ((target_T.mean(dim=0)-source_T.mean(dim=0).detach())**2).sum()
            # loss += self.hparams['dis_decay']*l2_dis
            # loss_reg = self.grad_penalty(loss, self.network.parameters())
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        self.iter += 1
        return {'loss': loss_aux.item()}

    def grad_penalty(self, loss, param):
        grad_params = torch.autograd.grad(loss, param, create_graph=True)
        grad_norm = 0
        for grad in grad_params:
            grad_norm += grad.pow(2).sum()
        grad_norm = grad_norm.sqrt()
        reg_loss = loss + self.hparams["grad_decay"] * grad_norm
        return reg_loss

    def predict(self, x):
        # return self.aux_classifier(self.aux_featurizer(x))
        return self.classifier(self.featurizer(x))



class PERM(Algorithm):
    """
    Empirical Risk Minimization (ERM) with probabilistic representation network
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(PERM, self).__init__(input_shape, num_classes, num_domains,
                                  hparams)
        self.featurizer = networks.Featurizer(input_shape, self.hparams, probabilistic=True)
        self.classifier = networks.Classifier(
            self.featurizer.n_outputs,
            num_classes,
            self.hparams['nonlinear_classifier'])

        self.optimizer = torch.optim.Adam(
            list(self.featurizer.parameters()) + list(self.classifier.parameters()),
            lr=self.hparams["lr"],
            weight_decay=self.hparams['weight_decay']
        )
        self.num_samples = hparams['num_samples']

    def update(self, minibatches, unlabeled=None):
        all_x = torch.cat([x for x,y in minibatches])
        all_y = torch.cat([y for x,y in minibatches])

        all_z_params = self.featurizer(all_x)
        z_dim = int(all_z_params.shape[-1]/2)
        z_mu = all_z_params[:,:z_dim]
        z_sigma = F.softplus(all_z_params[:,z_dim:])

        all_z_dist = dist.Independent(dist.normal.Normal(z_mu,z_sigma),1)
        all_z = all_z_dist.rsample()

        loss = F.cross_entropy(self.classifier(all_z), all_y)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return {'loss': loss.item()}

    def predict(self, x):
        z_params = self.featurizer(x)
        z_dim = int(z_params.shape[-1]/2)
        z_mu = z_params[:,:z_dim]
        z_sigma = F.softplus(z_params[:,z_dim:])

        z_dist = dist.Independent(dist.normal.Normal(z_mu,z_sigma),1)
        
        probs = 0.0
        for s in range(self.num_samples):
            z = z_dist.rsample()
            probs += F.softmax(self.classifier(z),1)
        probs = probs/self.num_samples
        return probs




class KL(Algorithm):
    """
    KL
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(KL, self).__init__(input_shape, num_classes, num_domains,
                                  hparams)
        self.featurizer = networks.Featurizer(input_shape, self.hparams, probabilistic=True)
        self.classifier = networks.Classifier(
            self.featurizer.n_outputs,
            num_classes,
            self.hparams['nonlinear_classifier'])

        # cls_lr = 100*self.hparams["lr"] if hparams['nonlinear_classifier'] else self.hparams["lr"]
        cls_lr = self.hparams["lr"]


        self.optimizer = torch.optim.Adam(
            #list(self.featurizer.parameters()) + list(self.classifier.parameters()),
            [{'params': self.featurizer.parameters(), 'lr': self.hparams["lr"]},
                {'params': self.classifier.parameters(), 'lr': cls_lr}],
            lr=self.hparams["lr"],
            weight_decay=self.hparams['weight_decay']
        )
        self.num_samples = hparams['num_samples']
        self.kl_reg = hparams['kl_reg']
        self.kl_reg_aux = hparams['kl_reg_aux']
        self.augment_softmax = hparams['augment_softmax']

    def update(self, minibatches, unlabeled=None):

        x = torch.cat([x for x,y in minibatches])
        y = torch.cat([y for x,y in minibatches])

        x_target = torch.cat(unlabeled)

        total_x = torch.cat([x,x_target])
        total_z_params = self.featurizer(total_x)
        z_dim = int(total_z_params.shape[-1]/2)
        total_z_mu = total_z_params[:,:z_dim]
        total_z_sigma = F.softplus(total_z_params[:,z_dim:]) 

        z_mu, z_sigma = total_z_mu[:x.shape[0]], total_z_sigma[:x.shape[0]]
        z_mu_target, z_sigma_target = total_z_mu[x.shape[0]:], total_z_sigma[x.shape[0]:]

        z_dist = dist.Independent(dist.normal.Normal(z_mu,z_sigma),1)
        z = z_dist.rsample()

        z_dist_target = dist.Independent(dist.normal.Normal(z_mu_target,z_sigma_target),1)
        z_target = z_dist_target.rsample()

        preds = torch.softmax(self.classifier(z),1)
        if self.augment_softmax != 0.0:
            K = 1 - self.augment_softmax * preds.shape[1]
            preds = preds*K + self.augment_softmax
        loss = F.nll_loss(torch.log(preds),y) 

        mix_coeff = dist.categorical.Categorical(x.new_ones(x.shape[0]))
        mixture = dist.mixture_same_family.MixtureSameFamily(mix_coeff,z_dist)
        mix_coeff_target = dist.categorical.Categorical(x_target.new_ones(x_target.shape[0]))
        mixture_target = dist.mixture_same_family.MixtureSameFamily(mix_coeff_target,z_dist_target)
        
        obj = loss
        kl = loss.new_zeros([])
        kl_aux = loss.new_zeros([])
        if self.kl_reg != 0.0:
            kl = (mixture_target.log_prob(z_target)-mixture.log_prob(z_target)).mean()
            obj = obj + self.kl_reg*kl
        if self.kl_reg_aux != 0.0:
            kl_aux = (mixture.log_prob(z)-mixture_target.log_prob(z)).mean()
            obj = obj + self.kl_reg_aux*kl_aux

        self.optimizer.zero_grad()
        obj.backward()
        self.optimizer.step()

        return {'loss': loss.item(), 'kl': kl.item(), 'kl_aux': kl_aux.item(), 'Jeffrey': (kl.item()+kl_aux.item())**(1/2)}

    def predict(self, x):
        z_params = self.featurizer(x)
        z_dim = int(z_params.shape[-1]/2)
        z_mu = z_params[:,:z_dim]
        z_sigma = F.softplus(z_params[:,z_dim:]) 

        z_dist = dist.Independent(dist.normal.Normal(z_mu,z_sigma),1)
        
        preds = 0.0
        for s in range(self.num_samples):
            z = z_dist.rsample()
            preds += F.softmax(self.classifier(z),1)
        preds = preds/self.num_samples

        K = 1 - self.augment_softmax * preds.shape[1]
        preds = preds*K + self.augment_softmax
        return preds



class KLGP(KL):
    """
    KL with gradient penalty
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(KLGP, self).__init__(input_shape, num_classes, num_domains,
                                 hparams)
        self.featurizer = networks.Featurizer(input_shape, self.hparams, probabilistic=True)
        self.classifier = networks.Classifier(
            self.featurizer.n_outputs,
            num_classes,
            self.hparams['nonlinear_classifier'])

        cls_lr = 100 * self.hparams["lr"] if hparams['nonlinear_classifier'] else self.hparams["lr"]

        self.optimizer = torch.optim.Adam(
            # list(self.featurizer.parameters()) + list(self.classifier.parameters()),
            [{'params': self.featurizer.parameters(), 'lr': self.hparams["lr"]},
             {'params': self.classifier.parameters(), 'lr': cls_lr}],
            lr=self.hparams["lr"],
            weight_decay=self.hparams['weight_decay']
        )
        self.num_samples = hparams['num_samples']
        self.kl_reg = hparams['kl_reg']
        self.kl_reg_aux = hparams['kl_reg_aux']
        self.augment_softmax = hparams['augment_softmax']

    def update(self, minibatches, unlabeled=None):

        x = torch.cat([x for x, y in minibatches])
        y = torch.cat([y for x, y in minibatches])

        x_target = torch.cat(unlabeled)

        total_x = torch.cat([x, x_target])
        total_z_params = self.featurizer(total_x)
        z_dim = int(total_z_params.shape[-1] / 2)
        total_z_mu = total_z_params[:, :z_dim]
        total_z_sigma = F.softplus(total_z_params[:, z_dim:])

        z_mu, z_sigma = total_z_mu[:x.shape[0]], total_z_sigma[:x.shape[0]]
        z_mu_target, z_sigma_target = total_z_mu[x.shape[0]:], total_z_sigma[x.shape[0]:]

        z_dist = dist.Independent(dist.normal.Normal(z_mu, z_sigma), 1)
        z = z_dist.rsample()

        z_dist_target = dist.Independent(dist.normal.Normal(z_mu_target, z_sigma_target), 1)
        z_target = z_dist_target.rsample()

        preds = torch.softmax(self.classifier(z), 1)
        if self.augment_softmax != 0.0:
            K = 1 - self.augment_softmax * preds.shape[1]
            preds = preds * K + self.augment_softmax
        loss = F.nll_loss(torch.log(preds), y)
        loss = self.grad_penalty(loss, self.classifier.parameters())

        mix_coeff = dist.categorical.Categorical(x.new_ones(x.shape[0]))
        mixture = dist.mixture_same_family.MixtureSameFamily(mix_coeff, z_dist)
        mix_coeff_target = dist.categorical.Categorical(x_target.new_ones(x_target.shape[0]))
        mixture_target = dist.mixture_same_family.MixtureSameFamily(mix_coeff_target, z_dist_target)

        obj = loss
        kl = loss.new_zeros([])
        kl_aux = loss.new_zeros([])
        if self.kl_reg != 0.0:
            kl = (mixture_target.log_prob(z_target) - mixture.log_prob(z_target)).mean()
            obj = obj + self.kl_reg * kl
        if self.kl_reg_aux != 0.0:
            kl_aux = (mixture.log_prob(z) - mixture_target.log_prob(z)).mean()
            obj = obj + self.kl_reg_aux * kl_aux

        self.optimizer.zero_grad()
        obj.backward()
        self.optimizer.step()

        return {'loss': loss.item(), 'kl': kl.item(), 'kl_aux': kl_aux.item()}

    def grad_penalty(self, loss, param):
        grad_params = torch.autograd.grad(loss, param, create_graph=True)
        grad_norm = 0
        for grad in grad_params:
            grad_norm += grad.pow(2).sum()
        grad_norm = grad_norm.sqrt()
        reg_loss = loss + self.hparams["grad_decay"] * grad_norm
        return reg_loss


#Using this
class KLGPv2(KLGP):
    """
    KL with controlling label information
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(KLGPv2, self).__init__(input_shape, num_classes, num_domains,
                                   hparams)
        self.featurizer = networks.Featurizer(input_shape, self.hparams, probabilistic=True)
        self.classifier = networks.Classifier(
            self.featurizer.n_outputs,
            num_classes,
            self.hparams['nonlinear_classifier'])

        cls_lr = 100 * self.hparams["lr"] if hparams['nonlinear_classifier'] else self.hparams["lr"]

        self.optimizer = torch.optim.Adam(
            # list(self.featurizer.parameters()) + list(self.classifier.parameters()),
            [{'params': self.featurizer.parameters(), 'lr': self.hparams["lr"]},
             {'params': self.classifier.parameters(), 'lr': cls_lr}],
            lr=self.hparams["lr"],
            weight_decay=self.hparams['weight_decay']
        )
        self.num_samples = hparams['num_samples']
        self.kl_reg = hparams['kl_reg']
        self.kl_reg_aux = hparams['kl_reg_aux']
        self.augment_softmax = hparams['augment_softmax']

    def update(self, minibatches, unlabeled=None):

        x = torch.cat([x for x, y in minibatches])
        y = torch.cat([y for x, y in minibatches])

        x_target = torch.cat(unlabeled)

        total_x = torch.cat([x, x_target])
        total_z_params = self.featurizer(total_x)
        z_dim = int(total_z_params.shape[-1] / 2)
        total_z_mu = total_z_params[:, :z_dim]
        total_z_sigma = F.softplus(total_z_params[:, z_dim:])

        z_mu, z_sigma = total_z_mu[:x.shape[0]], total_z_sigma[:x.shape[0]]
        z_mu_target, z_sigma_target = total_z_mu[x.shape[0]:], total_z_sigma[x.shape[0]:]

        z_dist = dist.Independent(dist.normal.Normal(z_mu, z_sigma), 1)
        z = z_dist.rsample()

        z_dist_target = dist.Independent(dist.normal.Normal(z_mu_target, z_sigma_target), 1)
        z_target = z_dist_target.rsample()

        preds = torch.softmax(self.classifier(z), 1)
        if self.augment_softmax != 0.0:
            K = 1 - self.augment_softmax * preds.shape[1]
            preds = preds * K + self.augment_softmax
        loss = F.nll_loss(torch.log(preds), y)
        loss = self.grad_penalty(loss, list(self.featurizer.parameters()) + list(self.classifier.parameters()))

        mix_coeff = dist.categorical.Categorical(x.new_ones(x.shape[0]))
        mixture = dist.mixture_same_family.MixtureSameFamily(mix_coeff, z_dist)
        mix_coeff_target = dist.categorical.Categorical(x_target.new_ones(x_target.shape[0]))
        mixture_target = dist.mixture_same_family.MixtureSameFamily(mix_coeff_target, z_dist_target)

        obj = loss
        kl = loss.new_zeros([])
        kl_aux = loss.new_zeros([])
        if self.kl_reg != 0.0:
            kl = (mixture_target.log_prob(z_target) - mixture.log_prob(z_target)).mean()
            obj = obj + self.kl_reg * kl
        if self.kl_reg_aux != 0.0:
            kl_aux = (mixture.log_prob(z) - mixture_target.log_prob(z)).mean()
            obj = obj + self.kl_reg_aux * kl_aux

        self.optimizer.zero_grad()
        obj.backward()
        self.optimizer.step()

        return {'loss': loss.item(), 'kl': kl.item(), 'kl_aux': kl_aux.item()}

class KLCL(KL):
    """
    KL with controlling label information
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(KLCL, self).__init__(input_shape, num_classes, num_domains,
                                   hparams)

        self.featurizer = networks.Featurizer(input_shape, self.hparams, probabilistic=True)
        self.classifier = networks.Classifier(
            self.featurizer.n_outputs,
            num_classes,
            self.hparams['nonlinear_classifier'])

        # cls_lr = 100 * self.hparams["lr"] if hparams['nonlinear_classifier'] else self.hparams["lr"]
        cls_lr = self.hparams["lr"] if hparams['nonlinear_classifier'] else self.hparams["lr"]

        self.aux_classifier = copy.deepcopy(self.classifier)
        self.aux_optimizer = torch.optim.Adam(
            self.aux_classifier.parameters(),
            lr=cls_lr,
            weight_decay=self.hparams['weight_decay']
        )


        self.optimizer = torch.optim.Adam(
            # list(self.featurizer.parameters()) + list(self.classifier.parameters()),
            [{'params': self.featurizer.parameters(), 'lr': self.hparams["lr"]},
             {'params': self.classifier.parameters(), 'lr': cls_lr}],
            lr=self.hparams["lr"],
            weight_decay=self.hparams['weight_decay']
        )
        self.num_samples = hparams['num_samples']
        self.kl_reg = hparams['kl_reg']
        self.kl_reg_aux = hparams['kl_reg_aux']
        self.augment_softmax = hparams['augment_softmax']

    def update(self, minibatches, unlabeled=None):

        x = torch.cat([x for x, y in minibatches])
        y = torch.cat([y for x, y in minibatches])
        x_target = torch.cat(unlabeled)
        total_x = torch.cat([x, x_target])

        total_z_params = self.featurizer(total_x)
        z_dim = int(total_z_params.shape[-1] / 2)
        total_z_mu = total_z_params[:, :z_dim]
        total_z_sigma = F.softplus(total_z_params[:, z_dim:])

        z_mu, z_sigma = total_z_mu[:x.shape[0]], total_z_sigma[:x.shape[0]]
        z_mu_target, z_sigma_target = total_z_mu[x.shape[0]:], total_z_sigma[x.shape[0]:]

        z_dist = dist.Independent(dist.normal.Normal(z_mu, z_sigma), 1)
        z = z_dist.rsample()

        z_dist_target = dist.Independent(dist.normal.Normal(z_mu_target, z_sigma_target), 1)
        z_target = z_dist_target.rsample()

        # total_z = torch.cat([z, z_target])
        preds = torch.softmax(self.classifier(z), 1)
        if self.augment_softmax != 0.0:
            K = 1 - self.augment_softmax * preds.shape[1]
            preds = preds * K + self.augment_softmax
        loss = F.nll_loss(torch.log(preds), y)
        # loss = F.nll_loss(torch.log(preds), torch.max(self.network(total_x).detach(), 1)[1])

        l2_dis = sum((w1 - w2.detach()).pow(2.0).sum()
                     for w1, w2 in zip(self.classifier.parameters(), self.aux_classifier.parameters()))

        loss += self.hparams['dis_decay'] * l2_dis

        # loss = self.grad_penalty(loss, list(self.featurizer.parameters()) + list(self.classifier.parameters()))


        mix_coeff = dist.categorical.Categorical(x.new_ones(x.shape[0]))
        mixture = dist.mixture_same_family.MixtureSameFamily(mix_coeff, z_dist)
        mix_coeff_target = dist.categorical.Categorical(x_target.new_ones(x_target.shape[0]))
        mixture_target = dist.mixture_same_family.MixtureSameFamily(mix_coeff_target, z_dist_target)

        obj = loss
        kl = loss.new_zeros([])
        kl_aux = loss.new_zeros([])
        if self.kl_reg != 0.0:
            kl = (mixture_target.log_prob(z_target) - mixture.log_prob(z_target)).mean()
            obj = obj + self.kl_reg * kl
        if self.kl_reg_aux != 0.0:
            kl_aux = (mixture.log_prob(z) - mixture_target.log_prob(z)).mean()
            obj = obj + self.kl_reg_aux * kl_aux

        # obj = self.grad_penalty(obj, list(self.featurizer.parameters()) + list(self.classifier.parameters()))

        self.optimizer.zero_grad()
        obj.backward()
        self.optimizer.step()

        # updating auxiliary network by using source domain data
        # total_z = torch.cat([z, z_target])
        total_z = z_target
        pseudo_y = torch.softmax(self.classifier(total_z), 1)
        # pseudo_z, pseudo_y = self.pseudo(total_x)

        cls_loss = F.cross_entropy(self.aux_classifier(total_z.detach()), pseudo_y.detach())
        # cls_loss = self.grad_penalty(cls_loss, self.aux_classifier.parameters())
        self.aux_optimizer.zero_grad()
        cls_loss.backward()
        self.aux_optimizer.step()

        return {'loss': loss.item(), 'kl': kl.item(), 'kl_aux': kl_aux.item()}

    # def pseudo(self, x):
    #     z_params = self.featurizer(x)
    #     z_dim = int(z_params.shape[-1] / 2)
    #     z_mu = z_params[:, :z_dim]
    #     z_sigma = F.softplus(z_params[:, z_dim:])
    #
    #     z_dist = dist.Independent(dist.normal.Normal(z_mu, z_sigma), 1)
    #
    #     # preds = 0.0
    #     # for s in range(self.num_samples):
    #     #     z = z_dist.rsample()
    #     #     preds += F.softmax(self.classifier(z), 1)
    #     # preds = preds / self.num_samples
    #     #
    #     # K = 1 - 0.05 * preds.shape[1]
    #     # preds = preds * K + 0.05
    #     # return preds
    #
    #     probs = 0.0
    #     total_z = 0.0
    #     for s in range(self.num_samples):
    #         z = z_dist.rsample()
    #         probs += F.softmax(self.classifier(z), 1)
    #         total_z += z
    #     probs = probs / self.num_samples
    #     total_z = total_z / self.num_samples
    #     return total_z, probs

    def grad_penalty(self, loss, param):
        grad_params = torch.autograd.grad(loss, param, create_graph=True)
        grad_norm = 0
        for grad in grad_params:
            grad_norm += grad.pow(2).sum()
        grad_norm = grad_norm.sqrt()
        reg_loss = loss + self.hparams["grad_decay"] * grad_norm
        return reg_loss


class KLCLGP(KLCL):
    """
    KL with controlling label information
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(KLCLGP, self).__init__(input_shape, num_classes, num_domains,
                                   hparams)

        self.featurizer = networks.Featurizer(input_shape, self.hparams, probabilistic=True)
        self.classifier = networks.Classifier(
            self.featurizer.n_outputs,
            num_classes,
            self.hparams['nonlinear_classifier'])

        self.network = nn.Sequential(networks.Featurizer(input_shape, self.hparams), copy.deepcopy(self.classifier))
        self.net_optimizer = torch.optim.Adam(
            self.network.parameters(),
            lr=self.hparams["lr"],
            weight_decay=self.hparams['weight_decay']
        )

        cls_lr = 100 * self.hparams["lr"] if hparams['nonlinear_classifier'] else self.hparams["lr"]

        self.optimizer = torch.optim.Adam(
            # list(self.featurizer.parameters()) + list(self.classifier.parameters()),
            [{'params': self.featurizer.parameters(), 'lr': self.hparams["lr"]},
             {'params': self.classifier.parameters(), 'lr': cls_lr}],
            lr=self.hparams["lr"],
            weight_decay=self.hparams['weight_decay']
        )
        self.num_samples = hparams['num_samples']
        self.kl_reg = hparams['kl_reg']
        self.kl_reg_aux = hparams['kl_reg_aux']
        self.augment_softmax = hparams['augment_softmax']

    def update(self, minibatches, unlabeled=None):

        x = torch.cat([x for x, y in minibatches])
        y = torch.cat([y for x, y in minibatches])

        x_target = torch.cat(unlabeled)

        total_x = torch.cat([x, x_target])
        total_z_params = self.featurizer(total_x)
        z_dim = int(total_z_params.shape[-1] / 2)
        total_z_mu = total_z_params[:, :z_dim]
        total_z_sigma = F.softplus(total_z_params[:, z_dim:])

        z_mu, z_sigma = total_z_mu[:x.shape[0]], total_z_sigma[:x.shape[0]]
        z_mu_target, z_sigma_target = total_z_mu[x.shape[0]:], total_z_sigma[x.shape[0]:]

        z_dist = dist.Independent(dist.normal.Normal(z_mu, z_sigma), 1)
        z = z_dist.rsample()

        z_dist_target = dist.Independent(dist.normal.Normal(z_mu_target, z_sigma_target), 1)
        z_target = z_dist_target.rsample()

        preds = torch.softmax(self.classifier(z), 1)
        if self.augment_softmax != 0.0:
            K = 1 - self.augment_softmax * preds.shape[1]
            preds = preds * K + self.augment_softmax
        loss = F.nll_loss(torch.log(preds), y)
        loss = self.grad_penalty(loss, list(self.featurizer.parameters()) + list(self.classifier.parameters()))

        mix_coeff = dist.categorical.Categorical(x.new_ones(x.shape[0]))
        mixture = dist.mixture_same_family.MixtureSameFamily(mix_coeff, z_dist)
        mix_coeff_target = dist.categorical.Categorical(x_target.new_ones(x_target.shape[0]))
        mixture_target = dist.mixture_same_family.MixtureSameFamily(mix_coeff_target, z_dist_target)

        obj = loss
        kl = loss.new_zeros([])
        kl_aux = loss.new_zeros([])
        if self.kl_reg != 0.0:
            kl = (mixture_target.log_prob(z_target) - mixture.log_prob(z_target)).mean()
            obj = obj + self.kl_reg * kl
        if self.kl_reg_aux != 0.0:
            kl_aux = (mixture.log_prob(z) - mixture_target.log_prob(z)).mean()
            obj = obj + self.kl_reg_aux * kl_aux

        self.optimizer.zero_grad()
        obj.backward()
        self.optimizer.step()

        cls_loss = F.cross_entropy(self.network(x_target), self.pseudo(x_target).detach())
        cls_loss = self.grad_penalty(cls_loss, self.network.parameters())
        self.net_optimizer.zero_grad()
        cls_loss.backward()
        self.net_optimizer.step()

        return {'loss': loss.item(), 'kl': kl.item(), 'kl_aux': kl_aux.item()}

    def pseudo(self, x):
        z_params = self.featurizer(x)
        z_dim = int(z_params.shape[-1] / 2)
        z_mu = z_params[:, :z_dim]
        z_sigma = F.softplus(z_params[:, z_dim:])

        z_dist = dist.Independent(dist.normal.Normal(z_mu, z_sigma), 1)

        # preds = 0.0
        # for s in range(self.num_samples):
        #     z = z_dist.rsample()
        #     preds += F.softmax(self.classifier(z), 1)
        # preds = preds / self.num_samples
        #
        # K = 1 - 0.05 * preds.shape[1]
        # preds = preds * K + 0.05
        # return preds

        probs = 0.0
        for s in range(self.num_samples):
            z = z_dist.rsample()
            probs += F.softmax(self.classifier(z), 1)
        probs = probs / self.num_samples
        return probs

    def grad_penalty(self, loss, param):
        grad_params = torch.autograd.grad(loss, param, create_graph=True)
        grad_norm = 0
        for grad in grad_params:
            grad_norm += grad.pow(2).sum()
        grad_norm = grad_norm.sqrt()
        reg_loss = loss + self.hparams["grad_decay"] * grad_norm
        return reg_loss




class WD(Algorithm):
    """Wasserstein Distance guided Representation Learning"""

    def __init__(self, input_shape, num_classes, num_domains,
                 hparams, class_balance=False):

        super(WD, self).__init__(input_shape, num_classes, num_domains,
                                  hparams)

        self.register_buffer('update_count', torch.tensor([0]))
        self.class_balance = class_balance

        # Algorithms
        self.featurizer = networks.Featurizer(input_shape, self.hparams)
        self.classifier = networks.Classifier(
            self.featurizer.n_outputs,
            num_classes,
            self.hparams['nonlinear_classifier'])

        self.fw = networks.MLP(self.featurizer.n_outputs,1,self.hparams)

        # Optimizers
        self.wd_opt = torch.optim.Adam(
            self.fw.parameters(),
            lr=self.hparams["lr_wd"],
            weight_decay=self.hparams['weight_decay_wd'])

        self.main_opt = torch.optim.Adam(
            (list(self.featurizer.parameters()) +
                list(self.classifier.parameters())),
            lr=self.hparams["lr"],
            weight_decay=self.hparams['weight_decay'])

    def wd_loss(self,h_s,h_t,for_fw=True):
        batch_size = h_s.shape[0]
        alpha = torch.rand([batch_size,1]).to(h_s.device)
        h_inter = h_s*alpha + h_t*(1-alpha)
        h_whole = torch.cat([h_s,h_t,h_inter],0)
        critic = self.fw(h_whole)

        critic_s = critic[:h_s.shape[0]]
        critic_t = critic[h_s.shape[0]:h_s.shape[0]+h_t.shape[0]]
        wd_loss = critic_s.mean() - critic_t.mean()

        if for_fw==False:
            return wd_loss
        else:
            epsilon = 1e-10 # for stable torch.sqrt
            grad = autograd.grad(critic.sum(),
                [h_whole], create_graph=True)[0]
            grad_penalty = ((torch.sqrt((grad**2).sum(dim=1)+epsilon)-1)**2).mean(dim=0)
            return -wd_loss + self.hparams['grad_penalty'] * grad_penalty
            

    def update(self, minibatches, unlabeled=None):
        objective = 0
        penalty = 0
        nmb = len(minibatches)

        features = [self.featurizer(xi) for xi, _ in minibatches]
        classifs = [self.classifier(fi) for fi in features]
        targets = [yi for _, yi in minibatches]

        features_target = [self.featurizer(xit) for xit in unlabeled]
        total_features = features+features_target
        total_d = len(total_features)

        for _ in range(self.hparams['wd_steps_per_step']):
            # train fw
            fw_loss = 0.0
            for i in range(total_d):
                for j in range(i + 1, total_d):
                    fw_loss += self.wd_loss(total_features[i], total_features[j], True)
            fw_loss /= (total_d * (total_d - 1) / 2)
            self.wd_opt.zero_grad()
            fw_loss.backward(retain_graph=True)
            self.wd_opt.step()


        # Train main network
        for i in range(nmb):
            objective += F.cross_entropy(classifs[i], targets[i])
        for i in range(total_d):
            for j in range(i + 1, total_d):
                penalty += self.wd_loss(total_features[i], total_features[j],False)

        objective /= nmb
        if nmb > 1:
            penalty /= (total_d * (total_d - 1) / 2)

        self.main_opt.zero_grad()
        (objective + (self.hparams['lambda_wd']*penalty)).backward()
        self.main_opt.step()

        if torch.is_tensor(penalty):
            penalty = penalty.item()

        return {'loss': objective.item(), 'penalty': penalty}

    def predict(self, x):
        return self.classifier(self.featurizer(x))


class DANN(Algorithm):
    """Domain-Adversarial Neural Networks"""

    def __init__(self, input_shape, num_classes, num_domains,
                 hparams, class_balance=False):

        super(DANN, self).__init__(input_shape, num_classes, num_domains,
                                  hparams)

        self.register_buffer('update_count', torch.tensor([0]))
        self.class_balance = class_balance

        # Algorithms
        self.featurizer = networks.Featurizer(input_shape, self.hparams)
        self.classifier = networks.Classifier(
            self.featurizer.n_outputs,
            num_classes,
            self.hparams['nonlinear_classifier'])
        self.discriminator = networks.MLP(self.featurizer.n_outputs,
            num_domains, self.hparams)
        self.class_embeddings = nn.Embedding(num_classes,
            self.featurizer.n_outputs)

        # Optimizers
        self.disc_opt = torch.optim.Adam(
            (list(self.discriminator.parameters()) +
                list(self.class_embeddings.parameters())),
            lr=self.hparams["lr_d"],
            weight_decay=self.hparams['weight_decay_d'],
            betas=(self.hparams['beta1'], 0.9))

        self.gen_opt = torch.optim.Adam(
            (list(self.featurizer.parameters()) +
                list(self.classifier.parameters())),
            lr=self.hparams["lr_g"],
            weight_decay=self.hparams['weight_decay_g'],
            betas=(self.hparams['beta1'], 0.9))

    def update(self, minibatches, unlabeled=None):
        device = "cuda" if minibatches[0][0].is_cuda else "cpu"
        self.update_count += 1
        x_each_domain = [x for x,y in minibatches] + unlabeled
        x = torch.cat([x for x,y in minibatches])
        y = torch.cat([y for x,y in minibatches])
        x_target = torch.cat(unlabeled)
        total_x = torch.cat([x,x_target])
        total_z = self.featurizer(total_x)

        z = total_z[:x.shape[0]]

        disc_input = total_z
        disc_out = self.discriminator(disc_input)
        disc_labels = torch.cat([
            torch.full((x.shape[0], ), i, dtype=torch.int64, device=device)
            for i, x in enumerate(x_each_domain)
        ])

        if self.class_balance:
            y_counts = F.one_hot(all_y).sum(dim=0)
            weights = 1. / (y_counts[all_y] * y_counts.shape[0]).float()
            disc_loss = F.cross_entropy(disc_out, disc_labels, reduction='none')
            disc_loss = (weights * disc_loss).sum()
        else:
            disc_loss = F.cross_entropy(disc_out, disc_labels)

        disc_softmax = F.softmax(disc_out, dim=1)
        input_grad = autograd.grad(disc_softmax[:, disc_labels].sum(),
            [disc_input], create_graph=True)[0]
        grad_penalty = (input_grad**2).sum(dim=1).mean(dim=0)
        disc_loss += self.hparams['grad_penalty'] * grad_penalty

        d_steps_per_g = self.hparams['d_steps_per_g_step']
        if (self.update_count.item() % (1+d_steps_per_g) < d_steps_per_g):

            self.disc_opt.zero_grad()
            disc_loss.backward()
            self.disc_opt.step()
            return {'disc_loss': disc_loss.item()}
        else:
            preds = self.classifier(z)
            classifier_loss = F.cross_entropy(preds, y)
            gen_loss = (classifier_loss +
                        (self.hparams['lambda'] * -disc_loss))
            self.disc_opt.zero_grad()
            self.gen_opt.zero_grad()
            gen_loss.backward()
            self.gen_opt.step()
            return {'gen_loss': gen_loss.item()}

    def predict(self, x):
        return self.classifier(self.featurizer(x))



class AbstractMMD(ERM):
    """
    Perform ERM while matching the pair-wise domain feature distributions
    using MMD (abstract class)
    """
    def __init__(self, input_shape, num_classes, num_domains, hparams, gaussian):
        super(AbstractMMD, self).__init__(input_shape, num_classes, num_domains,
                                  hparams)
        if gaussian:
            self.kernel_type = "gaussian"
        else:
            self.kernel_type = "mean_cov"

    def my_cdist(self, x1, x2):
        x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
        x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
        res = torch.addmm(x2_norm.transpose(-2, -1),
                          x1,
                          x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
        return res.clamp_min_(1e-30)

    def gaussian_kernel(self, x, y, gamma=[0.001, 0.01, 0.1, 1, 10, 100,
                                           1000]):
        D = self.my_cdist(x, y)
        K = torch.zeros_like(D)

        for g in gamma:
            K.add_(torch.exp(D.mul(-g)))

        return K

    def mmd(self, x, y):
        if self.kernel_type == "gaussian":
            Kxx = self.gaussian_kernel(x, x).mean()
            Kyy = self.gaussian_kernel(y, y).mean()
            Kxy = self.gaussian_kernel(x, y).mean()
            return Kxx + Kyy - 2 * Kxy
        else:
            mean_x = x.mean(0, keepdim=True)
            mean_y = y.mean(0, keepdim=True)
            cent_x = x - mean_x
            cent_y = y - mean_y
            cova_x = (cent_x.t() @ cent_x) / (len(x) - 1)
            cova_y = (cent_y.t() @ cent_y) / (len(y) - 1)

            mean_diff = (mean_x - mean_y).pow(2).mean()
            cova_diff = (cova_x - cova_y).pow(2).mean()

            return mean_diff + cova_diff

    def update(self, minibatches, unlabeled=None):
        objective = 0
        penalty = 0
        nmb = len(minibatches)

        features = [self.featurizer(xi) for xi, _ in minibatches]
        classifs = [self.classifier(fi) for fi in features]
        targets = [yi for _, yi in minibatches]

        for i in range(nmb):
            objective += F.cross_entropy(classifs[i], targets[i])
            for j in range(i + 1, nmb):
                penalty += self.mmd(features[i], features[j])

        objective /= nmb
        if nmb > 1:
            penalty /= (nmb * (nmb - 1) / 2)

        self.optimizer.zero_grad()
        (objective + (self.hparams['mmd_gamma']*penalty)).backward()
        self.optimizer.step()

        if torch.is_tensor(penalty):
            penalty = penalty.item()

        return {'loss': objective.item(), 'penalty': penalty}


class MMD(AbstractMMD):
    """
    MMD using Gaussian kernel
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(MMD, self).__init__(input_shape, num_classes,
                                          num_domains, hparams, gaussian=True)


class CORAL(AbstractMMD):
    """
    MMD using mean and covariance difference
    """

    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(CORAL, self).__init__(input_shape, num_classes,
                                         num_domains, hparams, gaussian=False)


