from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler, SGD
from torch.optim.lr_scheduler import LambdaLR
import math
from typing import Optional, List, Dict, Tuple, Callable
import numpy as np
import pdb
import support_alignment.core.classifiers


from support_alignment.core import (
    discriminators,
    classifiers,
    vat,
    feature_extractors,
    utils,
    importance_weighting,
)


def outer_product(z, y_softmax):
    return torch.bmm(y_softmax.unsqueeze(2), z.unsqueeze(1)).view(
        -1, y_softmax.size(1) * z.size(1)
    )


def create_linear_decay_lr_schedule(decay_start, decay_steps, decay_factor):
    def schedule(step):
        return 1.0 + min(max(0.0, (step - decay_start) / decay_steps), 1.0) * (
            decay_factor - 1.0
        )

    return schedule


def calc_coeff(iter_num, high=1.0, low=0.0, alpha=1.0, max_iter=1000.0):
    coeff = float(
        2.0 * (high - low) / (1.0 + np.exp(-alpha * iter_num / max_iter))
        - (high - low)
        + low
    )
    assert coeff <= high
    return coeff


class RandomLayer(nn.Module):
    def __init__(self, input_dim_list=[], output_dim=1024):
        super(RandomLayer, self).__init__()
        self.input_num = len(input_dim_list)
        self.output_dim = output_dim
        self.random_matrix = [
            torch.randn(output_dim, input_dim_list[i]) for i in range(self.input_num)
        ]

    def forward(self, input_list):
        batch_size = input_list[0].size()[0]
        return_list = [
            torch.bmm(
                self.random_matrix[i].expand(
                    batch_size, *list(self.random_matrix[i].size())
                ),
                input_list[i].unsqueeze(-1),
            )
            for i in range(self.input_num)
        ]
        return_tensor = return_list[0] / math.pow(
            float(self.output_dim), 1.0 / len(return_list)
        )
        for single in return_list[1:]:
            return_tensor = torch.mul(return_tensor, single)
        return return_tensor.view(-1, self.output_dim)

    def cuda(self):
        super(RandomLayer, self).cuda()
        self.random_matrix = [val.cuda() for val in self.random_matrix]


# Adapted from https://github.com/facebookresearch/DomainBed/blob/master/domainbed/algorithms.py
class Algorithm(torch.nn.Module):
    """
    A subclass of Algorithm implements a domain adaptation algorithm.
    Subclasses should implement the following:
    - update()
    - predict()
    """

    def __init__(self, input_shape, num_classes, hparams, data_params):
        super(Algorithm, self).__init__()
        self.hparams = hparams
        self.data_params = data_params
        self.input_shape = input_shape
        self.num_classes = num_classes

        self.optimizers = []

    def update(self, iterator):
        """
        Perform one update step, given an iterator which yields
        a labeled mini-batch from source environment
        and an unlabeled mini-batch from target environment
        """
        raise NotImplementedError

    def predict(self, x):
        raise NotImplementedError

    def optimizers_state(self):
        return [optim.state_dict() for optim in self.optimizers]

    def load_optimizers_state(self, optimizers_state):
        for optim, optim_state in zip(self.optimizers, optimizers_state):
            optim.load_state_dict(optim_state)

    def setup_schedulers(self):
        if self.hparams["lr_type"] == "decay":
            self.fx_lr_decay_start = self.hparams["fx_lr_decay_start"]
            self.fx_lr_decay_steps = self.hparams["fx_lr_decay_steps"]
            self.fx_lr_decay_factor = self.hparams["fx_lr_decay_factor"]

            lr_schedule = lambda step: 1.0
            if self.fx_lr_decay_start is not None:
                lr_schedule = create_linear_decay_lr_schedule(
                    self.fx_lr_decay_start,
                    self.fx_lr_decay_steps,
                    self.fx_lr_decay_factor,
                )

            self.fx_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
                self.fx_opt, lr_schedule
            )
            self.cls_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
                self.cls_opt, lr_schedule
            )
            self.disc_lr_scheduler = None
        elif self.hparams["lr_type"] == "inv":
            self.fx_lr_scheduler = utils.StepwiseLR(
                self.fx_opt, init_lr=self.hparams["fx_opt"]["kwargs"]["lr"]
            )
            self.cls_lr_scheduler = utils.StepwiseLR(
                self.cls_opt, init_lr=self.hparams["cls_opt"]["kwargs"]["lr"]
            )
            if hasattr(self, "disc_opt"):
                self.disc_lr_scheduler = utils.StepwiseLR(
                    self.disc_opt, init_lr=self.hparams["disc_opt"]["kwargs"]["lr"]
                )


class DANetwork(nn.Module):
    """
    config: {
        feature_extractor: {
            name -- feature extractor name
            hparams -- feature extractor hparams dict
        }
        classifier: {
            name -- classifier name
            hparams -- classifier hparams dict
        }
    }
    """

    def __init__(self, input_shape, num_classes, config):
        super(DANetwork, self).__init__()
        feature_extractor_config = config["feature_extractor"]
        classifier_config = config["classifier"]

        feature_extractor_name = feature_extractor_config["name"]
        feature_extractor_fn = getattr(feature_extractors, feature_extractor_name, None)
        if feature_extractor_fn is None:
            raise ValueError(f"Unknown feature_extractor {feature_extractor_name}")
        self.feature_extractor = feature_extractor_fn(
            input_shape, hparams=feature_extractor_config["hparams"]
        )

        classifier_name = classifier_config["name"]
        classifier_fn = getattr(classifiers, classifier_name, None)
        if classifier_fn is None:
            raise ValueError(f"Unknown classifier {classifier_name}")

        self.classifier = classifier_fn(
            self.feature_extractor.n_outputs,
            num_classes,
            hparams=classifier_config["hparams"],
        )

    def forward(self, x):
        return self.classifier(self.feature_extractor(x))


def process_param_groups(param_groups, optim_config):
    base_lr = optim_config["kwargs"]["lr"]
    base_wd = optim_config["kwargs"].get("weight_decay", 0.0)
    result_param_groups = []
    for param_group in param_groups:
        lr_factor = param_group.get("lr_factor", 1.0)
        wd_factor = param_group.get("wd_factor", 1.0)
        result_param_groups.append(
            {
                "params": param_group["params"],
                "lr": base_lr * lr_factor,
                "weight_decay": base_wd * wd_factor,
                "lr_factor": lr_factor,
                "wd_factor": wd_factor,
            }
        )
    return result_param_groups


def get_optimizer(params, optim_config):
    optim_name = optim_config["name"]
    optim_fn = getattr(torch.optim, optim_name, None)
    if optim_fn is None:
        raise ValueError(f"Unknown optimizer {optim_name}")
    return optim_fn(params, **optim_config["kwargs"])


# Adapted from https://github.com/facebookresearch/DomainBed/blob/master/domainbed/algorithms.py
class ERM(Algorithm):
    """
    Empirical Risk Minimization (ERM)

    hparams: {
        da_network -- DANetwork config dict

        fx_opt {
            name -- feature extractor optimizer name
            kwargs -- feature extractor kwargs dict
        }

        cls_opt {
            name -- classifier optimizer name
            kwargs -- classifier kwargs dict
        }

        ema:
        ema_momentum -- momentum of the exponential weight averaging (EWA)
                        applied to feature_extractor + classifier (None => no EMA)

        fx_lr_decay_start -- start of feature extractor learning rate decay
        fx_lr_decay_steps -- length of feature extractor learning rate decay
        fx_lr_decay_factor -- feature extractor learning rate decay factor
    }
    """

    def __init__(self, input_shape, num_classes, hparams, data_params):
        super(ERM, self).__init__(input_shape, num_classes, hparams, data_params)

        self.network = DANetwork(input_shape, num_classes, self.hparams["da_network"])

        self.ema_network = None
        if self.hparams["ema_momentum"] is not None:
            ema_momentum = self.hparams["ema_momentum"]

            def ema_avg_fn(averaged_model_parameter, model_parameter, num_averaged):
                return (
                    ema_momentum * averaged_model_parameter
                    + (1.0 - ema_momentum) * model_parameter
                )

            self.ema_network = torch.optim.swa_utils.AveragedModel(
                self.network, avg_fn=ema_avg_fn
            )
            for param in self.ema_network.parameters():
                param.requires_grad = False

        if hasattr(self.network.feature_extractor, "param_groups"):
            fx_params = process_param_groups(
                self.network.feature_extractor.param_groups, self.hparams["fx_opt"]
            )
        else:
            fx_params = [
                param
                for param in self.network.feature_extractor.parameters()
                if param.requires_grad
            ]

        self.fx_opt = get_optimizer(fx_params, self.hparams["fx_opt"])
        self.optimizers.append(self.fx_opt)

        cls_params = [
            param
            for param in self.network.classifier.parameters()
            if param.requires_grad
        ]
        self.cls_opt = get_optimizer(cls_params, self.hparams["cls_opt"])
        self.optimizers.append(self.cls_opt)

        self.setup_schedulers()

    def update(self, iterator):
        x_src, y_src, x_trg = next(iterator)
        z_src = self.network.feature_extractor(x_src)
        logits_src = self.network.classifier(z_src)

        loss = torch.nn.functional.cross_entropy(logits_src, y_src)

        self.fx_opt.zero_grad()
        self.cls_opt.zero_grad()
        loss.backward()
        self.fx_opt.step()
        self.fx_lr_scheduler.step()

        self.cls_opt.step()
        self.cls_lr_scheduler.step()

        if self.ema_network is not None:
            self.ema_network.update_parameters(self.network)
        stats = OrderedDict()
        extra_stats = OrderedDict()
        stats["lr"] = self.fx_lr_scheduler.get_last_lr()[0]
        stats["cls_lr"] = self.cls_lr_scheduler.get_last_lr()[0]
        stats["c_loss"] = loss.item()
        return stats, extra_stats

    def predict(self, x):
        network = self.network if self.ema_network is None else self.ema_network
        return network(x)

    def update_bn(self, loader, device):
        network = self.network if self.ema_network is None else self.ema_network
        utils.update_bn(loader, network, device)


# Adapted from https://github.com/korawat-tanwisuth/Proto_DA
class ProtoLoss(nn.Module):
    """
    Parameters:
        - **nav_t** (float): temperature parameter (1 for all experiments)
        - **beta** (float): learning rate/momentum update parameter for learning target proportions
        - **num_classes** (int): total number of classes
        - **s_par** (float, optional): coefficient in front of the bi-directional loss. 0.5 corresponds to pct. 1 corresponds to using only t to mu. 0 corresponds to only using mu to t.

    Inputs: mu_s, f_t
        - **mu_s** (tensor): weight matrix of the linear classifier, :math:`mu^s`
        - **f_t** (tensor): feature representations on target domain, :math:`f^t`

    Shape:
        - mu_s: : math: `(K,F)`, f_t: :math:`(M, F)` where F means the dimension of input features.

    """

    def __init__(
        self,
        nav_t: float,
        beta: float,
        num_classes: int,
        device: torch.device,
        s_par: Optional[float] = 0.5,
        reduction: Optional[str] = "mean",
    ):
        super(ProtoLoss, self).__init__()
        self.nav_t = nav_t
        self.s_par = s_par
        self.beta = beta
        self.prop = (torch.ones((num_classes, 1)) * (1 / num_classes)).to(device)
        self.eps = 1e-6

    def pairwise_cosine_dist(self, x, y):
        x = F.normalize(x, p=2, dim=1)
        y = F.normalize(y, p=2, dim=1)
        return 1 - torch.matmul(x, y.T)

    def get_pos_logits(self, sim_mat, prop):
        log_prior = torch.log(prop + self.eps)
        return sim_mat / self.nav_t + log_prior

    def update_prop(self, prop):
        return (1 - self.beta) * self.prop + self.beta * prop

    def forward(self, mu_s: torch.Tensor, f_t: torch.Tensor) -> torch.Tensor:
        # Update proportions
        sim_mat = torch.matmul(mu_s, f_t.T)
        old_logits = self.get_pos_logits(sim_mat.detach(), self.prop)
        s_dist_old = F.softmax(old_logits, dim=0)
        prop = s_dist_old.mean(1, keepdim=True)
        self.prop = self.update_prop(prop)

        # Calculate bi-directional transport loss
        new_logits = self.get_pos_logits(sim_mat, self.prop)
        s_dist = F.softmax(new_logits, dim=0)
        t_dist = F.softmax(sim_mat / self.nav_t, dim=1)
        cost_mat = self.pairwise_cosine_dist(mu_s, f_t)
        source_loss = (self.s_par * cost_mat * s_dist).sum(0).mean()
        target_loss = (
            ((1 - self.s_par) * cost_mat * t_dist).sum(1) * self.prop.squeeze(1)
        ).sum()
        loss = source_loss + target_loss
        return loss


class PCT(ERM):
    def __init__(self, input_shape, num_classes, hparams, data_params, device="cuda"):
        super(PCT, self).__init__(input_shape, num_classes, hparams, data_params)
        self.beta_scheduler = utils.StepwiseLR(
            None,
            init_lr=self.hparams["beta"],
            gamma=self.hparams["lr_gamma"],
            decay_rate=0.75,
        )
        # define loss function
        self.domain_loss = ProtoLoss(
            self.hparams["nav_t"],
            self.hparams["beta"],
            num_classes,
            device,
            self.hparams["s_par"],
        ).to(device)

    def update(self, iterator):
        x_src, y_src, x_trg = next(iterator)
        z_src = self.network.feature_extractor(x_src)
        z_trg = self.network.feature_extractor(x_trg)
        logits_src = self.network.classifier(z_src)

        if (
            self.hparams["da_network"]["classifier"]["hparams"]["num_hidden"]
            is not None
        ):
            prototypes_s = self.network.classifier.net[-1].weight.data.clone()
        else:
            prototypes_s = self.network.classifier.net.weight.data.clone()
        prototypes_s = prototypes_s.detach()

        cls_loss = torch.nn.functional.cross_entropy(logits_src, y_src)

        self.domain_loss.beta = self.beta_scheduler.get_lr()
        transfer_loss = self.domain_loss(prototypes_s, z_trg)
        loss = cls_loss + transfer_loss * self.hparams["trade_off"]

        self.fx_opt.zero_grad()
        self.cls_opt.zero_grad()
        loss.backward()
        self.fx_opt.step()
        self.fx_lr_scheduler.step()
        self.cls_opt.step()
        self.cls_lr_scheduler.step()

        if self.ema_network is not None:
            self.ema_network.update_parameters(self.network)

        stats = OrderedDict()
        extra_stats = OrderedDict()
        stats["lr"] = self.fx_lr_scheduler.get_last_lr()[0]
        stats["cls_lr"] = self.cls_lr_scheduler.get_last_lr()[0]
        stats["c_loss"] = loss.item()
        stats["d_loss"] = transfer_loss.item()
        return stats, extra_stats


def entropy(logits):
    return torch.mean(
        -torch.sum(
            torch.softmax(logits, dim=1) * torch.log_softmax(logits, dim=1), dim=1
        )
    )


def cdan_entropy(logits):
    rval = -torch.sum(
        torch.softmax(logits, dim=1) * torch.log_softmax(logits, dim=1), dim=1
    )
    return rval


class AbstractDiscriminatorAlignment(Algorithm):
    """
    DiscriminatorAlignment base class.

    hparams: {
        da_network -- DANetwork config dict

        discriminator: {
            hparams -- discriminator hparams dict
        }

        fx_opt {
            name -- feature extractor optimizer name
            kwargs -- feature extractor kwargs dict
        }

        cls_opt {
            name -- classifier optimizer name
            kwargs -- classifier kwargs dict
        }


        ema:
        ema_momentum -- momentum of the exponential weight averaging (EWA)
                        applied to feature_extractor + classifier (None => no EMA)

        fx_lr_decay_start -- start of feature extractor learning rate decay
        fx_lr_decay_steps -- length of feature extractor learning rate decay
        fx_lr_decay_factor -- feature extractor learning rate decay factor

        cls_weight -- weight of the classifier labeled loss (labeled source)
                      in the feature extractor's objective

        cls_trg_weight -- weight of the classifier unlabeled loss in the feature extractor's objective

        alignment_weight -- weight of the alignment loss in the feature extractor's objective
        alignment_w_steps -- alignment weight annealing steps

        disc_opt {
            name -- discriminator optimizer name
            kwargs -- discriminator optimizer kwargs dict
        }
        disc_steps -- number of discriminator training steps per one feature extractor step

        l2_weight -- weight of l2-norm regularizer on z (for both source and target)

        vat:   [if vat=True]
        cls_vat_src_weight -- weight of the VAT loss on source domain
        cls_vat_trg_weight -- weight of the VAT loss on target domain
        vat_radius, vat_xi -- VAT loss parameters
    }

    """

    def __init__(
        self,
        input_shape,
        num_classes,
        hparams,
        data_params,
        discriminator_fn,
        use_vat=False,
        conditional=False,
        entropy_weight_disc=False,
        entropy_weight_align=False,
    ):
        super(AbstractDiscriminatorAlignment, self).__init__(
            input_shape, num_classes, hparams, data_params
        )
        self.register_buffer("update_count", torch.tensor([0]))
        self.entropy_weight_disc = entropy_weight_disc
        self.entropy_weight_align = entropy_weight_align

        # feature_extractor & classifier
        self.network = DANetwork(input_shape, num_classes, self.hparams["da_network"])

        self.ema_network = None
        if self.hparams["ema_momentum"] is not None:
            ema_momentum = self.hparams["ema_momentum"]

            def ema_avg_fn(averaged_model_parameter, model_parameter, num_averaged):
                return (
                    ema_momentum * averaged_model_parameter
                    + (1.0 - ema_momentum) * model_parameter
                )

            self.ema_network = torch.optim.swa_utils.AveragedModel(
                self.network, avg_fn=ema_avg_fn
            )
            for param in self.ema_network.parameters():
                param.requires_grad = False

        if hasattr(self.network.feature_extractor, "param_groups"):
            fx_params = process_param_groups(
                self.network.feature_extractor.param_groups, self.hparams["fx_opt"]
            )
        else:
            fx_params = [
                param
                for param in self.network.feature_extractor.parameters()
                if param.requires_grad
            ]

        self.fx_opt = get_optimizer(fx_params, self.hparams["fx_opt"])
        self.optimizers.append(self.fx_opt)

        cls_params = [
            param
            for param in self.network.classifier.parameters()
            if param.requires_grad
        ]
        self.cls_opt = get_optimizer(cls_params, self.hparams["cls_opt"])
        self.optimizers.append(self.cls_opt)

        self.cls_weight = self.hparams["cls_weight"]

        self.cls_trg_weight = self.hparams["cls_trg_weight"]

        self.alignment_weight = self.hparams["alignment_weight"]
        self.alignment_w_steps = self.hparams["alignment_w_steps"]
        self.cdan_coeff = self.hparams.get("cdan_coeff", False)

        self.conditional = conditional
        dim_in = self.hparams["da_network"]["feature_extractor"]["hparams"][
            "feature_dim"
        ]
        self.random_layer = None
        if self.conditional:
            if self.hparams["use_random"]:
                self.random_layer = RandomLayer(
                    [dim_in, self.num_classes],
                    self.hparams["random_dim"],
                )
                self.random_layer.cuda()
                dim_in = self.hparams["random_dim"]
            else:
                dim_in *= num_classes
        self.discriminator = discriminator_fn(
            dim_in,
            self.hparams["discriminator"]["hparams"],
            # entropy_weight=self.entropy_weight,
        )

        self.disc_steps = self.hparams["disc_steps"]
        self.disc_opt = get_optimizer(
            self.discriminator.parameters(), self.hparams["disc_opt"]
        )
        self.optimizers.append(self.disc_opt)

        self.l2_weight = self.hparams["l2_weight"]

        self.vat_loss = None
        if use_vat:
            self.cls_vat_src_weight = self.hparams["cls_vat_src_weight"]
            self.cls_vat_trg_weight = self.hparams["cls_vat_trg_weight"]
            self.vat_loss = vat.VATLoss(
                radius=self.hparams["vat_radius"], xi=self.hparams["vat_xi"]
            )

        self.setup_schedulers()

    def disc_update(self, iterator):
        disc_stats = OrderedDict()
        disc_extra_stats = OrderedDict()

        batches = None

        weight_src = None
        weight_trg = None

        for i in range(self.disc_steps):
            disc_stats = OrderedDict()
            disc_extra_stats = OrderedDict()
            batches = next(iterator)
            x_src, _, x_trg = batches

            # Detach feature_extractor's outputs so that the gradients w.r.t. feature_extractor are not computed
            with torch.no_grad():
                z_src = self.network.feature_extractor(x_src).detach()
                z_trg = self.network.feature_extractor(x_trg).detach()

                logits_src = self.network.classifier(z_src).detach()
                logits_trg = self.network.classifier(z_trg).detach()

                if self.conditional:
                    softmax_src = F.softmax(logits_src, dim=-1)
                    softmax_trg = F.softmax(logits_trg, dim=-1)
                    if self.random_layer is not None:
                        disc_in_src = self.random_layer.forward([z_src, softmax_src])
                        disc_in_trg = self.random_layer.forward([z_trg, softmax_trg])
                    else:
                        disc_in_src = outer_product(z_src, softmax_src)
                        disc_in_trg = outer_product(z_trg, softmax_trg)
                else:
                    disc_in_src = z_src
                    disc_in_trg = z_trg

                if self.entropy_weight_disc:
                    weight_src = 1.0 + torch.exp(-cdan_entropy(logits_src))
                    weight_trg = 1.0 + torch.exp(-cdan_entropy(logits_trg))

            (
                disc_log_loss,
                disc_grad_loss,
                disc_stats,
                disc_extra_stats,
            ) = self.discriminator.disc_loss(
                disc_in_src,
                disc_in_trg,
                update_history=i != self.disc_steps - 1,
                weight_src=weight_src,
                weight_trg=weight_trg,
            )

            disc_loss = (
                disc_log_loss + self.discriminator.grad_penalty_weight * disc_grad_loss
            )

            self.disc_opt.zero_grad()
            disc_loss.backward()
            self.disc_opt.step()
            if self.disc_lr_scheduler is not None:
                self.disc_lr_scheduler.step()
        return batches, disc_stats, disc_extra_stats

    def update(self, iterator):
        batches, disc_stats, disc_extra_stats = self.disc_update(iterator)

        x_src, y_src, x_trg = batches if batches is not None else next(iterator)

        z_src = self.network.feature_extractor(x_src)
        z_trg = self.network.feature_extractor(x_trg)
        z_all = torch.cat((z_src, z_trg), dim=0)

        logits_src = self.network.classifier(z_src)
        logits_trg = self.network.classifier(z_trg)

        if self.conditional:
            softmax_src = F.softmax(logits_src, dim=-1).detach()
            softmax_trg = F.softmax(logits_trg, dim=-1).detach()
            if self.random_layer is not None:
                disc_in_src = self.random_layer.forward([z_src, softmax_src])
                disc_in_trg = self.random_layer.forward([z_trg, softmax_trg])
            else:
                disc_in_src = outer_product(z_src, softmax_src)
                disc_in_trg = outer_product(z_trg, softmax_trg)
        else:
            disc_in_src = z_src
            disc_in_trg = z_trg

        weight_src = None
        weight_trg = None

        if self.entropy_weight_align:
            weight_src = 1.0 + torch.exp(-cdan_entropy(logits_src))
            weight_trg = 1.0 + torch.exp(-cdan_entropy(logits_trg))

        alignment_weight = self.alignment_weight

        if self.cdan_coeff:
            # alignment_weight = calc_coeff(iter_num=self.update_count.item(), max_iter=self.)
            coeff = calc_coeff(self.update_count.item())
            alignment_weight = coeff * self.alignment_weight
        elif self.alignment_w_steps is not None:
            t = min(self.update_count.item() / self.alignment_w_steps, 1.0)
            alignment_weight = t * self.alignment_weight

        alignment_loss, alignment_stats = self.discriminator.alignment_loss(
            disc_in_src,
            disc_in_trg,
            update_history=True,
            weight_src=weight_src,
            weight_trg=weight_trg,
        )

        cls_loss_src = torch.nn.functional.cross_entropy(logits_src, y_src)

        fx_loss = self.cls_weight * cls_loss_src + alignment_weight * alignment_loss

        if self.cls_trg_weight > 0:
            cls_trg_weight = self.cls_trg_weight
            if self.hparams["cls_trg_weight_anneal"]:
                t = min(self.update_count.item() / self.alignment_w_steps, 1.0)
                cls_trg_weight = t * self.cls_trg_weight
            fx_loss += cls_trg_weight * entropy(logits_trg)

        cls_vat_src = None
        cls_vat_trg = None

        if self.vat_loss is not None:
            cls_vat_src, _ = self.vat_loss(self.network, x_src)
            cls_vat_trg, _ = self.vat_loss(self.network, x_trg)

            fx_loss += self.cls_vat_src_weight * cls_vat_src
            fx_loss += self.cls_vat_trg_weight * cls_vat_trg

        self.fx_opt.zero_grad()
        self.cls_opt.zero_grad()

        fx_loss.backward()
        self.fx_opt.step()
        self.fx_lr_scheduler.step()
        self.cls_opt.step()
        self.cls_lr_scheduler.step()

        if self.ema_network is not None:
            self.ema_network.update_parameters(self.network)

        stats = OrderedDict()
        extra_stats = OrderedDict()
        stats["lr"] = self.fx_lr_scheduler.get_last_lr()[0]
        stats["cls_lr"] = self.cls_lr_scheduler.get_last_lr()[0]
        stats["fx_loss"] = fx_loss.item()
        stats["c_loss"] = cls_loss_src.item()
        if cls_vat_src is not None:
            stats["c_vat_src"] = cls_vat_src.item()
        if cls_vat_trg is not None:
            stats["c_vat_trg"] = cls_vat_trg.item()

        stats.update(disc_stats)
        stats["a_loss"] = alignment_loss.item()
        stats["a_w"] = alignment_weight

        extra_stats.update(disc_extra_stats)
        extra_stats.update(alignment_stats)

        self.update_count += 1
        return stats, extra_stats

    def predict(self, x):
        network = self.network if self.ema_network is None else self.ema_network
        return network(x)


class DANN_NS(AbstractDiscriminatorAlignment):
    def __init__(self, input_shape, num_classes, hparams, data_params):
        super(DANN_NS, self).__init__(
            input_shape,
            num_classes,
            hparams,
            data_params,
            discriminator_fn=discriminators.LogLossNSDiscriminator,
        )


class VADA(AbstractDiscriminatorAlignment):
    def __init__(self, input_shape, num_classes, hparams, data_params):
        super(VADA, self).__init__(
            input_shape,
            num_classes,
            hparams,
            data_params,
            discriminator_fn=discriminators.LogLossNSDiscriminator,
            use_vat=True,
        )


class DANN_SUPP_SQ(AbstractDiscriminatorAlignment):
    def __init__(self, input_shape, num_classes, hparams, data_params):
        super(DANN_SUPP_SQ, self).__init__(
            input_shape,
            num_classes,
            hparams,
            data_params,
            discriminator_fn=discriminators.SupportLossSqDiscriminator,
        )


class DANN_SUPP_ABS(AbstractDiscriminatorAlignment):
    def __init__(self, input_shape, num_classes, hparams, data_params):
        super(DANN_SUPP_ABS, self).__init__(
            input_shape,
            num_classes,
            hparams,
            data_params,
            discriminator_fn=discriminators.SupportLossAbsDiscriminator,
        )


class CDAN_SUPP_ABS_E(AbstractDiscriminatorAlignment):
    def __init__(
        self, input_shape, num_classes, hparams, data_params, entropy_weight=False
    ):
        super(CDAN_SUPP_ABS_E, self).__init__(
            input_shape,
            num_classes,
            hparams,
            data_params,
            discriminator_fn=discriminators.SupportLossAbsDiscriminator,
            conditional=True,
            entropy_weight_align=True,
            entropy_weight_disc=True,
            use_vat=hparams["use_vat"],
        )


class CDAN(AbstractDiscriminatorAlignment):
    def __init__(
        self,
        input_shape,
        num_classes,
        hparams,
        data_params,
    ):
        super(CDAN, self).__init__(
            input_shape,
            num_classes,
            hparams,
            data_params,
            discriminator_fn=discriminators.LogLossNSDiscriminator,
            conditional=True,
            entropy_weight_align=False,
            entropy_weight_disc=False,
            # use_vat=hparams["use_vat"],
        )


class CDAN_SUPP_SQ_E(AbstractDiscriminatorAlignment):
    def __init__(
        self, input_shape, num_classes, hparams, data_params, entropy_weight=False
    ):
        super(CDAN_SUPP_SQ_E, self).__init__(
            input_shape,
            num_classes,
            hparams,
            data_params,
            discriminator_fn=discriminators.SupportLossSqDiscriminator,
            conditional=True,
            entropy_weight_align=True,
            entropy_weight_disc=True,
            use_vat=hparams["use_vat"],
        )


class SDANN(AbstractDiscriminatorAlignment):
    def __init__(self, input_shape, num_classes, hparams, data_params):
        super(SDANN, self).__init__(
            input_shape,
            num_classes,
            hparams,
            data_params,
            discriminator_fn=discriminators.SBetaDiscriminator,
        )


class IWBase(Algorithm):
    """
    Implements IWDAN/IWCDAN proposed in
    Domain Adaptation with Conditional Distribution Matching and Generalized Label Shift
    https://arxiv.org/abs/2003.04475

    Notes:
        * does not use unlabeled classifier loss (e.g. entropy)

    hparams: {
        da_network -- DANetwork config dict classifier.name must be IWClassifier

        discriminator: {
            hparams -- discriminator hparams dict
        }

        fx_opt {
            name -- feature extractor optimizer name
            kwargs -- feature extractor kwargs dict
        }

        cls_opt {
            name -- classifier optimizer name
            kwargs -- classifier kwargs dict
        }

        fx_lr_decay_start -- start of feature extractor learning rate decay
        fx_lr_decay_steps -- length of feature extractor learning rate decay
        fx_lr_decay_factor -- feature extractor learning rate decay factor

        cls_weight -- weight of the classifier loss (labeled source)
                      in the feature extractor's objective

        importance_weighting -- importance_weighting hparams dict
        iw_update_period -- importance weight update period (in steps)

        alignment_weight -- weight of the alignment loss in the feature extractor's objective
        alignment_w_steps -- alignment weight annealing steps

        disc_opt {
            name -- discriminator optimizer name
            kwargs -- discriminator optimizer kwargs dict
        }}

    """

    def __init__(
        self,
        input_shape,
        num_classes,
        hparams,
        data_params,
        conditional=False,
        discriminator_fn=None,
    ):
        super(IWBase, self).__init__(input_shape, num_classes, hparams, data_params)
        self.register_buffer("update_count", torch.tensor([0]))
        self.conditional = conditional

        self.network = DANetwork(input_shape, num_classes, self.hparams["da_network"])

        if hasattr(self.network.feature_extractor, "param_groups"):
            fx_params = process_param_groups(
                self.network.feature_extractor.param_groups, self.hparams["fx_opt"]
            )
        else:
            fx_params = [
                param
                for param in self.network.feature_extractor.parameters()
                if param.requires_grad
            ]
        self.fx_opt = get_optimizer(fx_params, self.hparams["fx_opt"])
        self.optimizers.append(self.fx_opt)

        cls_params = [
            param
            for param in self.network.classifier.parameters()
            if param.requires_grad
        ]
        self.cls_opt = get_optimizer(cls_params, self.hparams["cls_opt"])
        self.optimizers.append(self.cls_opt)

        self.setup_schedulers()

        self.cls_weight = self.hparams["cls_weight"]

        self.alignment_weight = self.hparams["alignment_weight"]
        self.alignment_w_steps = self.hparams["alignment_w_steps"]

        disc_in_dim = self.network.feature_extractor.n_outputs

        self.random_layer = None
        if conditional:
            disc_in_dim *= self.num_classes
            if self.hparams["use_random"]:
                disc_in_dim = self.hparams["random_dim"]
                feature_dim = self.hparams["da_network"]["feature_extractor"][
                    "hparams"
                ]["feature_dim"]
                self.random_layer = RandomLayer(
                    [feature_dim, self.num_classes],
                    disc_in_dim,
                )
                self.random_layer.cuda()

        if discriminator_fn is None:
            self.discriminator = discriminators.IWDiscriminator(
                disc_in_dim, self.hparams["discriminator"]["hparams"]
            )
        else:
            self.discriminator = discriminator_fn(
                disc_in_dim,
                self.hparams["discriminator"]["hparams"],
            )

        self.disc_steps = self.hparams["disc_steps"]
        self.disc_opt = get_optimizer(
            self.discriminator.parameters(), self.hparams["disc_opt"]
        )
        self.optimizers.append(self.disc_opt)

        self.iw = importance_weighting.ImportanceWeighting(
            num_classes, self.hparams["importance_weighting"]
        )
        self.iw_update_period = self.hparams["iw_update_period"]
        self.register_buffer(
            "source_class_distribution",
            torch.tensor(data_params["source_class_distribution"]),
        )
        self.register_buffer(
            "source_class_distribution_inv", 1.0 / self.source_class_distribution
        )

    def update(self, iterator):
        batches = None
        disc_stats = OrderedDict()
        disc_extra_stats = OrderedDict()
        for i in range(self.disc_steps):
            disc_stats = OrderedDict()
            disc_extra_stats = OrderedDict()
            batches = next(iterator)
            x_src, y_src, x_trg = batches
            weights_src = self.iw.get_sample_weights(y_src)

            # Detach feature_extractor's outputs so that the gradients w.r.t. feature_extractor are not computed
            with torch.no_grad():
                z_src = self.network.feature_extractor(x_src).detach()
                z_trg = self.network.feature_extractor(x_trg).detach()
                if self.conditional:
                    logits_src = self.network.classifier(z_src).detach()
                    logits_trg = self.network.classifier(z_trg).detach()
                    softmax_src = F.softmax(logits_src, dim=-1)
                    softmax_trg = F.softmax(logits_trg, dim=-1)
                    if self.random_layer is not None:
                        disc_in_src = self.random_layer.forward([z_src, softmax_src])
                        disc_in_trg = self.random_layer.forward([z_trg, softmax_trg])
                    else:
                        disc_in_src = torch.bmm(
                            softmax_src.unsqueeze(2), z_src.unsqueeze(1)
                        ).view(-1, softmax_src.size(1) * z_src.size(1))
                        disc_in_trg = torch.bmm(
                            softmax_trg.unsqueeze(2), z_trg.unsqueeze(1)
                        ).view(-1, softmax_trg.size(1) * z_trg.size(1))
                else:
                    disc_in_src = z_src
                    disc_in_trg = z_trg

            disc_loss, disc_stats, disc_extra_stats = self.discriminator.disc_loss(
                disc_in_src, disc_in_trg, weights_src
            )

            self.disc_opt.zero_grad()
            disc_loss.backward()
            self.disc_opt.step()

        x_src, y_src, x_trg = batches if batches is not None else next(iterator)
        weights_src = self.iw.get_sample_weights(y_src)

        z_src = self.network.feature_extractor(x_src)
        z_trg = self.network.feature_extractor(x_trg)
        z_all = torch.cat((z_src, z_trg), dim=0)

        logits_src = self.network.classifier(z_src)
        logits_trg = self.network.classifier(z_trg)
        self.iw.update_stats(logits_src, y_src, logits_trg)

        # Alignment loss
        if self.conditional:
            softmax_src = F.softmax(logits_src, dim=-1)
            softmax_trg = F.softmax(logits_trg, dim=-1)
            if self.random_layer is not None:
                disc_in_src = self.random_layer.forward([z_src, softmax_src])
                disc_in_trg = self.random_layer.forward([z_trg, softmax_trg])
            else:
                disc_in_src = torch.bmm(
                    softmax_src.unsqueeze(2), z_src.unsqueeze(1)
                ).view(-1, softmax_src.size(1) * z_src.size(1))
                disc_in_trg = torch.bmm(
                    softmax_trg.unsqueeze(2), z_trg.unsqueeze(1)
                ).view(-1, softmax_trg.size(1) * z_trg.size(1))
        else:
            disc_in_src = z_src
            disc_in_trg = z_trg

        alignment_weight = self.alignment_weight
        if self.alignment_w_steps is not None:
            t = min(self.update_count.item() / self.alignment_w_steps, 1.0)
            alignment_weight = t * self.alignment_weight

        alignment_loss, alignment_stats = self.discriminator.alignment_loss(
            disc_in_src, disc_in_trg, weights_src
        )

        # Classifier loss
        class_weights = self.source_class_distribution_inv
        cls_loss_src = (
            torch.mean(
                F.cross_entropy(
                    logits_src, y_src, weight=class_weights, reduction="none"
                )
                * weights_src
            )
            / self.num_classes
        )

        fx_loss = self.cls_weight * cls_loss_src + alignment_weight * alignment_loss

        self.fx_opt.zero_grad()
        self.cls_opt.zero_grad()

        fx_loss.backward()
        self.fx_opt.step()
        self.fx_lr_scheduler.step()
        self.cls_opt.step()
        self.cls_lr_scheduler.step()

        if self.update_count % self.iw_update_period == self.iw_update_period - 1:
            self.iw.update_weights(self.source_class_distribution)

        stats = OrderedDict()
        extra_stats = OrderedDict()

        stats["fx_loss"] = fx_loss.item()
        stats["c_loss"] = cls_loss_src.item()
        stats.update(disc_stats)
        stats["a_loss"] = alignment_loss.item()
        stats["a_w"] = alignment_weight

        extra_stats.update(disc_extra_stats)
        extra_stats.update(alignment_stats)

        self.update_count += 1
        return stats, extra_stats

    def predict(self, x):
        return self.network(x)

    def update_bn(self, loader, device):
        utils.update_bn(loader, self.network, device)


class IWDAN(IWBase):
    def __init__(self, input_shape, num_classes, hparams, data_params):
        super(IWDAN, self).__init__(
            input_shape, num_classes, hparams, data_params, conditional=False
        )


class IWCDAN(IWBase):
    def __init__(self, input_shape, num_classes, hparams, data_params):
        super(IWCDAN, self).__init__(
            input_shape, num_classes, hparams, data_params, conditional=True
        )


class SENTRY(Algorithm):
    def __init__(self, input_shape, num_classes, hparams, data_params, device="cuda"):
        super(SENTRY, self).__init__(input_shape, num_classes, hparams, data_params)

        self.network = DANetwork(input_shape, num_classes, self.hparams["da_network"])

        if hasattr(self.network.feature_extractor, "param_groups"):
            fx_params = process_param_groups(
                self.network.feature_extractor.param_groups, self.hparams["fx_opt"]
            )
        else:
            fx_params = [
                param
                for param in self.network.feature_extractor.parameters()
                if param.requires_grad
            ]

        self.fx_opt = get_optimizer(fx_params, self.hparams["fx_opt"])
        self.optimizers.append(self.fx_opt)

        cls_params = [
            param
            for param in self.network.classifier.parameters()
            if param.requires_grad
        ]
        self.cls_opt = get_optimizer(cls_params, self.hparams["cls_opt"])
        self.optimizers.append(self.cls_opt)

        self.setup_schedulers()

        self.queue_length = (
            256  # Queue length for computing target information entropy loss
        )
        self.committee_size = 3  # Committee size
        self.positive_threshold, self.negative_threshold = (
            self.committee_size // 2
        ) + 1, (
            self.committee_size // 2
        ) + 1  # Use majority voting scheme

        self.device = device
        self.lambda_src, self.lambda_unsup, self.lambda_ent = (
            self.hparams["src_weight"],
            self.hparams["unsup_weight"],
            self.hparams["ent_weight"],
        )
        self.queue = torch.zeros(self.queue_length).to(self.device)
        self.pointer = 0

    def update(self, iterator):
        x_src, y_src, x_trg, x_trg_ra_list = next(iterator)

        # z_src = self.network.feature_extractor(x_src)
        logits_src = self.network(x_src)
        loss_src = self.lambda_src * torch.nn.functional.cross_entropy(
            logits_src, y_src
        )
        loss = loss_src

        # mutual info loss
        score_t_og = self.network(x_trg)
        batch_sz = x_trg.shape[0]
        tgt_preds = score_t_og.max(dim=1)[1].reshape(-1)
        if (
            self.pointer + batch_sz > self.queue_length
        ):  # Deal with wrap around when ql % batchsize != 0
            rem_space = self.queue_length - self.pointer
            self.queue[self.pointer : self.queue_length] = (
                tgt_preds[:rem_space].detach() + 1
            )
            self.queue[0 : batch_sz - rem_space] = tgt_preds[rem_space:].detach() + 1
        else:
            self.queue[self.pointer : self.pointer + batch_sz] = tgt_preds.detach() + 1
        self.pointer = (self.pointer + batch_sz) % self.queue_length

        bincounts = (
            torch.bincount(self.queue.long(), minlength=self.num_classes + 1).float()
            / self.queue_length
        )
        bincounts = bincounts[1:]

        log_q = torch.log(bincounts + 1e-12).detach()
        loss_infoent = self.lambda_unsup * torch.mean(
            torch.sum(
                score_t_og.softmax(dim=1) * log_q.reshape(1, self.num_classes),
                dim=1,
            )
        )
        loss += loss_infoent

        # loss sentry
        score_t_og = self.network(x_trg).detach()
        tgt_preds = score_t_og.max(dim=1)[1].reshape(-1)

        correct_mask, incorrect_mask = torch.zeros_like(tgt_preds).to(
            self.device
        ), torch.zeros_like(tgt_preds).to(self.device)

        score_t_aug_pos, score_t_aug_neg = torch.zeros_like(
            score_t_og
        ), torch.zeros_like(score_t_og)

        for data_t_aug_curr in x_trg_ra_list:
            score_t_aug_curr = self.network(data_t_aug_curr)
            tgt_preds_aug = score_t_aug_curr.max(dim=1)[1].reshape(-1)
            consistent_idxs = (tgt_preds == tgt_preds_aug).detach()
            inconsistent_idxs = (tgt_preds != tgt_preds_aug).detach()
            correct_mask = correct_mask + consistent_idxs.type(torch.uint8)
            incorrect_mask = incorrect_mask + inconsistent_idxs.type(torch.uint8)

            score_t_aug_pos[consistent_idxs, :] = score_t_aug_curr[consistent_idxs, :]
            score_t_aug_neg[inconsistent_idxs, :] = score_t_aug_curr[
                inconsistent_idxs, :
            ]

        correct_mask, incorrect_mask = (
            correct_mask >= self.positive_threshold,
            incorrect_mask >= self.negative_threshold,
        )
        correct_ratio = (correct_mask).sum().item() / batch_sz
        incorrect_ratio = (incorrect_mask).sum().item() / batch_sz

        if correct_ratio > 0.0:
            probs_t_pos = F.softmax(score_t_aug_pos, dim=1)
            loss_cent_correct = (
                self.lambda_ent
                * correct_ratio
                * -torch.mean(
                    torch.sum(
                        probs_t_pos[correct_mask]
                        * (torch.log(probs_t_pos[correct_mask] + 1e-12)),
                        1,
                    )
                )
            )
            loss += loss_cent_correct

        if incorrect_ratio > 0.0:
            probs_t_neg = F.softmax(score_t_aug_neg, dim=1)
            loss_cent_incorrect = (
                self.lambda_ent
                * incorrect_ratio
                * torch.mean(
                    torch.sum(
                        probs_t_neg[incorrect_mask]
                        * (torch.log(probs_t_neg[incorrect_mask] + 1e-12)),
                        1,
                    )
                )
            )
            loss += loss_cent_incorrect

        self.fx_opt.zero_grad()
        self.cls_opt.zero_grad()
        loss.backward()
        self.fx_opt.step()
        self.fx_lr_scheduler.step()

        self.cls_opt.step()
        self.cls_lr_scheduler.step()

        extra_stats = OrderedDict()
        stats = OrderedDict()
        stats["lr"] = self.fx_lr_scheduler.get_last_lr()[0]
        stats["cls_lr"] = self.cls_lr_scheduler.get_last_lr()[0]
        stats["c_loss"] = loss_src.item()
        return stats, extra_stats

    def predict(self, x):
        return self.network(x)

    def update_bn(self, loader, device):
        utils.update_bn(loader, self.network, device)
