import torch.nn
from torch import optim

import conf
from utils.memory import FIFO
from .dnn import DNN
from torch.utils.data import DataLoader

from utils.loss_functions import *

device = torch.device("cuda:{:d}".format(conf.args.gpu_idx) if torch.cuda.is_available() else "cpu")


class TENT(DNN):
    def __init__(self, *args, **kwargs):
        super(TENT, self).__init__(*args, **kwargs)

    def init_learner(self):
        for param in self.net.parameters():  # initially turn off requires_grad for all
            param.requires_grad = False

        for module in self.net.modules():
            if isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d):
                # https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
                # TENT: force use of batch stats in train and eval modes: https://github.com/DequanWang/tent/blob/master/tent.py
                if conf.args.use_learned_stats:
                    module.track_running_stats = True
                    module.momentum = conf.args.bn_momentum
                else:
                    module.track_running_stats = False
                    module.running_mean = None
                    module.running_var = None

                module.weight.requires_grad_(True)
                module.bias.requires_grad_(True)

            if isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
                module.requires_grad_(True)

        if conf.args.dataset == "imagenet":
            optimizer = torch.optim.SGD(
                            self.net.parameters(),
                            conf.args.opt['learning_rate'],
                            momentum=conf.args.opt['momentum'],
                            weight_decay=conf.args.opt['weight_decay'],
                            nesterov=True)
        else:
            optimizer = optim.Adam(self.net.parameters(), lr=conf.args.opt['learning_rate'], weight_decay=0.0)
        return optimizer

    def test_time_adaptation(self):
        assert isinstance(self.mem, FIFO)
        feats, labels, _ = self.mem.get_memory()
        feats = torch.stack(feats).to(device)
        labels = torch.Tensor(labels).type(torch.long).to(device)

        dataset = torch.utils.data.TensorDataset(feats, labels)
        data_loader = DataLoader(dataset, batch_size=conf.args.tta_batch_size,
                                 shuffle=True, drop_last=False, pin_memory=False)

        for e in range(conf.args.epoch):
            for batch_idx, (feats, _) in enumerate(data_loader):
                if len(feats) == 1:
                    self.net.eval()  # avoid BN error
                else:
                    self.net.train()

                entropy_loss = HLoss()

                preds_of_data = self.net(feats)

                loss = entropy_loss(preds_of_data)
                
                if conf.args.enable_batta:
                    loss += self.get_batta_ssl_loss()
                    
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                    
              
