import torch
import numpy as np
import copy
import random
import argparse
import os
from datetime import datetime
import torch.optim as optim
import math
from utils.misc import *
from utils.misc_cifar import *
from utils.models import *

from misc import *


def prepare_tta(net_init, tta_method, args):
    # init model and configure
    net_adapt = copy.deepcopy(net_init)

    if tta_method == "noadapt":
        net_adapt.eval()
    elif tta_method in ["note", "bnadapt", "delta", "ods"]:
        net_adapt = configure_model_bn(net_adapt)
        net_adapt.train()
    else:
        net_adapt = configure_model_tent(net_adapt)
        net_adapt.train()

    if tta_method not in ["sar"]:
        params, _ = collect_params2(net_adapt, 'bn')
        opt = torch.optim.Adam(params,
                               lr=1e-3,
                               betas=(0.9, 0.999),
                               weight_decay=0.)
    else:
        params, _ = sam_collect_params(net_adapt, freeze_top=True)
        opt = SAM(params, torch.optim.Adam, rho=0.05, lr=1e-3,
                  weight_decay=0.)
    return net_adapt, opt


def softmax_entropy(x):
    """Entropy of softmax distribution from logits."""
    return -(x.softmax(1) * x.log_softmax(1)).sum(1)


def euclidean_distance_einsum(X, Y):
    # X, Y [n, dim], [m, dim] -> [n,m]
    XX = torch.einsum('nd, nd->n', X, X)[:, None]  # [n, 1]
    YY = torch.einsum('md, md->m', Y, Y)  # [m]
    XY = 2 * torch.matmul(X, Y.T)  # [n,m]
    return XX + YY - XY


def cosine_distance_einsum(X, Y):
    # X, Y [n, dim], [m, dim] -> [n,m]
    X = F.normalize(X, dim=1)
    Y = F.normalize(Y, dim=1)
    XX = torch.einsum('nd, nd->n', X, X)[:, None]  # [n, 1]
    YY = torch.einsum('md, md->m', Y, Y)  # [m]
    XY = 2 * torch.matmul(X, Y.T)  # [n,m]
    return XX + YY - XY


def compute_Wb(logits_curr, g_phi, args):
    preds_curr = torch.softmax(logits_curr, dim=1)

    py_curr = preds_curr.mean(0)
    eps = 1e-12
    pred_dev_curr = (-((torch.ones(args.num_classes).to(logits_curr.device) / args.num_classes) * torch.log(preds_curr + eps))).sum(dim=1).mean()

    if g_phi is not None:
        with torch.no_grad():
            W, b = g_phi(py_curr, pred_dev_curr)
    else:
        W = torch.eye(args.num_classes).to(logits_curr.device)
        b = torch.zeros(1, args.num_classes).to(logits_curr.device)

    return W, b


def copy_model_and_optimizer(model, optimizer):
    """Copy the model and optimizer states for resetting after adaptation."""
    model_state = copy.deepcopy(model.state_dict())
    optimizer_state = copy.deepcopy(optimizer.state_dict())
    return model_state, optimizer_state


def load_model_and_optimizer(model, optimizer, model_state, optimizer_state):
    """Restore the model and optimizer states from copies."""
    model.load_state_dict(model_state, strict=True)
    optimizer.load_state_dict(optimizer_state)


def tta_noadapt(x_te, y_te, indices, net, args, logits_tr_info=None, g_phi=None):
    # network for adaptation
    net_adapt, _ = prepare_tta(net, "noadapt", args)

    # network for computing W and b
    net_ref = copy.deepcopy(net)
    net_ref.train()

    # number of test batches
    nb = math.ceil(len(indices) / args.test_batch_size)
    perm = indices

    acc_te = 0.
    for counter in range(nb):
        idx_curr = perm[counter * args.test_batch_size: (counter + 1) * args.test_batch_size]
        x_curr, y_curr = x_te[idx_curr].to(args.device), y_te[idx_curr]  # .to(args.device)

        # compute W and b
        with torch.no_grad():
            outputs_ref = net_ref(x_curr) / args.smax_temp
            outputs = net_adapt(x_curr)

        W_curr, b_curr = compute_Wb(outputs_ref, g_phi, args)
        

        # modify
        outputs_norm = torch.norm(outputs, dim=1, p=2, keepdim=True)
        outputs = outputs_norm * F.normalize((outputs / args.smax_temp) @ W_curr + b_curr, p=2, dim=1)

        acc_te += (outputs.argmax(dim=1).detach().cpu() == y_curr).float().sum()

    acc_te = acc_te / len(indices)
    return acc_te, inp_all, args.smax_temp


def tta_bnadapt(x_te, y_te, indices, net, args, logits_tr_info=None, g_phi=None):
    # network for adaptation
    net_adapt, _ = prepare_tta(net, "bnadapt", args)

    # network for computing W and b
    net_ref = copy.deepcopy(net)
    net_ref.train()

    # number of test batches
    nb = math.ceil(len(indices) / args.test_batch_size)
    perm = indices

    acc_te = 0.
    inp_all = 0.
    for counter in range(nb):
        idx_curr = perm[counter * args.test_batch_size: (counter + 1) * args.test_batch_size]
        x_curr, y_curr = x_te[idx_curr].to(args.device), y_te[idx_curr]  # .to(args.device)

        # compute W and b
        with torch.no_grad():
            outputs_ref = net_ref(x_curr) / args.smax_temp
            # forward
            outputs = net_adapt(x_curr)

        W_curr, b_curr = compute_Wb(outputs_ref, g_phi, args)

        

        # modify
        outputs_norm = torch.norm(outputs, dim=1, p=2, keepdim=True)
        outputs = outputs_norm * F.normalize((outputs / args.smax_temp) @ W_curr + b_curr, p=2, dim=1)

        acc_te += (outputs.argmax(dim=1).detach().cpu() == y_curr).float().sum()

    acc_te = acc_te / len(indices)
    inp_all /= nb
    return acc_te, inp_all, args.smax_temp


def tta_tent(x_te, y_te, indices, net, args, logits_tr_info=None, g_phi=None):
    # network for adaptation
    net_adapt, opt = prepare_tta(net, "tent", args)

    # network for computing W and b
    net_ref = copy.deepcopy(net)
    net_ref.train()

    # number of test batches
    nb = math.ceil(len(indices) / args.test_batch_size)
    perm = indices

    acc_te = 0.
    inp_all = 0.
    for counter in range(nb):
        idx_curr = perm[counter * args.test_batch_size: (counter + 1) * args.test_batch_size]
        x_curr, y_curr = x_te[idx_curr].to(args.device), y_te[idx_curr]  # .to(args.device)

        # compute W and b
        with torch.no_grad():
            outputs_ref = net_ref(x_curr) / args.smax_temp

        W_curr, b_curr = compute_Wb(outputs_ref, g_phi, args)

        # forward
        outputs = net_adapt(x_curr)

        # modify
        outputs_norm = torch.norm(outputs, dim=1, p=2, keepdim=True)
        outputs = outputs_norm * F.normalize((outputs / args.smax_temp) @ W_curr + b_curr, p=2, dim=1)

        target = torch.softmax(outputs, dim=-1)

        tta_loss = torch.logsumexp(outputs, dim=1) - (target * outputs).sum(dim=1)
        tta_loss = tta_loss.mean()

        opt.zero_grad()
        tta_loss.backward()
        opt.step()

        acc_te += (outputs.argmax(dim=1).detach().cpu() == y_curr).float().sum()

    acc_te = acc_te / len(indices)
    inp_all /= nb
    return acc_te, inp_all, args.smax_temp


def tta_note(x_te, y_te, indices, _net, args, logits_tr_info=None, g_phi=None):
    # network for computing W and b
    net_ref = copy.deepcopy(_net)
    net_ref.train()

    # number of test batches
    nb = math.ceil(len(indices) / args.test_batch_size)
    perm = indices
    memory_size = args.test_batch_size
    iabn_k = 4
    skip_thres = 1
    num_examples = len(indices)

    # network for adaptation
    net = copy.deepcopy(_net)
    convert_iabn(net, iabn_k, skip_thres)

    # load fine-tuned model
    ckpt_path = "directory of fine-tuned classifier with IABN layer"
    checkpoint = torch.load(ckpt_path)
    net.load_state_dict(checkpoint, strict=True)

    mem = PBRS(capacity=memory_size, num_classes=args.num_classes)
    fifo = FIFO(capacity=args.test_batch_size)
    # build optimizer
    for param in net.parameters():
        param.requires_grad = False

    for nm, module in net.named_modules():
        if isinstance(module, InstanceAwareBatchNorm2d) or isinstance(module, InstanceAwareBatchNorm1d):
            for param in module.parameters():
                param.requires_grad = True

    opt = optim.Adam(net.parameters(), lr=1e-3,
                     betas=(0.9, 0.999),
                     weight_decay=0.)

    net_state, opt_state = copy_model_and_optimizer(net, opt)
    net.train()

    acc_te = 0.
    inp_all = 0.
    for counter in range(nb):
        idx_curr = perm[counter * args.test_batch_size: (counter + 1) * args.test_batch_size]
        x_curr, y_curr = x_te[idx_curr].to(args.device), y_te[idx_curr]  # .to(args.device)

        if g_phi is None:
            net.eval()
            with torch.no_grad():
                outputs = net(x_curr)

            for i in range(x_curr.size(0)):
                f, c = x_curr[i], y_curr[i]
                pseudo_cls = outputs[i].max(dim=0, keepdim=False)[1]
                mem.add_instance([f.cpu(), pseudo_cls.cpu(), c.cpu(), 0])
                fifo.add_instance([f.cpu(), pseudo_cls.cpu()])

        else:
            net.train()
            with torch.no_grad():
                outputs_ref = net_ref(x_curr) / args.smax_temp
                outputs = outputs_ref
                W_mixed, b_mixed = compute_Wb(outputs_ref, g_phi, args)
                outputs_norm = torch.norm(outputs, dim=1, keepdim=True, p=2)
                outputs = outputs_norm * F.normalize((outputs / args.smax_temp) @ W_mixed + b_mixed, dim=1, p=2)

            for i in range(x_curr.size(0)):
                f, c = x_curr[i], y_curr[i]
                pseudo_cls = outputs[i].max(dim=0, keepdim=False)[1]
                mem.add_instance([f.cpu(), pseudo_cls.cpu(), c.cpu(), 0])
                fifo.add_instance([f.cpu(), pseudo_cls.cpu()])

        # one-step adaptation
        net.train()
        memory_samples_x, memory_samples_p = mem.get_memory()
        memory_samples_x = torch.stack(memory_samples_x).to(x_curr.device)
        memory_samples_p = torch.stack(memory_samples_p)

        outputs = net(memory_samples_x)

        target = torch.softmax(outputs, dim=-1)

        ent_loss = torch.logsumexp(outputs, dim=1) - (target * outputs).sum(dim=1)
        loss = ent_loss.mean() 

        tta_loss = loss.sum(0)  

        opt.zero_grad()
        tta_loss.backward()
        opt.step()

        # evaluation
        net.eval()  # x_curr, y_curr
        with torch.no_grad():
            outputs = net(x_curr)
            acc_te += (outputs.argmax(dim=1).detach().cpu() == y_curr).float().sum()

    acc_te = acc_te / len(indices)
    return acc_te, inp_all, args.smax_temp

def tta_pl(x_te, y_te, indices, net, args, logits_tr_info=None, g_phi=None):
    # network for adaptation
    net_adapt, opt = prepare_tta(net, "pl", args)

    # network for computing W and b
    net_ref = copy.deepcopy(net)
    net_ref.train()

    # number of test batches
    nb = math.ceil(len(indices) / args.test_batch_size)
    perm = indices

    acc_te = 0.
    inp_all = 0.
    for counter in range(nb):
        idx_curr = perm[counter * args.test_batch_size: (counter + 1) * args.test_batch_size]
        x_curr, y_curr = x_te[idx_curr].to(args.device), y_te[idx_curr]  # .to(args.device)

        # compute W and b
        with torch.no_grad():
            outputs_ref = net_ref(x_curr) / args.smax_temp

        W_curr, b_curr = compute_Wb(outputs_ref, g_phi, args)

        # forward
        outputs = net_adapt(x_curr)

        # modify
        outputs_norm = torch.norm(outputs, dim=1, p=2, keepdim=True)
        outputs = outputs_norm * F.normalize((outputs / args.smax_temp) @ W_curr + b_curr, p=2, dim=1)

        target = F.one_hot(outputs.argmax(dim=1), num_classes=args.num_classes).float()

        idx_rbl = torch.softmax(outputs, dim=-1).max(dim=1)[0] > 0.95

        tta_loss = torch.logsumexp(outputs[idx_rbl], dim=1) - (target[idx_rbl] * outputs[idx_rbl]).sum(dim=1)
        tta_loss = tta_loss.mean()

        opt.zero_grad()
        tta_loss.backward()
        opt.step()

        acc_te += (outputs.argmax(dim=1).detach().cpu() == y_curr).float().sum()

    acc_te = acc_te / len(indices)
    inp_all /= nb
    return acc_te, inp_all, args.smax_temp


def tta_lame(x_te, y_te, indices, _net, args, g_phi=None):
    # network for adaptation
    net, _ = prepare_tta(_net, "tent", args)

    net_ref = copy.deepcopy(_net)
    net_ref = configure_model_tent(net_ref)

    nb = math.ceil(len(indices)/ args.test_batch_size)
    perm = indices

    acc_te = 0.
    inp_all = 0.

    for counter in range(nb):
        idx_curr = perm[counter*args.test_batch_size:(counter+1)*args.test_batch_size]
        x_curr, y_curr = x_te[idx_curr].cuda(), y_te[idx_curr].cpu()

        with torch.no_grad():
            logits_ref = net_ref(x_curr)  / args.smax_temp
            outputs, feats = net(x_curr, True)

        W_mixed, b_mixed = compute_Wb(logits_ref, g_phi, args)

        outputs_norm = torch.norm(outputs, dim=1, keepdim=True, p=2)
        outputs = outputs_norm * F.normalize((outputs / args.smax_temp) @ W_mixed + b_mixed, dim=1, p=2)

        refined_outputs = compute_lame(torch.softmax(outputs,-1), feats, metric='knn', k=5)

        acc_te += (refined_outputs.argmax(dim=1).detach().cpu() == y_curr).float().sum()

    acc_te = acc_te / len(indices)
    return acc_te, inp_all, args.smax_temp


def tta_sar(x_te, y_te, indices, net, args, logits_tr_info=None, g_phi=None):
    # network for adaptation
    net_adapt, opt = prepare_tta(net, "sar", args)

    # network for computing W and b
    net_ref = copy.deepcopy(net)
    net_ref.train()

    # number of test batches
    nb = math.ceil(len(indices) / args.test_batch_size)
    perm = indices

    net_state, opt_state = copy_model_and_optimizer(net_adapt, opt)
    ema = None
    margin = math.log(args.num_classes) * 0.40
    loss_fn = HLoss_with_margin(1., margin=margin)

    acc_te = 0.
    inp_all = 0.
    for counter in range(nb):
        idx_curr = perm[counter * args.test_batch_size: (counter + 1) * args.test_batch_size]
        x_curr, y_curr = x_te[idx_curr].to(args.device), y_te[idx_curr]

        # compute W and b
        with torch.no_grad():
            outputs_ref = net_ref(x_curr) / args.smax_temp

        W_curr, b_curr = compute_Wb(outputs_ref, g_phi, args)

        # forward and adapt
        net_adapt.train()

        # first step
        logits_data1 = net_adapt(x_curr)

        # modify

        outputs_norm = torch.norm(logits_data1, dim=1, p=2, keepdim=True)
        logits_data1 = outputs_norm * F.normalize((logits_data1 / args.smax_temp) @ W_curr + b_curr, p=2, dim=1)

        loss_first = loss_fn(logits_data1)

        opt.zero_grad()
        loss_first.backward()

        # compute \hat{\epsilon(\Theta)} for first order approximation
        opt.first_step(zero_grad=True)
        logits_data = net_adapt(x_curr)

        outputs_norm = torch.norm(logits_data, dim=1, p=2, keepdim=True)
        logits_data = outputs_norm * F.normalize((logits_data / args.smax_temp) @ W_curr + b_curr, p=2, dim=1)

        # second time backward, update model weights using gradients at \Theta+\hat{\epsilon(\Theta)}
        loss_second = loss_fn(logits_data)

        loss_second.backward()
        opt.second_step(zero_grad=True)

        acc_te += (logits_data1.argmax(dim=1).detach().cpu() == y_curr).float().sum()


        # update ema
        ema = update_ema(ema, loss_second.item())
        # perform model recovery
        reset_flag = False

        if ema < 0.2:
            # reset
            load_model_and_optimizer(net_adapt, opt, net_state, opt_state)
            # print("%d/%d,now reset the model"%(counter, nb))
            ema = None

    acc_te = acc_te / len(indices)
    inp_all /= nb
    return acc_te, inp_all, args.smax_temp


def tta_ods(x_te, y_te, indices, net, args, logits_tr_info=None, g_phi=None):
    # network for adaptation
    net_adapt, opt = prepare_tta(net, "ods", args)

    # network for computing W and b
    net_ref = copy.deepcopy(net)
    net_ref.train()

    # number of test batches
    nb = math.ceil(len(indices) / args.test_batch_size)
    perm = indices
    memory_size = 20 * args.num_classes

    queue_dist = LabelDistributionQueue(num_class=args.num_classes, capacity=memory_size)

    acc_te = 0.
    inp_all = 0.
    for counter in range(nb):
        idx_curr = perm[counter * args.test_batch_size: (counter + 1) * args.test_batch_size]
        x_curr, y_curr = x_te[idx_curr].to(args.device), y_te[idx_curr]  # .to(args.device)

        # memory update
        net_adapt.eval()
        with torch.no_grad():
            # obtain feature representation using net_adapt
            outputs, feats = net_adapt(x_curr, True)

            # obtain prediction using net_ref
            logits_ref = net_ref(x_curr)

            # compute W and b
            W_curr, b_curr = compute_Wb(logits_ref / args.smax_temp, g_phi, args)

            # prediction refinement
            outputs_norm = torch.norm(logits_ref, dim=1, p=2, keepdim=True)
            logits_ref = outputs_norm * F.normalize((logits_ref / args.smax_temp) @ W_curr + b_curr, p=2, dim=1)

            pseudo_cls = logits_ref.max(dim=1, keepdim=False)[1].clone().detach().cpu()

        optim_dist = compute_lame(torch.softmax(logits_ref, dim=-1), feats, 'knn', 5)
        final_prediction = torch.sqrt(torch.softmax(outputs, dim=1) * optim_dist)
        acc_te += (final_prediction.argmax(dim=1).detach().cpu() == y_curr).float().sum()

        # for adaptation
        net_adapt.train()
        queue_dist.update(pseudo_cls)

        weight = 1.0 - queue_dist.get() + 0.1
        weight = weight / weight.sum()
        weight = weight[pseudo_cls].to(x_curr.device)

        outputs = net_adapt(x_curr)

        target = torch.softmax(outputs, dim=-1)

        ent_loss = torch.logsumexp(outputs, dim=1) - (target * outputs).sum(dim=1)
        loss = ent_loss.mean() * weight

        tta_loss = loss.sum(0) / weight.sum(0)

        opt.zero_grad()
        tta_loss.backward()
        opt.step()


    acc_te = acc_te / len(indices)
    inp_all /= nb
    return acc_te, inp_all, args.smax_temp


def tta_delta(x_te, y_te, indices, net, args, logits_tr_info=None, g_phi=None):
    # network for computing W and b
    net_ref = copy.deepcopy(net)
    net_ref.train()

    # number of test batches
    nb = math.ceil(len(indices) / args.test_batch_size)
    perm = indices

    # prepare DELTA
    old_prior = 0.95
    dot = 0.95

    net_adapt = copy.deepcopy(net)
    net_adapt.requires_grad_(False)
    replace_mods = find_bns(net_adapt, old_prior)
    for (parent, name, child) in replace_mods:
        setattr(parent, name, child)

    net_adapt.requires_grad_(False)
    params = []
    for nm, m in net_adapt.named_modules():
        if isinstance(m, nn.BatchNorm2d):
            for np, p in m.named_parameters():
                if np in ['weight', 'bias']:
                    p.requires_grad_(True)
                    params.append(p)

    opt = torch.optim.Adam(params,
                           lr=1e-3,
                           betas=(0.9, 0.999),
                           weight_decay=0.)

    acc_te = 0.
    inp_all = 0.
    qhat = torch.zeros(1, args.num_classes).cuda() + (1. / args.num_classes)
    for counter in range(nb):
        idx_curr = perm[counter * args.test_batch_size: (counter + 1) * args.test_batch_size]
        x_curr, y_curr = x_te[idx_curr].to(args.device), y_te[idx_curr]  # .to(args.device)

        with torch.no_grad():
            logits_ref = net_ref(x_curr) / args.smax_temp

        W_curr, b_curr = compute_Wb(logits_ref, g_phi, args)

        # update model
        outputs = net_adapt(x_curr)

        # refine model

        outputs_norm = torch.norm(outputs, dim=1, p=2, keepdim=True)
        outputs = outputs_norm * F.normalize((outputs / args.smax_temp) @ W_curr + b_curr, p=2, dim=1)

        p = torch.softmax(outputs, dim=1)
        pmax, pls = p.max(dim=1)
        logp = F.log_softmax(outputs, dim=1)
        ent_weight = torch.ones_like(pls)
        entropys = -(p * logp).sum(dim=1)
        class_weight = 1. / qhat
        class_weight = class_weight / class_weight.sum()
        sample_weight = class_weight.gather(1, pls.view(1, -1)).squeeze()
        sample_weight = sample_weight / sample_weight.sum() * len(pls)
        ent_weight = ent_weight * sample_weight
        tta_loss = (entropys * ent_weight).mean()

        opt.zero_grad()
        tta_loss.backward()
        opt.step()

        with torch.no_grad():
            qhat = dot * qhat + (1. - dot) * p.mean(dim=0, keepdim=True)

        acc_te += (outputs.argmax(dim=1).detach().cpu() == y_curr).float().sum()

    acc_te = acc_te / len(indices)
    inp_all /= nb
    return acc_te, inp_all, args.smax_temp



##############################################################################################################################
# utils for TTA methods
class HLoss(torch.nn.Module):
    def __init__(self, temp_factor=1.0):
        super(HLoss, self).__init__()
        self.temp_factor = temp_factor

    def forward(self, x):
        softmax = F.softmax(x / self.temp_factor, dim=1)
        entropy = -softmax * torch.log(softmax + 1e-6)
        entropy = entropy.sum(1)

        b = entropy.mean()

        return b


class HLoss_with_margin(torch.nn.Module):
    def __init__(self, temp_factor=1.0, margin=1.):
        super(HLoss_with_margin, self).__init__()
        self.temp_factor = temp_factor
        self.margin = margin

    def forward(self, x):
        softmax = F.softmax(x / self.temp_factor, dim=1)
        entropy = -softmax * torch.log(softmax + 1e-6)
        entropy = entropy.sum(dim=1)

        filter_ids_1 = torch.where(entropy < self.margin)
        entropy = entropy[filter_ids_1]

        b = entropy.mean()

        return b


# codes from "SoTTA"
def sam_collect_params(model, freeze_top=False):
    params = []
    names = []
    for nm, m in model.named_modules():
        # skip top layers for adaptation: layer4 for ResNets and blocks9-11 for Vit-Base
        if freeze_top:
            if 'layer4' in nm:
                continue
            if 'blocks.9' in nm:
                continue
            if 'blocks.10' in nm:
                continue
            if 'blocks.11' in nm:
                continue
            if 'norm.' in nm:
                continue
            if nm in ['norm']:
                continue

        if isinstance(m, (torch.nn.BatchNorm2d, torch.nn.LayerNorm, torch.nn.GroupNorm)):
            for np, p in m.named_parameters():
                if np in ['weight', 'bias']:  # weight is scale, bias is shift
                    params.append(p)
                    names.append(f"{nm}.{np}")

    return params, names


class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        # print(self.base_optimizer, self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][
            0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
            torch.stack([
                ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                for group in self.param_groups for p in group["params"]
                if p.grad is not None
            ]),
            p=2
        )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups


def convert_iabn(module, iabn_k, skip_thres, **kwargs):
    module_output = module
    if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):
        # print (module)
        IABN = InstanceAwareBatchNorm2d if isinstance(module, nn.BatchNorm2d) else InstanceAwareBatchNorm1d
        module_output = IABN(
            num_channels=module.num_features,
            k=iabn_k,
            eps=module.eps,
            momentum=module.momentum,
            affine=module.affine
        )

        module_output._bn = copy.deepcopy(module)
        # print (module.bias.data, module_output._bn.bias.data)

        # import time
        # time.sleep(10)

    for name, child in module.named_children():
        module_output.add_module(
            name, convert_iabn(child, iabn_k, skip_thres, **kwargs)
        )
    del module
    # print (module)
    return module_output


class InstanceAwareBatchNorm2d(nn.Module):
    def __init__(self, num_channels, k=3.0, eps=1e-5, momentum=0.1, affine=True):
        super(InstanceAwareBatchNorm2d, self).__init__()
        self.num_channels = num_channels
        self.eps = eps
        self.k = k
        self.affine = affine
        self._bn = nn.BatchNorm2d(num_channels, eps=eps,
                                  momentum=momentum, affine=affine)
        self.skip_thres = 1.

    def _softshrink(self, x, lbd):
        x_p = F.relu(x - lbd, inplace=True)
        x_n = F.relu(-(x + lbd), inplace=True)
        y = x_p - x_n
        return y

    def forward(self, x):
        b, c, h, w = x.size()
        sigma2, mu = torch.var_mean(x, dim=[2, 3], keepdim=True, unbiased=True)  # IN

        if self.training:
            # use batch stats
            _ = self._bn(x)
            sigma2_b, mu_b = torch.var_mean(x, dim=[0, 2, 3], keepdim=True, unbiased=True)
        else:
            if self._bn.track_running_stats == False and self._bn.running_mean is None and self._bn.running_var is None:  # use batch stats
                # use batch stats
                sigma2_b, mu_b = torch.var_mean(x, dim=[0, 2, 3], keepdim=True, unbiased=True)
            else:
                # stored one
                mu_b = self._bn.running_mean.view(1, c, 1, 1)
                sigma2_b = self._bn.running_var.view(1, c, 1, 1)

        if h * w <= self.skip_thres:  # conf.args.skip_thres:
            mu_adj = mu_b
            sigma2_adj = sigma2_b
        else:
            s_mu = torch.sqrt((sigma2_b + self.eps) / (h * w))
            s_sigma2 = (sigma2_b + self.eps) * np.sqrt(2 / (h * w - 1))

            mu_adj = mu_b + self._softshrink(mu - mu_b, self.k * s_mu)

            sigma2_adj = sigma2_b + self._softshrink(sigma2 - sigma2_b, self.k * s_sigma2)

            sigma2_adj = F.relu(sigma2_adj)  # non negative

        x_n = (x - mu_adj) * torch.rsqrt(sigma2_adj + self.eps)
        if self.affine:
            weight = self._bn.weight.view(c, 1, 1)
            bias = self._bn.bias.view(c, 1, 1)
            x_n = x_n * weight + bias
        return x_n


class InstanceAwareBatchNorm1d(nn.Module):
    def __init__(self, num_channels, k=3.0, eps=1e-5, momentum=0.1, affine=True):
        super(InstanceAwareBatchNorm1d, self).__init__()
        self.num_channels = num_channels
        self.k = k
        self.eps = eps
        self.affine = affine
        self._bn = nn.BatchNorm1d(num_channels, eps=eps,
                                  momentum=momentum, affine=affine)

    def _softshrink(self, x, lbd):
        x_p = F.relu(x - lbd, inplace=True)
        x_n = F.relu(-(x + lbd), inplace=True)
        y = x_p - x_n
        return y

    def forward(self, x):
        b, c, l = x.size()
        sigma2, mu = torch.var_mean(x, dim=[2], keepdim=True, unbiased=True)
        if self.training:
            _ = self._bn(x)
            sigma2_b, mu_b = torch.var_mean(x, dim=[0, 2], keepdim=True, unbiased=True)
        else:
            if self._bn.track_running_stats == False and self._bn.running_mean is None and self._bn.running_var is None:  # use batch stats
                sigma2_b, mu_b = torch.var_mean(x, dim=[0, 2], keepdim=True, unbiased=True)
            else:
                mu_b = self._bn.running_mean.view(1, c, 1)
                sigma2_b = self._bn.running_var.view(1, c, 1)

        if l <= skip_thres:  # conf.args.skip_thres:
            mu_adj = mu_b
            sigma2_adj = sigma2_b

        else:
            s_mu = torch.sqrt((sigma2_b + self.eps) / l)  ##
            s_sigma2 = (sigma2_b + self.eps) * np.sqrt(2 / (l - 1))

            mu_adj = mu_b + self._softshrink(mu - mu_b, self.k * s_mu)
            sigma2_adj = sigma2_b + self._softshrink(sigma2 - sigma2_b, self.k * s_sigma2)
            sigma2_adj = F.relu(sigma2_adj)

        x_n = (x - mu_adj) * torch.rsqrt(sigma2_adj + self.eps)

        if self.affine:
            weight = self._bn.weight.view(c, 1)
            bias = self._bn.bias.view(c, 1)
            x_n = x_n * weight + bias

        return x_n


class FIFO():
    def __init__(self, capacity):
        # no domain label
        self.data = [[], []]  # [[], [], []]
        self.capacity = capacity
        pass

    def get_memory(self):
        return self.data

    def get_occupancy(self):
        return len(self.data[0])

    def add_instance(self, instance):
        assert (len(instance) == 2)  # (len(instance) == 3)

        if self.get_occupancy() >= self.capacity:
            self.remove_instance()

        for i, dim in enumerate(self.data):
            dim.append(instance[i])

    def remove_instance(self):
        for dim in self.data:
            dim.pop(0)
        pass


class Reservoir():  # Time uniform

    def __init__(self, capacity):
        super(Reservoir, self).__init__(capacity)
        self.data = [[], []]  # [[], [], []]
        self.capacity = capacity
        self.counter = 0

    def get_memory(self):
        return self.data

    def get_occupancy(self):
        return len(self.data[0])

    def add_instance(self, instance):
        assert (len(instance) == 2)  # (len(instance) == 3)
        is_add = True
        self.counter += 1

        if self.get_occupancy() >= self.capacity:
            is_add = self.remove_instance()

        if is_add:
            for i, dim in enumerate(self.data):
                dim.append(instance[i])

    def remove_instance(self):

        m = self.get_occupancy()
        n = self.counter
        u = random.uniform(0, 1)
        if u <= m / n:
            tgt_idx = random.randrange(0, m)  # target index to remove
            for dim in self.data:
                dim.pop(tgt_idx)
        else:
            return False
        return True


class PBRS():

    def __init__(self, capacity, num_classes):
        self.num_classes = num_classes
        self.data = [[[], []] for _ in range(
            self.num_classes)]  # [[[], [], []] for _ in range(self.num_classes)] #feat, pseudo_cls, domain, cls, loss
        self.counter = [0] * self.num_classes
        self.marker = [''] * self.num_classes
        self.capacity = capacity
        pass

    def print_class_dist(self):

        print(self.get_occupancy_per_class())

    def print_real_class_dist(self):

        occupancy_per_class = [0] * self.num_classes
        for i, data_per_cls in enumerate(self.data):
            for cls in data_per_cls[3]:
                occupancy_per_class[cls] += 1
        print(occupancy_per_class)

    def get_memory(self):

        data = self.data

        tmp_data = [[], []]  # [[], [], []]
        for data_per_cls in data:
            feats, cls = data_per_cls
            tmp_data[0].extend(feats)
            tmp_data[1].extend(cls)
            # tmp_data[2].extend(dls)

        return tmp_data

    def get_occupancy(self):
        occupancy = 0
        for data_per_cls in self.data:
            occupancy += len(data_per_cls[0])
        return occupancy

    def get_occupancy_per_class(self):
        occupancy_per_class = [0] * self.num_classes
        for i, data_per_cls in enumerate(self.data):
            occupancy_per_class[i] = len(data_per_cls[0])
        return occupancy_per_class

    def update_loss(self, loss_list):
        for data_per_cls in self.data:
            feats, cls, _, losses = data_per_cls
            for i in range(len(losses)):
                losses[i] = loss_list.pop(0)

    def add_instance(self, instance):
        assert (len(instance) == 4)
        cls = instance[1]
        self.counter[cls] += 1
        is_add = True

        if self.get_occupancy() >= self.capacity:
            is_add = self.remove_instance(cls)

        if is_add:
            for i, dim in enumerate(self.data[cls]):
                dim.append(instance[i])

    def get_largest_indices(self):

        occupancy_per_class = self.get_occupancy_per_class()
        max_value = max(occupancy_per_class)
        largest_indices = []
        for i, oc in enumerate(occupancy_per_class):
            if oc == max_value:
                largest_indices.append(i)
        return largest_indices

    def remove_instance(self, cls):
        largest_indices = self.get_largest_indices()
        if cls not in largest_indices:  # instance is stored in the place of another instance that belongs to the largest class
            largest = random.choice(largest_indices)  # select only one largest class
            tgt_idx = random.randrange(0, len(self.data[largest][0]))  # target index to remove
            for dim in self.data[largest]:
                dim.pop(tgt_idx)
        else:  # replaces a randomly selected stored instance of the same class
            m_c = self.get_occupancy_per_class()[cls]
            n_c = self.counter[cls]
            u = random.uniform(0, 1)
            if u <= m_c / n_c:
                tgt_idx = random.randrange(0, len(self.data[cls][0]))  # target index to remove
                for dim in self.data[cls]:
                    dim.pop(tgt_idx)
            else:
                return False
        return True


class LabelDistributionQueue:
    def __init__(self, num_class, capacity=None):
        if capacity is None: capacity = num_class * 20
        self.queue_length = capacity
        self.queue = torch.zeros(self.queue_length)
        self.pointer = 0
        self.num_class = num_class
        self.size = 0

    def update(self, tgt_preds):
        tgt_preds = tgt_preds.detach().cpu()
        batch_sz = tgt_preds.shape[0]
        self.size += batch_sz
        if self.pointer + batch_sz > self.queue_length:  # Deal with wrap around when ql % batchsize != 0
            rem_space = self.queue_length - self.pointer
            self.queue[self.pointer: self.queue_length] = (tgt_preds[:rem_space] + 1)
            self.queue[0: batch_sz - rem_space] = (tgt_preds[rem_space:] + 1)
        else:
            self.queue[self.pointer: self.pointer + batch_sz] = (tgt_preds + 1)
        self.pointer = (self.pointer + batch_sz) % self.queue_length

    def get(self, ):
        bincounts = torch.bincount(self.queue.long(), minlength=self.num_class + 1).float() / self.queue_length
        bincounts = bincounts[1:]
        if bincounts.sum() == 0: bincounts[:] = 1
        # log_q = torch.log(bincounts + 1e-12).detach().cuda()
        return bincounts

    def full(self):
        return self.size >= self.queue_length


def compute_lame(preds, feats, metric, k):
    # --- Get unary and terms and kernel ---
    unary = - torch.log(preds + 1e-10)  # [N, K]
    feats = F.normalize(feats, p=2, dim=-1)  # [N, d]
    kernel = compute_sim(feats, k, metric)  # [N, N]

    # --- Perform optim ---
    Y = laplacian_optimization(unary, kernel)

    return Y


def compute_sim(feats, k, metric):
    if metric == 'rbf':
        N = feats.size(0)
        dist = torch.norm(feats.unsqueeze(0) - feats.unsqueeze(1), dim=-1, p=2)  # [N, N]
        n_neighbors = min(k, N)
        kth_dist = dist.topk(k=n_neighbors, dim=-1, largest=False).values[:,
                   -1]  # compute k^th distance for each point, [N, knn + 1]
        sigma = kth_dist.mean()
        rbf = torch.exp(- dist ** 2 / (2 * sigma ** 2))
        return rbf
    elif metric == 'knn':
        N = feats.size(0)
        dist = torch.norm(feats.unsqueeze(0) - feats.unsqueeze(1), dim=-1, p=2)  # [N, N]
        n_neighbors = min(k + 1, N)

        knn_index = dist.topk(n_neighbors, -1, largest=False).indices[:, 1:]  # [N, knn]

        W = torch.zeros(N, N, device=feats.device)
        W.scatter_(dim=-1, index=knn_index, value=1.0)
        return W
    else:
        return torch.matmul(feats, feats.t())


def laplacian_optimization(unary, kernel, bound_lambda=1, max_steps=100):
    E_list = []
    oldE = float('inf')
    Y = (-unary).softmax(-1)  # [N, K]
    for i in range(max_steps):
        pairwise = bound_lambda * kernel.matmul(Y)  # [N, K]
        exponent = -unary + pairwise
        Y = exponent.softmax(-1)
        E = entropy_energy(Y, unary, pairwise, bound_lambda).item()
        E_list.append(E)

        if (i > 1 and (abs(E - oldE) <= 1e-8 * abs(oldE))):
            # logger.info(f'Converged in {i} iterations')
            break
        else:
            oldE = E

    return Y


def entropy_energy(Y, unary, pairwise, bound_lambda):
    E = (unary * Y - bound_lambda * pairwise * Y + Y * torch.log(Y.clip(1e-20))).sum()
    return E


# inspired by https://github.com/bethgelab/robustness/tree/aa0a6798fe3973bae5f47561721b59b39f126ab7
def find_bns(parent, prior):
    replace_mods = []
    if parent is None:
        return []
    for name, child in parent.named_children():
        if isinstance(child, nn.BatchNorm2d):
            module = TBR(child, prior).cuda()
            replace_mods.append((parent, name, module))
        else:
            replace_mods.extend(find_bns(child, prior))
    return replace_mods


class TBR(nn.Module):
    def __init__(self, layer, prior):
        assert prior >= 0 and prior <= 1
        super().__init__()
        self.layer = layer
        self.layer.eval()
        self.prior = prior
        self.rmax = 3.0
        self.dmax = 5.0
        self.tracked_num = 0
        # self.running_mean = deepcopy(layer.running_mean)
        # self.running_std = deepcopy(torch.sqrt(layer.running_var) + 1e-5)
        self.running_mean = None
        self.running_std = None

    def forward(self, input):
        batch_mean = input.mean([0, 2, 3])
        batch_std = torch.sqrt(input.var([0, 2, 3], unbiased=False) + self.layer.eps)

        if self.running_mean is None:
            self.running_mean = batch_mean.detach().clone()
            self.running_std = batch_std.detach().clone()

        r = (batch_std.detach() / self.running_std)  # .clamp_(1./self.rmax, self.rmax)
        d = ((batch_mean.detach() - self.running_mean) / self.running_std)  # .clamp_(-self.dmax, self.dmax)

        input = (input - batch_mean[None, :, None, None]) / batch_std[None, :, None, None] * r[None, :, None, None] + d[
                                                                                                                      None,
                                                                                                                      :,
                                                                                                                      None,
                                                                                                                      None]
        # input = (input - self.running_mean[None,:,None,None]) / self.running_std[None,:,None,None]

        # if len(input)>=128:
        self.running_mean = self.prior * self.running_mean + (1. - self.prior) * batch_mean.detach()
        self.running_std = self.prior * self.running_std + (1. - self.prior) * batch_std.detach()
        # else:
        #     print('too small batch size, using last step model directly...')

        self.tracked_num += 1

        return input * self.layer.weight[None, :, None, None] + self.layer.bias[None, :, None, None]


def update_ema(ema, new_data):
    if ema is None:
        return new_data
    else:
        with torch.no_grad():
            return 0.9 * ema + (1 - 0.9) * new_data