import math
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import copy
import random
import torchvision.transforms as transforms

from PIL import Image, ImageOps, ImageEnhance
import torch
import numpy as np
import torch
from pathlib import Path
import os
from robustbench.zenodo_download import DownloadError, zenodo_download
from robustbench.model_zoo.enums import BenchmarkDataset, ThreatModel
from typing import Callable, Dict, Optional, Sequence, Set, Tuple


def collect_params(model, ft_layers):
    params = []
    names = []
    if ft_layers == 'whole':
        for nm, m in model.named_modules():
            for np, p in m.named_parameters():
                if np in ['weight', 'bias']:  
                    params.append(p)
                    names.append(f"{nm}.{np}")
    else:
        for nm, m in model.named_modules():
            if isinstance(m, nn.BatchNorm2d):
                for np, p in m.named_parameters():
                    if np in ['weight', 'bias']: 
                        params.append(p)
                        names.append(f"{nm}.{np}")
    return params, names


def setup_optimizer(params, optim_type):
    if optim_type == 'adam':
        return optim.Adam(params,
                          lr=1e-3,
                          betas=(0.9, 0.999),
                          weight_decay=0.)
    elif optim_type == 'sgd':
        return optim.SGD(params,
                         lr=1e-3,
                         momentum=0.9,
                         dampening=0.,
                         weight_decay=0.,
                         nesterov=True)
    else:
        raise NotImplementedError


# https://github.com/fiveai/LAME/blob/master/src/adaptation/lame.py
def compute_lame(preds, feats, metric, k):
    # --- Get unary and terms and kernel ---
    unary = - torch.log(preds + 1e-10)  # [N, K]
    feats = F.normalize(feats, p=2, dim=-1)  # [N, d]
    kernel = compute_sim(feats, k, metric)  # [N, N]

    # --- Perform optim ---
    Y = laplacian_optimization(unary, kernel)

    return Y


def compute_sim(feats, k, metric):
    if metric == 'rbf':
        N = feats.size(0)
        dist = torch.norm(feats.unsqueeze(0) - feats.unsqueeze(1), dim=-1, p=2)  # [N, N]
        n_neighbors = min(k, N)
        kth_dist = dist.topk(k=n_neighbors, dim=-1, largest=False).values[:,
                   -1]  # compute k^th distance for each point, [N, knn + 1]
        sigma = kth_dist.mean()
        rbf = torch.exp(- dist ** 2 / (2 * sigma ** 2))
        return rbf
    elif metric == 'knn':
        N = feats.size(0)
        dist = torch.norm(feats.unsqueeze(0) - feats.unsqueeze(1), dim=-1, p=2)  # [N, N]
        n_neighbors = min(k + 1, N)

        knn_index = dist.topk(n_neighbors, -1, largest=False).indices[:, 1:]  # [N, knn]

        W = torch.zeros(N, N, device=feats.device)
        W.scatter_(dim=-1, index=knn_index, value=1.0)
        return W
    else:
        return torch.matmul(feats, feats.t())


def laplacian_optimization(unary, kernel, bound_lambda=1, max_steps=100):
    E_list = []
    oldE = float('inf')
    Y = (-unary).softmax(-1)  # [N, K]
    for i in range(max_steps):
        pairwise = bound_lambda * kernel.matmul(Y)  # [N, K]
        exponent = -unary + pairwise
        Y = exponent.softmax(-1)
        E = entropy_energy(Y, unary, pairwise, bound_lambda).item()
        E_list.append(E)

        if (i > 1 and (abs(E - oldE) <= 1e-8 * abs(oldE))):
            # logger.info(f'Converged in {i} iterations')
            break
        else:
            oldE = E

    return Y


def entropy_energy(Y, unary, pairwise, bound_lambda):
    E = (unary * Y - bound_lambda * pairwise * Y + Y * torch.log(Y.clip(1e-20))).sum()
    return E


# inspired by https://github.com/bethgelab/robustness/tree/aa0a6798fe3973bae5f47561721b59b39f126ab7
def find_bns(parent, prior):
    replace_mods = []
    if parent is None:
        return []
    for name, child in parent.named_children():
        if isinstance(child, nn.BatchNorm2d):
            module = TBR(child, prior).cuda()
            replace_mods.append((parent, name, module))
        else:
            replace_mods.extend(find_bns(child, prior))
    return replace_mods


class TBR(nn.Module):
    def __init__(self, layer, prior):
        assert prior >= 0 and prior <= 1
        super().__init__()
        self.layer = layer
        self.layer.eval()
        self.prior = prior
        self.rmax = 3.0
        self.dmax = 5.0
        self.tracked_num = 0
        self.running_mean = None
        self.running_std = None

    def forward(self, input):
        batch_mean = input.mean([0, 2, 3])
        batch_std = torch.sqrt(input.var([0, 2, 3], unbiased=False) + self.layer.eps)

        if self.running_mean is None:
            self.running_mean = batch_mean.detach().clone()
            self.running_std = batch_std.detach().clone()

        r = (batch_std.detach() / self.running_std) 
        d = ((batch_mean.detach() - self.running_mean) / self.running_std)  

        input = (input - batch_mean[None, :, None, None]) / batch_std[None, :, None, None] * r[None, :, None, None] + d[
                                                                                                                      None,
                                                                                                                      :,
                                                                                                                      None,
                                                                                                                      None]



        self.running_mean = self.prior * self.running_mean + (1. - self.prior) * batch_mean.detach()
        self.running_std = self.prior * self.running_std + (1. - self.prior) * batch_std.detach()

        self.tracked_num += 1

        return input * self.layer.weight[None, :, None, None] + self.layer.bias[None, :, None, None]


## 1. noadapt
def tta_noadapt(_x, _y, _model, batch_size, num_classes, T=None, g_phi=None):
    _model.eval()
    acc = 0.
    num_examples = _x.size(0)
    n_batches = math.ceil(_x.shape[0] / batch_size)
    logits_bank = torch.zeros(num_examples, num_classes)
    labels_bank = torch.zeros(num_examples).long()

    T_online = None
    with torch.no_grad():
        perm = torch.randperm(_x.size(0))
        for counter in range(n_batches):
            idx_curr = perm[counter * batch_size:(counter + 1) * batch_size]
            x_curr = _x[idx_curr].cuda()
            y_curr = _y[idx_curr]

            outputs, _ = _model(x_curr, True)

            logits_bank[idx_curr] = outputs.detach().cpu()
            labels_bank[idx_curr] = y_curr.detach().cpu()

            if T is not None:
                outputs = outputs @ T.cuda()
            elif g_phi is not None:
                if T_online is None:
                    with torch.no_grad():
                        T_online = g_phi(torch.softmax(outputs, dim=-1).mean(0))
                outputs = outputs @ T_online

            acc += (outputs.argmax(-1).cpu() == y_curr.cpu()).float().sum()

    acc = acc / num_examples
    return acc, logits_bank, labels_bank


# 2. bnadapt
def tta_bnadapt(_x, _y, _model, args, T=None, g_phi=None, need_feats=False):
    batch_size, num_classes = args.batch_size, args.num_classes
    _model.train()
    acc = 0.
    num_examples = _x.size(0)
    n_batches = math.ceil(_x.shape[0] / batch_size)
    logits_bank = torch.zeros(num_examples, num_classes)
    labels_bank = torch.zeros(num_examples).long()

    feat_dim = 2048 if args.num_classes == 7 else 512
    feats_bank = torch.zeros(num_examples, feat_dim)
    T_online = None
    with torch.no_grad():
        perm = torch.randperm(_x.size(0))
        for counter in range(n_batches):
            idx_curr = perm[counter * batch_size:(counter + 1) * batch_size]
            x_curr = _x[idx_curr].cuda()
            y_curr = _y[idx_curr]

            outputs, feats = _model(x_curr, True)

            if T is not None:
                logit_norm = torch.norm(outputs, dim=1, keepdim=True)
                outputs = outputs @ T.cuda()
                outputs = F.normalize(outputs, dim=1)
                outputs = logit_norm * outputs
            elif g_phi is not None:
                if T_online is None:
                    with torch.no_grad():
                        T_online = g_phi(torch.softmax(outputs, dim=-1).mean(0))
                logit_norm = torch.norm(outputs, dim=1, keepdim=True)
                outputs = outputs @ T_online
                outputs = F.normalize(outputs, dim=1)
                outputs = logit_norm * outputs

            logits_bank[idx_curr] = outputs.detach().cpu()
            labels_bank[idx_curr] = y_curr.detach().cpu()
            feats_bank[idx_curr] = feats.detach().cpu()

            acc += (outputs.argmax(-1).cpu() == y_curr.cpu()).float().sum()

    acc = acc / num_examples
    if need_feats:
        return acc, logits_bank, labels_bank, feats_bank
    else:
        return acc, logits_bank, labels_bank


# 3. tent
def tta_tent(_x, _y, _model, args, T=None, g_phi=None):
    params, _ = collect_params(_model, args.ft_layers)
    _optimizer = setup_optimizer(params, args.optim_type)
    _model.train()

    num_examples = _x.size(0)
    num_batches = math.ceil(num_examples / args.batch_size)
    accs = []

    T_online = None
    for _ in range(args.num_epochs):
        perm = torch.randperm(num_examples)
        acc_temp = 0.
        for counter in range(num_batches):
            idx_curr = perm[counter * args.batch_size:(counter + 1) * args.batch_size]
            x_curr = _x[idx_curr].cuda()
            y_curr = _y[idx_curr]

            outputs = _model(x_curr)

            if T is not None:
                logit_norm = torch.norm(outputs, dim=1, keepdim=True)
                outputs = outputs @ T.cuda()
                outputs = F.normalize(outputs, dim=1)
                outputs = logit_norm * outputs
            elif g_phi is not None:
                if T_online is None:
                    with torch.no_grad():
                        T_online = g_phi(torch.softmax(outputs, dim=-1).mean(0))
                logit_norm = torch.norm(outputs, dim=1, keepdim=True)
                outputs = outputs @ T_online
                outputs = F.normalize(outputs, dim=1)
                outputs = logit_norm * outputs

            target = torch.softmax(outputs, dim=-1)

            tta_loss = torch.logsumexp(outputs, dim=1) - (target * outputs).sum(dim=1)
            tta_loss = tta_loss.mean()

            _optimizer.zero_grad()
            tta_loss.backward()
            _optimizer.step()

            acc_temp += (outputs.max(1)[1].detach().cpu() == y_curr).float().sum()
        accs.append(acc_temp.item() / num_examples)
    return accs


# 4. pseudo label
def tta_pseudolabel(_x, _y, _model, args, T=None, g_phi=None, is_hard=False):
    params, _ = collect_params(_model, args.ft_layers)
    _optimizer = setup_optimizer(params, args.optim_type)
    _model.train()

    num_examples = _x.size(0)
    num_batches = math.ceil(num_examples / args.batch_size)
    accs = []

    T_online = None
    for _ in range(args.num_epochs):
        perm = torch.randperm(num_examples)
        acc_temp = 0.
        for counter in range(num_batches):
            idx_curr = perm[counter * args.batch_size:(counter + 1) * args.batch_size]
            x_curr = _x[idx_curr].cuda()
            y_curr = _y[idx_curr]

            outputs = _model(x_curr)

            if T is not None:
                logit_norm = torch.norm(outputs, dim=1, keepdim=True)
                outputs = outputs @ T.cuda()
                outputs = F.normalize(outputs, dim=1)
                outputs = logit_norm * outputs
            elif g_phi is not None:
                if T_online is None:
                    with torch.no_grad():
                        T_online = g_phi(torch.softmax(outputs, dim=-1).mean(0))
                logit_norm = torch.norm(outputs, dim=1, keepdim=True)
                outputs = outputs @ T_online
                outputs = F.normalize(outputs, dim=1)
                outputs = logit_norm * outputs

            target = F.one_hot(outputs.argmax(-1), num_classes=args.num_classes)

            smax_outputs = torch.softmax(outputs, dim=-1)
            py, yhat = smax_outputs.max(1)
            mask = py > args.pl_threshold

            tta_loss = torch.logsumexp(outputs[mask], dim=1) - (target[mask] * outputs[mask]).sum(dim=1)
            tta_loss = tta_loss.mean()

            _optimizer.zero_grad()
            tta_loss.backward()
            _optimizer.step()

            acc_temp += (outputs.max(1)[1].detach().cpu() == y_curr).float().sum()
        accs.append(acc_temp.item() / num_examples)
    return accs


# 4. lame
def tta_lame(_x, _y, _model, args, T=None, g_phi=None, is_hard=False):
    _model.train()

    num_examples = _x.size(0)
    num_batches = math.ceil(num_examples / args.batch_size)
    accs = 0.

    T_online = None
    perm = torch.randperm(num_examples)
    for counter in range(num_batches):
        idx_curr = perm[counter * args.batch_size:(counter + 1) * args.batch_size]
        x_curr = _x[idx_curr].cuda()
        y_curr = _y[idx_curr]

        outputs, feats = _model(x_curr, True)

        if T is not None:
            logit_norm = torch.norm(outputs, dim=1, keepdim=True)
            outputs = outputs @ T.cuda()
            outputs = F.normalize(outputs, dim=1)
            outputs = logit_norm * outputs
        elif g_phi is not None:
            if T_online is None:
                with torch.no_grad():
                    T_online = g_phi(torch.softmax(outputs, dim=-1).mean(0))
            logit_norm = torch.norm(outputs, dim=1, keepdim=True)
            outputs = outputs @ T_online
            outputs = F.normalize(outputs, dim=1)
            outputs = logit_norm * outputs

        refined_outputs = compute_lame(torch.softmax(outputs, -1), feats, metric='knn', k=5)

        accs += (refined_outputs.argmax(-1).cpu() == y_curr).float().sum()
    accs = accs.item() / num_examples

    return accs


def tta_delta(_x, _y, _model, args, T=None, g_phi=None, is_hard=False):
    # hparams
    old_prior = 0.95
    dot = 0.95

    # build models for DELTA
    _model.eval()
    _model.requires_grad_(False)
    replace_mods = find_bns(_model, old_prior)
    for (parent, name, child) in replace_mods:
        setattr(parent, name, child)

    _model.requires_grad_(False)
    params = []
    for nm, m in _model.named_modules():
        if isinstance(m, nn.BatchNorm2d):
            for np, p in m.named_parameters():
                if np in ['weight', 'bias']:
                    p.requires_grad_(True)
                    params.append(p)

    _optimizer = setup_optimizer(params, args.optim_type)
    _model.train()

    num_examples = _x.size(0)
    num_batches = math.ceil(num_examples / args.batch_size)
    accs = []

    qhat = torch.zeros(1, args.num_classes).cuda() + (1. / args.num_classes)

    T_online = None
    for _ in range(args.num_epochs):
        perm = torch.randperm(num_examples)
        acc_temp = 0.
        for counter in range(num_batches):
            idx_curr = perm[counter * args.batch_size:(counter + 1) * args.batch_size]
            x_curr = _x[idx_curr].cuda()
            y_curr = _y[idx_curr]

            outputs = _model(x_curr)

            if T is not None:
                logit_norm = torch.norm(outputs, dim=1, keepdim=True)
                outputs = outputs @ T.cuda()
                outputs = F.normalize(outputs, dim=1)
                outputs = logit_norm * outputs
            elif g_phi is not None:
                if T_online is None:
                    with torch.no_grad():
                        T_online = g_phi(torch.softmax(outputs, dim=-1).mean(0))
                logit_norm = torch.norm(outputs, dim=1, keepdim=True)
                outputs = outputs @ T_online
                outputs = F.normalize(outputs, dim=1)
                outputs = logit_norm * outputs

            p = torch.softmax(outputs, dim=1)
            pmax, pls = p.max(dim=1)
            logp = F.log_softmax(outputs, dim=1)
            ent_weight = torch.ones_like(pls)
            entropys = -(p * logp).sum(dim=1)

            class_weight = 1. / qhat
            class_weight = class_weight / class_weight.sum()
            sample_weight = class_weight.gather(1, pls.view(1, -1)).squeeze()
            sample_weight = sample_weight / sample_weight.sum() * len(pls)
            ent_weight = ent_weight * sample_weight
            tta_loss = (entropys * ent_weight).mean()

            _optimizer.zero_grad()
            tta_loss.backward()
            _optimizer.step()

            acc_temp += (outputs.max(1)[1].detach().cpu() == y_curr).float().sum()

            # update qhat
            with torch.no_grad():
                qhat = dot * qhat + (1. - dot) * p.mean(dim=0, keepdim=True)
        accs.append(acc_temp.item() / num_examples)
    return accs


# 3. tent
def tta_note(_x, _y, _model, args, T=None, g_phi=None, is_hard=False):
    # hparams
    memory_size = args.batch_size
    iabn_k = 4
    skip_thres = 1.
    num_examples = _x.size(0)
    num_batches = math.ceil(num_examples / args.batch_size)

    # build model
    convert_iabn(_model, iabn_k, skip_thres)
    _model.train()

    # build memory
    mem = PBRS(capacity=memory_size, num_classes=args.num_classes)

    # build optimizer
    for param in _model.parameters(): 
        param.requires_grad = False

    for module in _model.modules():
        if isinstance(module, InstanceAwareBatchNorm2d) or isinstance(module, InstanceAwareBatchNorm1d):
            for param in module.parameters():
                param.requires_grad = True
        elif isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d):
            for param in module.parameters():
                param.requires_grad = True

    params = []
    for nm, m in _model.named_modules():
        if isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d):
            for np, p in m.named_parameters():
                if np in ['weight', 'bias']: 
                    params.append(p)
        elif isinstance(m, InstanceAwareBatchNorm2d) or isinstance(m, InstanceAwareBatchNorm1d):
            for np, p in m.named_parameters():
                if np in ['weight', 'bias']: 
                    params.append(p)

    _optimizer = setup_optimizer(params, args.optim_type)
    _model.train()

    accs = []

    T_online = None
    for _ in range(args.num_epochs):
        perm = torch.randperm(num_examples)
        acc_temp = 0.
        for counter in range(num_batches):
            idx_curr = perm[counter * args.batch_size:(counter + 1) * args.batch_size]
            x_curr = _x[idx_curr].cuda()
            y_curr = _y[idx_curr]

            # make prediction
            with torch.no_grad():
                _model.eval()
                outputs = _model(x_curr)

                if T is not None:
                    logit_norm = torch.norm(outputs, dim=1, keepdim=True)
                    outputs = outputs @ T.cuda()
                    outputs = F.normalize(outputs, dim=1)
                    outputs = logit_norm * outputs
                elif g_phi is not None:
                    if T_online is None:
                        with torch.no_grad():
                            T_online = g_phi(torch.softmax(outputs, dim=-1).mean(0))
                    logit_norm = torch.norm(outputs, dim=1, keepdim=True)
                    outputs = outputs @ T_online
                    outputs = F.normalize(outputs, dim=1)
                    outputs = logit_norm * outputs

                acc_temp += (outputs.argmax(-1).cpu() == y_curr.cpu()).float().sum()

            # update memory
            for i in range(x_curr.size(0)):
                current_sample = x_curr[i], y_curr[i]

                with torch.no_grad():
                    _model.eval()
                    f, c = current_sample[0].cuda(), current_sample[1].cuda()
                    logit = _model(f.unsqueeze(0))

                    if T is not None:
                        logit_norm = torch.norm(logit, dim=1, keepdim=True)
                        logit = logit @ T.cuda()
                        logit = F.normalize(logit, dim=1)
                        logit = logit_norm * logit
                    elif g_phi is not None:
                        T_online = g_phi(torch.softmax(logit, dim=-1).mean(0))
                        logit_norm = torch.norm(logit, dim=1, keepdim=True)
                        logit = logit @ T_online
                        logit = F.normalize(logit, dim=1)
                        logit = logit_norm * logit
                    pseudo_cls = logit.max(1, keepdim=False)[1][0]
                    mem.add_instance([f.cpu(), pseudo_cls.cpu(), c.cpu(), 0])
            # one-step adaptation
            _model.train()
            memory_samples_x, _ = mem.get_memory()
            memory_samples_x = torch.stack(memory_samples_x).cuda()

            logits = _model(memory_samples_x)  # [N, K]
            # pseudo_prob = F.one_hot(logits.argmax(-1), num_classes=args.num_classes) if is_hard else torch.softmax(logits, dim=-1)
            pseudo_prob = torch.softmax(logits, dim=-1)

            loss = torch.logsumexp(logits, dim=1) - (pseudo_prob * logits).sum(dim=1)
            loss = loss.mean()

            _optimizer.zero_grad()
            loss.backward()
            _optimizer.step()

        accs.append(acc_temp.item() / num_examples)

    return accs


def tta_ods_note(_x, _y, _model, args, T=None, g_phi=None, is_hard=False):
    # hparams
    memory_size = args.batch_size
    iabn_k = 4
    skip_thres = 1.
    num_examples = _x.size(0)
    num_classes = args.num_classes
    num_batches = math.ceil(num_examples / args.batch_size)

    convert_iabn(_model, iabn_k, skip_thres)
    _model_init = copy.deepcopy(_model)
    mem = PBRS(capacity=memory_size, num_classes=num_classes)
    queue_dist = LabelDistributionQueue(num_class=num_classes, capacity=memory_size)

    # build optimizer
    for param in _model.parameters():  
        param.requires_grad = False
    for param in _model_init.parameters():
        param.requires_grad = False

    for module in _model.modules():
        if isinstance(module, InstanceAwareBatchNorm2d) or isinstance(module, InstanceAwareBatchNorm1d):
            for param in module.parameters():
                param.requires_grad = True
        elif isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d):
            for param in module.parameters():
                param.requires_grad = True

    params = []
    for nm, m in _model.named_modules():
        if isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d):
            for np, p in m.named_parameters():
                if np in ['weight', 'bias']:  
                    params.append(p)
        elif isinstance(m, InstanceAwareBatchNorm2d) or isinstance(m, InstanceAwareBatchNorm1d):
            for np, p in m.named_parameters():
                if np in ['weight', 'bias']: 
                    params.append(p)
    _optimizer = setup_optimizer(params, args.optim_type)

    _model.train()
    _model_init.train()
    accs = []

    T_online = None
    for _ in range(args.num_epochs):
        perm = torch.randperm(num_examples)
        acc_temp = 0.
        for counter in range(num_batches):
            idx_curr = perm[counter * args.batch_size:(counter + 1) * args.batch_size]
            x_curr = _x[idx_curr].cuda()
            y_curr = _y[idx_curr]

            outputs = _model(x_curr)

            # make predction
            _model.eval()
            outputs, feats = _model(x_curr, True)
            if T is not None:
                logit_norm = torch.norm(outputs, dim=1, keepdim=True)
                outputs = outputs @ T.cuda()
                outputs = F.normalize(outputs, dim=1)
                outputs = logit_norm * outputs
            elif g_phi is not None:
                if T_online is None:
                    with torch.no_grad():
                        T_online = g_phi(torch.softmax(outputs, dim=-1).mean(0))
                logit_norm = torch.norm(outputs, dim=1, keepdim=True)
                outputs = outputs @ T_online
                outputs = F.normalize(outputs, dim=1)
                outputs = logit_norm * outputs

            outputs_init = _model_init(x_curr)

            optim_dist = compute_lame(torch.softmax(outputs_init, -1), feats, 'knn', 5)  # [B,K]

            probas = torch.sqrt(torch.softmax(outputs, 1) * optim_dist)  # [B,K]
            unary = - torch.log(probas + 1e-10)
            refine_outputs = (-unary).softmax(-1)
            outputs = refine_outputs

            acc_temp += (outputs.argmax(-1).cpu() == y_curr.cpu()).float().sum()

            # update memory
            for i in range(x_curr.size(0)):
                current_sample = x_curr[i], y_curr[i]

                with torch.no_grad():
                    _model.eval()
                    f, c = current_sample[0].cuda(), current_sample[1].cuda()
                    logit = _model(f.unsqueeze(0))

                    if T is not None:
                        logit_norm = torch.norm(logit, dim=1, keepdim=True)
                        logit = logit @ T.cuda()
                        logit = F.normalize(logit, dim=1)
                        logit = logit_norm * logit
                    elif g_phi is not None:
                        T_temp = g_phi(torch.softmax(logit, dim=-1).mean(0))
                        logit_norm = torch.norm(logit, dim=1, keepdim=True)
                        logit = logit @ T_temp
                        logit = F.normalize(logit, dim=1)
                        logit = logit_norm * logit
                    pseudo_cls = logit.max(1, keepdim=False)[1][0]
                    mem.add_instance([f.cpu(), pseudo_cls.cpu(), c.cpu(), 0])

            # one-step adaptation
            _model.train()
            memory_samples_x, memory_samples_p = mem.get_memory()
            memory_samples_x = torch.stack(memory_samples_x).cuda()
            memory_samples_p = torch.stack(memory_samples_p)  # .cuda()

            # queue_dist.update(torch.tensor(memory_samples_p).clone().detach())
            queue_dist.update(memory_samples_p.clone().detach())

            weight = 1.0 - queue_dist.get() + 0.1
            weight = weight / weight.sum()

            weight = weight[memory_samples_p].cuda()
            outputs = _model(memory_samples_x)
            # target = F.one_hot(outputs.argmax(-1),num_classes=num_classes) if is_hard else torch.softmax(outputs, dim=-1)
            target = torch.softmax(outputs, dim=-1)

            ent_loss = torch.logsumexp(outputs, dim=1) - (target * outputs).sum(dim=1)
            loss = ent_loss.mean() * weight

            loss = loss.sum(0) / weight.sum(0)

            _optimizer.zero_grad()
            loss.backward()
            _optimizer.step()

        accs.append(acc_temp.item() / num_examples)

    return accs

    return accs


def convert_iabn(module, iabn_k, skip_thres, **kwargs):
    module_output = module
    if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):
        IABN = InstanceAwareBatchNorm2d if isinstance(module, nn.BatchNorm2d) else InstanceAwareBatchNorm1d
        module_output = IABN(
            num_channels=module.num_features,
            k=iabn_k,
            eps=module.eps,
            momentum=module.momentum,
            affine=module.affine,
        )

        module_output._bn = copy.deepcopy(module)

    for name, child in module.named_children():
        module_output.add_module(
            name, convert_iabn(child, iabn_k, skip_thres, **kwargs)
        )
    del module
    return module_output


class InstanceAwareBatchNorm2d(nn.Module):
    def __init__(self, num_channels, k=3.0, eps=1e-5, momentum=0.1, affine=True):
        super(InstanceAwareBatchNorm2d, self).__init__()
        self.num_channels = num_channels
        self.eps = eps
        self.k = k
        self.affine = affine
        self._bn = nn.BatchNorm2d(num_channels, eps=eps,
                                  momentum=momentum, affine=affine)
        self.skip_thres = 1.

    def _softshrink(self, x, lbd):
        x_p = F.relu(x - lbd, inplace=True)
        x_n = F.relu(-(x + lbd), inplace=True)
        y = x_p - x_n
        return y

    def forward(self, x):
        b, c, h, w = x.size()
        sigma2, mu = torch.var_mean(x, dim=[2, 3], keepdim=True, unbiased=True)  # IN

        if self.training:
            _ = self._bn(x)
            sigma2_b, mu_b = torch.var_mean(x, dim=[0, 2, 3], keepdim=True, unbiased=True)
        else:
            if self._bn.track_running_stats == False and self._bn.running_mean is None and self._bn.running_var is None:  # use batch stats
                sigma2_b, mu_b = torch.var_mean(x, dim=[0, 2, 3], keepdim=True, unbiased=True)
            else:
                mu_b = self._bn.running_mean.view(1, c, 1, 1)
                sigma2_b = self._bn.running_var.view(1, c, 1, 1)

        if h * w <= self.skip_thres: 
            mu_adj = mu_b
            sigma2_adj = sigma2_b
        else:
            s_mu = torch.sqrt((sigma2_b + self.eps) / (h * w))
            s_sigma2 = (sigma2_b + self.eps) * np.sqrt(2 / (h * w - 1))

            mu_adj = mu_b + self._softshrink(mu - mu_b, self.k * s_mu)

            sigma2_adj = sigma2_b + self._softshrink(sigma2 - sigma2_b, self.k * s_sigma2)

            sigma2_adj = F.relu(sigma2_adj)  

        x_n = (x - mu_adj) * torch.rsqrt(sigma2_adj + self.eps)
        if self.affine:
            weight = self._bn.weight.view(c, 1, 1)
            bias = self._bn.bias.view(c, 1, 1)
            x_n = x_n * weight + bias
        return x_n


class InstanceAwareBatchNorm1d(nn.Module):
    def __init__(self, num_channels, k=3.0, eps=1e-5, momentum=0.1, affine=True):
        super(InstanceAwareBatchNorm1d, self).__init__()
        self.num_channels = num_channels
        self.k = k
        self.eps = eps
        self.affine = affine
        self._bn = nn.BatchNorm1d(num_channels, eps=eps,
                                  momentum=momentum, affine=affine)

    def _softshrink(self, x, lbd):
        x_p = F.relu(x - lbd, inplace=True)
        x_n = F.relu(-(x + lbd), inplace=True)
        y = x_p - x_n
        return y

    def forward(self, x):
        b, c, l = x.size()
        sigma2, mu = torch.var_mean(x, dim=[2], keepdim=True, unbiased=True)
        if self.training:
            _ = self._bn(x)
            sigma2_b, mu_b = torch.var_mean(x, dim=[0, 2], keepdim=True, unbiased=True)
        else:
            if self._bn.track_running_stats == False and self._bn.running_mean is None and self._bn.running_var is None:  # use batch stats
                sigma2_b, mu_b = torch.var_mean(x, dim=[0, 2], keepdim=True, unbiased=True)
            else:
                mu_b = self._bn.running_mean.view(1, c, 1)
                sigma2_b = self._bn.running_var.view(1, c, 1)

        if l <= skip_thres: 
            mu_adj = mu_b
            sigma2_adj = sigma2_b

        else:
            s_mu = torch.sqrt((sigma2_b + self.eps) / l)  ##
            s_sigma2 = (sigma2_b + self.eps) * np.sqrt(2 / (l - 1))

            mu_adj = mu_b + self._softshrink(mu - mu_b, self.k * s_mu)
            sigma2_adj = sigma2_b + self._softshrink(sigma2 - sigma2_b, self.k * s_sigma2)
            sigma2_adj = F.relu(sigma2_adj)

        x_n = (x - mu_adj) * torch.rsqrt(sigma2_adj + self.eps)

        if self.affine:
            weight = self._bn.weight.view(c, 1)
            bias = self._bn.bias.view(c, 1)
            x_n = x_n * weight + bias

        return x_n


class FIFO():
    def __init__(self, capacity):
        # no domain label
        self.data = [[], []]  # [[], [], []]
        self.capacity = capacity
        pass

    def get_memory(self):
        return self.data

    def get_occupancy(self):
        return len(self.data[0])

    def add_instance(self, instance):
        assert (len(instance) == 2)  # (len(instance) == 3)

        if self.get_occupancy() >= self.capacity:
            self.remove_instance()

        for i, dim in enumerate(self.data):
            dim.append(instance[i])

    def remove_instance(self):
        for dim in self.data:
            dim.pop(0)
        pass


class Reservoir():  # Time uniform

    def __init__(self, capacity):
        super(Reservoir, self).__init__(capacity)
        self.data = [[], []]  # [[], [], []]
        self.capacity = capacity
        self.counter = 0

    def get_memory(self):
        return self.data

    def get_occupancy(self):
        return len(self.data[0])

    def add_instance(self, instance):
        assert (len(instance) == 2)  # (len(instance) == 3)
        is_add = True
        self.counter += 1

        if self.get_occupancy() >= self.capacity:
            is_add = self.remove_instance()

        if is_add:
            for i, dim in enumerate(self.data):
                dim.append(instance[i])

    def remove_instance(self):

        m = self.get_occupancy()
        n = self.counter
        u = random.uniform(0, 1)
        if u <= m / n:
            tgt_idx = random.randrange(0, m)  # target index to remove
            for dim in self.data:
                dim.pop(tgt_idx)
        else:
            return False
        return True


class PBRS():

    def __init__(self, capacity, num_classes):
        self.num_classes = num_classes
        self.data = [[[], []] for _ in range(
            self.num_classes)]  # [[[], [], []] for _ in range(self.num_classes)] #feat, pseudo_cls, domain, cls, loss
        self.counter = [0] * self.num_classes
        self.marker = [''] * self.num_classes
        self.capacity = capacity
        pass

    def print_class_dist(self):

        print(self.get_occupancy_per_class())

    def print_real_class_dist(self):

        occupancy_per_class = [0] * self.num_classes
        for i, data_per_cls in enumerate(self.data):
            for cls in data_per_cls[3]:
                occupancy_per_class[cls] += 1
        print(occupancy_per_class)

    def get_memory(self):

        data = self.data

        tmp_data = [[], []]  # [[], [], []]
        for data_per_cls in data:
            feats, cls = data_per_cls
            tmp_data[0].extend(feats)
            tmp_data[1].extend(cls)
            # tmp_data[2].extend(dls)

        return tmp_data

    def get_occupancy(self):
        occupancy = 0
        for data_per_cls in self.data:
            occupancy += len(data_per_cls[0])
        return occupancy

    def get_occupancy_per_class(self):
        occupancy_per_class = [0] * self.num_classes
        for i, data_per_cls in enumerate(self.data):
            occupancy_per_class[i] = len(data_per_cls[0])
        return occupancy_per_class

    def update_loss(self, loss_list):
        for data_per_cls in self.data:
            feats, cls, _, losses = data_per_cls
            for i in range(len(losses)):
                losses[i] = loss_list.pop(0)

    def add_instance(self, instance):
        assert (len(instance) == 4)
        cls = instance[1]
        self.counter[cls] += 1
        is_add = True

        if self.get_occupancy() >= self.capacity:
            is_add = self.remove_instance(cls)

        if is_add:
            for i, dim in enumerate(self.data[cls]):
                dim.append(instance[i])

    def get_largest_indices(self):

        occupancy_per_class = self.get_occupancy_per_class()
        max_value = max(occupancy_per_class)
        largest_indices = []
        for i, oc in enumerate(occupancy_per_class):
            if oc == max_value:
                largest_indices.append(i)
        return largest_indices

    def remove_instance(self, cls):
        largest_indices = self.get_largest_indices()
        if cls not in largest_indices:  # instance is stored in the place of another instance that belongs to the largest class
            largest = random.choice(largest_indices)  # select only one largest class
            tgt_idx = random.randrange(0, len(self.data[largest][0]))  # target index to remove
            for dim in self.data[largest]:
                dim.pop(tgt_idx)
        else:  # replaces a randomly selected stored instance of the same class
            m_c = self.get_occupancy_per_class()[cls]
            n_c = self.counter[cls]
            u = random.uniform(0, 1)
            if u <= m_c / n_c:
                tgt_idx = random.randrange(0, len(self.data[cls][0]))  # target index to remove
                for dim in self.data[cls]:
                    dim.pop(tgt_idx)
            else:
                return False
        return True


class LabelDistributionQueue:
    def __init__(self, num_class, capacity=None):
        if capacity is None: capacity = num_class * 20
        self.queue_length = capacity
        self.queue = torch.zeros(self.queue_length)
        self.pointer = 0
        self.num_class = num_class
        self.size = 0

    def update(self, tgt_preds):
        tgt_preds = tgt_preds.detach().cpu()
        batch_sz = tgt_preds.shape[0]
        self.size += batch_sz
        if self.pointer + batch_sz > self.queue_length:  # Deal with wrap around when ql % batchsize != 0
            rem_space = self.queue_length - self.pointer
            self.queue[self.pointer: self.queue_length] = (tgt_preds[:rem_space] + 1)
            self.queue[0: batch_sz - rem_space] = (tgt_preds[rem_space:] + 1)
        else:
            self.queue[self.pointer: self.pointer + batch_sz] = (tgt_preds + 1)
        self.pointer = (self.pointer + batch_sz) % self.queue_length

    def get(self, ):
        bincounts = torch.bincount(self.queue.long(), minlength=self.num_class + 1).float() / self.queue_length
        bincounts = bincounts[1:]
        if bincounts.sum() == 0: bincounts[:] = 1
        # log_q = torch.log(bincounts + 1e-12).detach().cuda()
        return bincounts

    def full(self):
        return self.size >= self.queue_length


### memo

def tta_memo(_x, _y, _model, args, _use_test_bn=False, _adaptive=False, T=None, g_phi=None):
    if _use_test_bn:
        _model = configure_model_tent(_model)
    else:
        _model = configure_model_eval(_model)

    _x, _y = (_x.numpy() * 255).astype(np.uint8).transpose(0, 2, 3, 1), _y.numpy()
    batch_size = args.batch_size

    preprocess = transforms.Compose(
        [transforms.ToTensor()])
    test_data = AugMixDataset(_x, _y, preprocess, False)

    test_loader = torch.utils.data.DataLoader(
        test_data,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True)

    params, _ = collect_params(_model, args.ft_layers)
    _optimizer = setup_optimizer(params, args.optim_type)

    if not _adaptive:
        model_state, optimizer_state = copy_model_and_optimizer(_model, _optimizer)

    acc = 0.
    T_online = None
    n_batches = math.ceil(_x.shape[0] / batch_size)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    for _, (images, y_curr) in enumerate(test_loader):
        if not _adaptive:
            load_model_and_optimizer(_model, _optimizer,
                                     model_state, optimizer_state)

        y_curr = y_curr.to(device)

        for _ in range(1):
            logits_aug1 = _model(images[1].to(device))
            logits_aug2 = _model(images[2].to(device))
            logits_aug3 = _model(images[3].to(device))
            if T is not None:
                # print (T)
                logits_aug1 = logits_aug1 @ T.cuda()
                logits_aug2 = logits_aug2 @ T.cuda()
                logits_aug3 = logits_aug3 @ T.cuda()
            elif g_phi is not None:
                if T_online is None:
                    with torch.no_grad():
                        T_online = g_phi(1 / 3 * (torch.softmax(logits_aug1, dim=-1) + torch.softmax(logits_aug2,
                                                                                                     dim=-1) + torch.softmax(
                            logits_aug3, dim=-1)).mean(0))
                # outputs = outputs @ T_online
                logits_aug1 = logits_aug1 @ T_online.cuda()
                logits_aug2 = logits_aug2 @ T_online.cuda()
                logits_aug3 = logits_aug3 @ T_online.cuda()

            # T = 1 # cfg.OPTIM.TEMP
            p_aug1, p_aug2, p_aug3 = F.softmax(logits_aug1, dim=1), F.softmax(logits_aug2, dim=1), F.softmax(
                logits_aug3, dim=1)

            p_avg = (p_aug1 + p_aug2 + p_aug3) / 3
            tta_loss = - (p_avg * torch.log(p_avg)).sum(dim=1)

            tta_loss = tta_loss.mean()

            _optimizer.zero_grad()
            tta_loss.backward()

            _optimizer.step()

        outputs_new = _model(images[0])
        acc += (outputs_new.max(1)[1] == y_curr).float().sum()

    return acc.item() / _x.shape[0]


class AugMixDataset(torch.utils.data.Dataset):
    """Dataset wrapper to perform AugMix augmentation."""

    def __init__(self, x_test, y_test, preprocess, no_jsd=False):
        # self.dataset = dataset
        self.x_test = x_test
        self.y_test = y_test
        self.preprocess = preprocess
        self.no_jsd = no_jsd

    def __getitem__(self, i):
        # x, y = self.dataset[i]
        x, y = self.x_test[i], self.y_test[i]
        img = Image.fromarray(x)

        if self.no_jsd:
            return aug(img, self.preprocess), y
        else:
            im_tuple = (self.preprocess(img), aug(img, self.preprocess),
                        aug(img, self.preprocess), aug(img, self.preprocess))
            return im_tuple, y

    def __len__(self):
        return self.x_test.shape[0]


def aug(image, preprocess):
    """Perform AugMix augmentations and compute mixture.

    Args:
    image: PIL.Image input image
    preprocess: Preprocessing function which should return a torch tensor.

    Returns:
    mixed: Augmented and mixed image.
    """
    mixture_width = 3
    mixture_depth = -1
    aug_severity = 3

    aug_list = augmentations_all  # augmentations.augmentations_all
    #   if args.all_ops:
    # aug_list = augmentations.augmentations_all

    ws = np.float32(np.random.dirichlet([1] * mixture_width))
    #   m = np.float32(np.random.beta(1, 1))
    m = 0.4

    mix = torch.zeros_like(preprocess(image))
    for i in range(mixture_width):
        image_aug = image.copy()
        depth = mixture_depth if mixture_depth > 0 else np.random.randint(
            1, 4)
    for _ in range(depth):
        op = np.random.choice(aug_list)
        image_aug = op(image_aug, 3)
    # Preprocessing commutes since all coefficients are convex
    mix += ws[i] * preprocess(image_aug)

    mixed = (1 - m) * preprocess(image) + m * mix
    return mixed


# ImageNet code should change this value
IMAGE_SIZE = 32


def int_parameter(level, maxval):
    """Helper function to scale `val` between 0 and maxval .
    Args:
      level: Level of the operation that will be between [0, `PARAMETER_MAX`].
      maxval: Maximum value that the operation can have. This will be scaled to
        level/PARAMETER_MAX.
    Returns:
      An int that results from scaling `maxval` according to `level`.
    """
    return int(level * maxval / 10)


def float_parameter(level, maxval):
    """Helper function to scale `val` between 0 and maxval.
    Args:
      level: Level of the operation that will be between [0, `PARAMETER_MAX`].
      maxval: Maximum value that the operation can have. This will be scaled to
        level/PARAMETER_MAX.
    Returns:
      A float that results from scaling `maxval` according to `level`.
    """
    return float(level) * maxval / 10.


def sample_level(n):
    #   return np.random.uniform(low=0.1, high=n)
    return n


def autocontrast(pil_img, _):
    return ImageOps.autocontrast(pil_img)


def equalize(pil_img, _):
    return ImageOps.equalize(pil_img)


def posterize(pil_img, level):
    level = int_parameter(sample_level(level), 4)
    return ImageOps.posterize(pil_img, 4 - level)


def rotate(pil_img, level):
    degrees = int_parameter(sample_level(level), 30)
    if np.random.uniform() > 0.5:
        degrees = -degrees
    return pil_img.rotate(degrees, resample=Image.BILINEAR)


def solarize(pil_img, level):
    level = int_parameter(sample_level(level), 256)
    return ImageOps.solarize(pil_img, 256 - level)


def shear_x(pil_img, level):
    level = float_parameter(sample_level(level), 0.3)
    if np.random.uniform() > 0.5:
        level = -level
    return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
                             Image.AFFINE, (1, level, 0, 0, 1, 0),
                             resample=Image.BILINEAR)


def shear_y(pil_img, level):
    level = float_parameter(sample_level(level), 0.3)
    if np.random.uniform() > 0.5:
        level = -level
    return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
                             Image.AFFINE, (1, 0, 0, level, 1, 0),
                             resample=Image.BILINEAR)


def translate_x(pil_img, level):
    level = int_parameter(sample_level(level), IMAGE_SIZE / 3)
    if np.random.random() > 0.5:
        level = -level
    return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
                             Image.AFFINE, (1, 0, level, 0, 1, 0),
                             resample=Image.BILINEAR)


def translate_y(pil_img, level):
    level = int_parameter(sample_level(level), IMAGE_SIZE / 3)
    if np.random.random() > 0.5:
        level = -level
    return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
                             Image.AFFINE, (1, 0, 0, 0, 1, level),
                             resample=Image.BILINEAR)


# operation that overlaps with ImageNet-C's test set
def color(pil_img, level):
    level = float_parameter(sample_level(level), 1.8) + 0.1
    return ImageEnhance.Color(pil_img).enhance(level)


# operation that overlaps with ImageNet-C's test set
def contrast(pil_img, level):
    level = float_parameter(sample_level(level), 1.8) + 0.1
    return ImageEnhance.Contrast(pil_img).enhance(level)


# operation that overlaps with ImageNet-C's test set
def brightness(pil_img, level):
    level = float_parameter(sample_level(level), 1.8) + 0.1
    return ImageEnhance.Brightness(pil_img).enhance(level)


# operation that overlaps with ImageNet-C's test set
def sharpness(pil_img, level):
    level = float_parameter(sample_level(level), 1.8) + 0.1
    return ImageEnhance.Sharpness(pil_img).enhance(level)


augmentations_all = [
    autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
    translate_x, translate_y, color, contrast, brightness, sharpness
]


def configure_model_eval(model):
    """Configure model for use with tent."""
    # train mode, because tent optimizes the model to minimize entropy
    model.eval()
    # disable grad, to (re-)enable only what tent updates
    model.requires_grad_(False)
    # configure norm for tent updates: enable grad + force batch statisics
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.requires_grad_(True)
            # force use of batch stats in train and eval modes
    return model


def copy_model_and_optimizer(model, optimizer):
    """Copy the model and optimizer states for resetting after adaptation."""
    model_state = copy.deepcopy(model.state_dict())
    optimizer_state = copy.deepcopy(optimizer.state_dict())
    return model_state, optimizer_state


def load_model_and_optimizer(model, optimizer, model_state, optimizer_state):
    """Restore the model and optimizer states from copies."""
    model.load_state_dict(model_state, strict=True)
    optimizer.load_state_dict(optimizer_state)
