import os

import math
import pandas as pd
import torch.nn
from torch import optim
from torch.nn import KLDivLoss
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP


import conf
from .dnn import DNN
from torch.utils.data import DataLoader

from utils.loss_functions import *
import copy
from torchvision.transforms.functional import to_pil_image

device = torch.device("cuda:{:d}".format(conf.args.gpu_idx) if torch.cuda.is_available() else "cpu")


class BATTA_PARALLEL(DNN):
    """
    Code is partially inspired from:
    - https://github.com/taeckyung/SoTTA/blob/main/learner/sotta.py
    - https://github.com/tmllab/2023_NeurIPS_FlatMatch/blob/main/trainer.py
    """

    def __init__(self, *args, **kwargs):
        self.atta_src_net = None
        super(BATTA_PARALLEL, self).__init__(*args, **kwargs)
        self.src_net = copy.deepcopy(self.net)
        self.src_net_state = copy.deepcopy(self.net.state_dict())
        self.loss_fn = softmax_cross_entropy()
        assert (conf.args.memory_type == "ActivePriorityFIFO")

    def init_learner(self):
        for param in self.net.parameters():  # turn on requires_grad for all
            param.requires_grad = True

        for name, module in self.net.named_modules():
            # skip top layers for adaptation: layer4 for ResNets and blocks9-11 for Vit-Base
            if 'fc' in name:
                for param in module.parameters():
                    param.requires_grad = True
                continue

            if isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d):
                # https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
                # TENT: force use of batch stats in train and eval modes: https://github.com/DequanWang/tent/blob/master/tent.py
                if conf.args.use_learned_stats:  # use learn stats
                    module.track_running_stats = True
                    module.momentum = conf.args.bn_momentum
                else:
                    module.track_running_stats = False
                    module.running_mean = None
                    module.running_var = None

                module.weight.requires_grad_(True)
                module.bias.requires_grad_(True)


        optimizer = torch.optim.SGD(
                        self.net.parameters(),
                        conf.args.opt['learning_rate'],
                        momentum=conf.args.opt['momentum'],
                        weight_decay=conf.args.opt['weight_decay'],
                        nesterov=True)

        self.src_net_state = copy.deepcopy(self.net.state_dict())
        self.cnt_correct_after_frze = 0

        self.dropout_nets = [self.net] + [
            copy.deepcopy(self.net).to(torch.device("cuda:{:d}".format((conf.args.gpu_idx + i) % 8)))
            for i in range(1, conf.args.n_dropouts)
        ]

        return optimizer

    def test_time_adaptation(self):
        # Unlabeled data
        u_feats, u_labels, _, _ = self.mem.get_u_memory()
        if len(u_feats) != 0:
            u_feats = torch.stack(u_feats).to(device)
            u_labels = torch.tensor(u_labels).to(device)
            u_dataset = torch.utils.data.TensorDataset(u_feats, u_labels)
            u_dataloader = DataLoader(u_dataset, batch_size=conf.args.update_every_x, #len(u_feats),
                                      shuffle=True, drop_last=False, pin_memory=False)
        else:
            u_dataloader = [(None, None)]

        # Correct data
        correct_feats, correct_labels, _, _ = self.mem.get_correct_memory()
        self.json_active['num_correct_mem'] += [len(correct_labels)]
        if len(correct_feats) != 0:
            correct_feats = torch.stack(correct_feats).to(device)
            correct_labels = torch.tensor(correct_labels).to(device)

            correct_dataset = torch.utils.data.TensorDataset(correct_feats, correct_labels)
            correct_dataloader = DataLoader(correct_dataset, batch_size=conf.args.update_every_x, #len(correct_feats),
                                            shuffle=True, drop_last=False, pin_memory=False)
        else:
            correct_dataloader = [(None, None)]

        # Wrong data
        wrong_feats, wrong_labels, wrong_gt_labels, _ = self.mem.get_wrong_memory()
        self.json_active['num_wrong_mem'] += [len(wrong_labels)]
        if len(wrong_feats) != 0:
            wrong_feats = torch.stack(wrong_feats).to(device)
            wrong_labels = torch.tensor(wrong_labels).to(device)
            wrong_gt_labels = torch.tensor(wrong_gt_labels, dtype=torch.long).to(device)

            wrong_dataset = torch.utils.data.TensorDataset(wrong_feats, wrong_labels, wrong_gt_labels)
            wrong_dataloader = DataLoader(wrong_dataset, batch_size=conf.args.update_every_x, # len(wrong_feats),
                                          shuffle=True, drop_last=False, pin_memory=False)
        else:
            wrong_dataloader = [(None, None, None)]

        self.net.train()
        self.disable_running_stats()

        epoch = conf.args.epoch
        for e in range(epoch):
            for (u_feats_, u_labels_), (correct_feats_, correct_labels_), (wrong_feats_, wrong_labels_, wrong_gt_labels_) \
                    in zip(u_dataloader, correct_dataloader, wrong_dataloader):

                data = []
                if correct_feats_ is not None:
                    data.append(correct_feats_)
                if wrong_feats_ is not None:
                    data.append(wrong_feats_)
                if u_feats_ is not None:
                    data.append(u_feats_)
                assert(len(data) > 0)
                data = torch.cat(data, dim=0)

                model_state = copy.deepcopy(self.net.state_dict())
                for i in range(1, conf.args.n_dropouts):
                    self.dropout_nets[i].load_state_dict(model_state, strict=True)

                # For optimization
                predictions = []
                for i in range(0, conf.args.n_dropouts):
                    net = self.dropout_nets[i]
                    d = torch.device("cuda:{:d}".format((conf.args.gpu_idx + i) % 8))
                    pred = net[1]((net[0](data.to(d))), dropout=conf.args.dropout_rate)  # batch_size, n_classes
                    pred = F.softmax(pred, dim=1)
                    predictions.append(pred)

                mcd_mean_softmaxs = []
                for i in range(0, conf.args.n_dropouts):
                    d = torch.device("cuda:{:d}".format((conf.args.gpu_idx + i) % 8))
                    predictions_copy = [predictions[j].detach().to(d) if j != i else predictions[j] for j in range(conf.args.n_dropouts)]
                    mcd_softmaxs = torch.stack(predictions_copy, dim=1)  # batch_size, n_iter, n_classes
                    mcd_mean_softmax = torch.mean(mcd_softmaxs, dim=1)
                    mcd_mean_softmaxs.append(mcd_mean_softmax)

                with torch.no_grad():
                    outputs = self.net(data)
                    outputs_softmax = outputs.softmax(dim=1)
                    preds = outputs.argmax(dim=1)

                for i in range(0, conf.args.n_dropouts):
                    d = torch.device("cuda:{:d}".format((conf.args.gpu_idx + i) % 8))
                    correct_loss, wrong_loss, unlabeled_loss = self.get_loss(mcd_mean_softmaxs[i], mcd_mean_softmaxs[i], outputs_softmax,
                                                                             correct_labels_, wrong_labels_, u_labels_, d)

                    loss = conf.args.w_final_loss_unlabeled * unlabeled_loss \
                           + conf.args.w_final_loss_correct * correct_loss \
                           + conf.args.w_final_loss_wrong * wrong_loss

                    for param in self.dropout_nets[i].parameters():
                        if param.grad is not None:
                            param.grad.zero_()  # Set the gradient to zero

                    loss.backward()

                # Synchronize gradients across all GPUs
                for params in zip(*[net.parameters() for net in self.dropout_nets]):
                    # Initialize the first parameter's gradient as the aggregated gradient container
                    aggregated_grad = sum(param.grad.to(device) for param in params if param.grad is not None)
                    # Assign the aggregated gradient to the first network's parameters (or any network)
                    params[0].grad.data = aggregated_grad

                self.optimizer.step()

    def get_loss(self, own_mc_dropout_softmax, total_mc_dropout_softmax, original_softmax, correct_labels_, wrong_labels_, u_labels_, d):
        correct_loss = torch.tensor([0.0]).to(d)
        wrong_loss = torch.tensor([0.0]).to(d)
        unlabeled_loss = torch.tensor([0.0]).to(d)
        uncertainty_loss = torch.tensor([0.0]).to(d)
        original_softmax = original_softmax.to(d)

        if correct_labels_ is not None:
            correct_dropout_outputs = own_mc_dropout_softmax[:len(correct_labels_)]
            correct_loss = self.loss_fn(correct_dropout_outputs, correct_labels_.to(d))

        if wrong_labels_ is not None:
            start_idx = len(correct_labels_) if correct_labels_ is not None else 0
            end_idx = -len(u_labels_) if u_labels_ is not None else len(total_mc_dropout_softmax)

            own_wrong_dropout_outputs = own_mc_dropout_softmax[start_idx:end_idx] # softmax output of wrong sample

            wrong_loss = -self.loss_fn(own_wrong_dropout_outputs, wrong_labels_.to(d))

        if u_labels_ is not None:
            start_idx = len(correct_labels_) if correct_labels_ is not None else 0
            start_idx += len(wrong_labels_) if wrong_labels_ is not None else 0

            own_u_dropout_outputs = own_mc_dropout_softmax[start_idx:]
            total_u_dropout_outputs = total_mc_dropout_softmax[start_idx:]

            original_pred = original_softmax[start_idx:].argmax(dim=1).detach().to(d)
            confident_idx = original_pred == total_u_dropout_outputs.argmax(dim=1)
            if conf.args.ablation_conf_th > 0:
                confident_idx = original_softmax[torch.arange(len(original_pred)), original_pred] > conf.args.ablation_conf_th
            if conf.args.ablation_ent_th > 0:
                confident_idx = entropy(original_softmax[start_idx:]) < conf.args.ablation_ent_th

            # certainty = own_u_dropout_outputs[torch.arange(len(original_pred)), original_pred].detach().unsqueeze(1)
            unlabeled_loss = self.loss_fn(own_u_dropout_outputs[confident_idx], original_pred[confident_idx])

        return correct_loss, wrong_loss, unlabeled_loss


    def pre_active_sample_selection(self):
        self.enable_running_stats()
        self.disable_running_stats()

    def disable_running_stats(self):
        for module in self.net.modules():
            if isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d):
                if conf.args.use_learned_stats:  # use learn stats
                    module.track_running_stats = True
                    module.momentum = 0

    def enable_running_stats(self):
        for module in self.net.modules():
            if isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d):
                if conf.args.use_learned_stats:  # use learn stats
                    module.track_running_stats = True
                    module.momentum = conf.args.bn_momentum
        self.net.train()
        feats, _, _ = self.fifo.get_memory()
        feats = torch.stack(feats).to(device)
        with torch.no_grad():
            _ = self.net(feats) # update bn stats
        pass
