
import conf
from data_loader.CIFAR100Dataset import CIFAR100Dataset
from data_loader.CIFAR10Dataset import CIFAR10Dataset

from data_loader.PACSDataset import PacsDataset
from data_loader.TINYIMAGENETDataset import TinyImageNetDataset
from utils.loss_functions import *
from utils.memory import FIFO

device = torch.device("cuda:{:d}".format(conf.args.gpu_idx) if torch.cuda.is_available() else "cpu")

from torch import optim

import conf
from data_loader.data_loader import datasets_to_dataloader
from utils.reset_utils import copy_model_and_optimizer
from .dnn import DNN
from torch.utils.data import random_split, DataLoader

from utils.loss_functions import *
from utils import memory


class ETA(DNN):
    # EATA without anti-forgetting
    def __init__(self, model_, corruption_list_):
        # self.steps = 1 # SoTTA: replaced to epoch
        # assert self.steps > 0, "EATA requires >= 1 step(s) to forward and update"
        # self.episodic = False  # SoTTA: we don't use episodic

        self.num_samples_update_1 = 0  # number of samples after First filtering, exclude unreliable samples
        self.num_samples_update_2 = 0  # number of samples after Second filtering, exclude both unreliable and redundant samples
        self.e_margin = conf.args.e_margin  # hyper-parameter E_0 (Eqn. 3)
        self.d_margin = conf.args.d_margin  # hyper-parameter \epsilon for consine simlarity thresholding (Eqn. 5)

        self.current_model_probs = None  # the moving average of probability vector (Eqn. 4)

        self.fishers = None  # fisher regularizer items for anti-forgetting, need to be calculated pre model adaptation (Eqn. 9)
        self.fisher_alpha = conf.args.fisher_alpha  # trade-off \beta for two losses (Eqn. 8)

        super(ETA, self).__init__(model_, corruption_list_)

    def init_learner(self):
        configure_model(self.net)
        optimizer = optim.SGD(self.net.parameters(),
                                   lr=conf.args.opt['learning_rate'],
                                   momentum=conf.args.opt['momentum'],
                                   weight_decay=conf.args.opt['weight_decay'])
        return optimizer

    @torch.enable_grad()  # ensure grads in possible no grad context for testing
    def forward_and_adapt_eata(self, x, model, optimizer, fishers, e_margin, current_model_probs, fisher_alpha=50.0,
                            d_margin=0.05, scale_factor=2, num_samples_update=0):
        """Forward and adapt model on batch of data.
        Measure entropy of the model prediction, take gradients, and update params.
        Return:
        1. model outputs;
        2. the number of reliable and non-redundant samples;
        3. the number of reliable samples;
        4. the moving average  probability vector over all previous samples
        """
        # forward
        outputs = model(x)
        # adapt
        entropys = softmax_entropy(outputs)
        # filter unreliable samples
        filter_ids_1 = torch.where(entropys < e_margin)
        ids1 = filter_ids_1
        ids2 = torch.where(ids1[0] > -0.1)
        entropys = entropys[filter_ids_1]
        # filter redundant samples
        if current_model_probs is not None:
            cosine_similarities = F.cosine_similarity(current_model_probs.unsqueeze(dim=0),
                                                    outputs[filter_ids_1].softmax(1), dim=1)
            filter_ids_2 = torch.where(torch.abs(cosine_similarities) < d_margin)
            entropys = entropys[filter_ids_2]
            ids2 = filter_ids_2
            updated_probs = update_model_probs(current_model_probs, outputs[filter_ids_1][filter_ids_2].softmax(1))
        else:
            updated_probs = update_model_probs(current_model_probs, outputs[filter_ids_1].softmax(1))
        coeff = 1 / (torch.exp(entropys.clone().detach() - e_margin))
        # implementation version 1, compute loss, all samples backward (some unselected are masked)
        entropys = entropys.mul(coeff)  # reweight entropy losses for diff. samples
        loss = entropys.mean(0)
        """
        # implementation version 2, compute loss, forward all batch, forward and backward selected samples again.
        # if x[ids1][ids2].size(0) != 0:
        #     loss = softmax_entropy(model(x[ids1][ids2])).mul(coeff).mean(0) # reweight entropy losses for diff. samples
        """
        if fishers is not None:
            ewc_loss = 0
            for name, param in model.named_parameters():
                if name in fishers:
                    ewc_loss += fisher_alpha * (fishers[name][0] * (param - fishers[name][1]) ** 2).sum()
            loss += ewc_loss
        
        if conf.args.enable_batta:
            loss += self.get_batta_ssl_loss()
                        
        optimizer.zero_grad()  # move it here to make PETAL reset works!
        if x[ids1][ids2].size(0) != 0:
            loss.backward()
            optimizer.step()
        # optimizer.zero_grad()
        return outputs, entropys.size(0), filter_ids_1[0].size(0), updated_probs, loss

    def test_time_adaptation(self):
        assert isinstance(self.mem, FIFO)
        feats, labels, _ = self.mem.get_memory()
        feats = torch.stack(feats).to(device)
        labels = torch.Tensor(labels).type(torch.long).to(device)

        dataset = torch.utils.data.TensorDataset(feats, labels)
        data_loader = DataLoader(dataset, batch_size=conf.args.tta_batch_size,
                                 shuffle=True, drop_last=False, pin_memory=False)

        for e in range(conf.args.epoch):
            for batch_idx, (feats, _) in enumerate(data_loader):
                if len(feats) == 1:
                    self.net.eval()  # avoid BN error
                else:
                    self.net.train()

                result = self.forward_and_adapt_eata(feats, self.net,
                                                self.optimizer,
                                                self.fishers,
                                                self.e_margin,
                                                self.current_model_probs,
                                                fisher_alpha=self.fisher_alpha,
                                                num_samples_update=self.num_samples_update_2,
                                                d_margin=self.d_margin)
                outputs, num_counts_2, num_counts_1, updated_probs, loss = result

                self.num_samples_update_2 += num_counts_2
                self.num_samples_update_1 += num_counts_1
                self.reset_model_probs(updated_probs)

    def reset_model_probs(self, probs):
        self.current_model_probs = probs


class EATA(ETA):
    def __init__(self, model_, corruption_list_):
        super(EATA, self).__init__(model_, corruption_list_)

        # only use the first domain for fisher importance calculation
        if conf.args.dataset == "cifar10":
            fisher_dataset = CIFAR10Dataset(file="", domains=[corruption_list_[0]], max_source=9999, transform='val')
        elif conf.args.dataset == "cifar100":
            fisher_dataset = CIFAR100Dataset(file="", domains=[corruption_list_[0]], max_source=9999, transform='val')
        elif conf.args.dataset == "pacs":
            fisher_dataset = PacsDataset(file="", domains=[corruption_list_[0]], max_source=9999, transform='val')
        elif conf.args.dataset == "tiny-imagenet":
            fisher_dataset = TinyImageNetDataset(file="", domain=corruption_list_[0], max_source=9999, transform='val')
        else:
            raise NotImplementedError

        fisher_dataset = random_split(fisher_dataset, [conf.args.fisher_size, len(fisher_dataset)-conf.args.fisher_size])[0]
        fisher_loader = datasets_to_dataloader([fisher_dataset], batch_size=64, concat=True, shuffle=True)

        subnet = configure_model(self.net)
        params, param_names = collect_params(subnet)
        ewc_optimizer = torch.optim.SGD(params, 0.001)
        fishers = {}
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        for iter_, (images, targets, domains) in enumerate(fisher_loader, start=1):
            if conf.args.gpu_idx is not None:
                images = images.cuda(conf.args.gpu_idx, non_blocking=True)
            # if torch.cuda.is_available():
            #     targets = targets.cuda(conf.args.gpu_idx, non_blocking=True)
            outputs = subnet(images)
            _, targets = outputs.max(1)
            loss = train_loss_fn(outputs, targets)
            loss.backward()
            for name, param in subnet.named_parameters():
                if param.grad is not None:
                    if iter_ > 1:
                        fisher = param.grad.data.clone().detach() ** 2 + fishers[name][0]
                    else:
                        fisher = param.grad.data.clone().detach() ** 2
                    if iter_ == len(fisher_loader):
                        fisher = fisher / iter_
                    fishers.update({name: [fisher, param.data.clone().detach()]})
            ewc_optimizer.zero_grad()
        # logger.info("compute fisher matrices finished")
        del ewc_optimizer
        self.fishers = fishers  # fisher regularizer items for anti-forgetting, need to be calculated pre model adaptation (Eqn. 9)


def collect_params(model):
    """Collect the affine scale + shift parameters from batch norms.
    Walk the model's modules and collect all batch normalization parameters.
    Return the parameters and their names.
    Note: other choices of parameterization are possible!
    """
    params = []
    names = []
    for nm, m in model.named_modules():
        if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.LayerNorm):
            for np, p in m.named_parameters():
                if np in ['weight', 'bias']:  # weight is scale, bias is shift
                    params.append(p)
                    names.append(f"{nm}.{np}")
    return params, names


def configure_model(model):
    """Configure model for use with eata."""
    # train mode, because eata optimizes the model to minimize entropy
    model.train()
    # disable grad, to (re-)enable only what eata updates
    model.requires_grad_(False)
    # configure norm for eata updates: enable grad + force batch statisics
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.requires_grad_(True)
            # force use of batch stats in train and eval modes
            m.track_running_stats = False
            m.running_mean = None
            m.running_var = None

        if isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
            m.requires_grad_(True)
    return model



@torch.jit.script
def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
    """Entropy of softmax distribution from logits."""
    temprature = 1
    x = x / temprature
    x = -(x.softmax(1) * x.log_softmax(1)).sum(1)
    return x


def update_model_probs(current_model_probs, new_probs):
    if current_model_probs is None:
        if new_probs.size(0) == 0:
            return None
        else:
            with torch.no_grad():
                return new_probs.mean(0)
    else:
        if new_probs.size(0) == 0:
            with torch.no_grad():
                return current_model_probs
        else:
            with torch.no_grad():
                return 0.9 * current_model_probs + (1 - 0.9) * new_probs.mean(0)


def check_model(model):
    """Check model for compatability with eata."""
    is_training = model.training
    assert is_training, "eata needs train mode: call model.train()"
    param_grads = [p.requires_grad for p in model.parameters()]
    has_any_params = any(param_grads)
    has_all_params = all(param_grads)
    assert has_any_params, "eata needs params to update: " \
                           "check which require grad"
    assert not has_all_params, "eata should not update all params: " \
                               "check which require grad"
    has_bn = any([isinstance(m, nn.BatchNorm2d) for m in model.modules()])
    assert has_bn, "eata needs normalization for its optimization"
