from torch.nn import functional as F

from dassl.engine import TRAINER_REGISTRY
from dassl.engine.trainer_myada import TrainerX
from dassl.metrics import compute_accuracy
from collections import OrderedDict
import datetime
import time
import copy


import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import random
from dassl.optim import build_optimizer, build_lr_scheduler
import numpy as np
import matplotlib.pyplot as plt
import os

from dassl.utils import MetricMeter, AverageMeter



criterion_bce = nn.BCELoss()
criterion_ce = nn.CrossEntropyLoss()
criterion_bce_red = nn.BCELoss(reduction='none')


def cross_entropy_loss_with_2logits(logits1, logits2):
    log_probs1 = F.log_softmax(logits1, dim=1)
    probs2 = F.softmax(logits2, dim=1)
    loss = -(probs2 * log_probs1).sum(dim=1).mean()
    return loss


def sample_batch(images, labels, batch_size):
    num_samples = len(images)
    if batch_size > num_samples:
        repeat_times = batch_size // num_samples + 1
        extended_images = images * repeat_times
        extended_labels = labels * repeat_times
        indices = random.sample(range(len(extended_images)), batch_size)
        sampled_images = [extended_images[i] for i in indices]
        sampled_labels = [extended_labels[i] for i in indices]
    else:
        indices = random.sample(range(num_samples), batch_size)
        sampled_images = [images[i] for i in indices]
        sampled_labels = [labels[i] for i in indices]
    
    return sampled_images, sampled_labels



@TRAINER_REGISTRY.register()
class Vanilla_Ada(TrainerX):

    def fix_bn(self, m):
        classname = m.__class__.__name__
        if classname.find('BatchNorm') != -1:
            m.eval()

    def enable_bn(self, m):
        classname = m.__class__.__name__
        if classname.find('BatchNorm') != -1:
            m.train()

    def forward_FDA(self, rois, targets):
        unk_pixel_list = []
        k_pixel_list = []
        k_label_list = []
        loss_mining_dict = {}
        bs, c, h, w = rois.size()
        p_xk_x = 1.0

        for idx, roi in enumerate(rois): 
            roi_label = targets[idx]
            roi_flatten = roi.view(c, -1).permute(1, 0)
            self.model.classifier.apply(self.fix_bn)
            
            pixel_logits = self.model.classifier(roi_flatten, adaption=True, constant=-1 * self.cfg.ADA.mining_grl) 
            pixel_scores = pixel_logits.softmax(-1) 

            self.model.classifier.apply(self.enable_bn)

            sorted_value, sorted_index = pixel_scores[:, :-1].sort(-1, descending=True)  
            k_mask = (sorted_index[:, :1] == roi_label).sum(-1).bool()
            unk_mask = ~((sorted_index[:, :self.cfg.ADA.topk] == roi_label).sum(-1).bool()) 

            if k_mask.any() and unk_mask.any() and k_mask.sum() > unk_mask.sum():
                unk_pixels = pixel_logits[unk_mask] 
                k_pixels = pixel_logits[k_mask] 
                unk_pixel_list.append(unk_pixels)
                k_pixel_list.append(k_pixels)
                k_label_list += [int(roi_label)]*k_pixels.shape[0] 


        if len(unk_pixel_list) > 1:
            num_lk = len(torch.cat(k_pixel_list))
            num_pu = len(torch.cat(unk_pixel_list))
            p_xu_x = num_pu / (num_pu + num_lk)
            p_xk_x = num_lk / (num_pu + num_lk)

            mined_scores = torch.cat(unk_pixel_list).softmax(-1)[:, -1]
            loss_mining_unk = criterion_bce(mined_scores, torch.tensor([self.cfg.ADA.mining_th] * len(mined_scores)).cuda())
            loss_mining_dict.update(loss_mining_unk=loss_mining_unk)


        return loss_mining_dict, p_xk_x





    def forward_backward(self, batch):
        input, label = self.parse_batch_train(batch)
        domain_label = batch['domain'].to(self.device)
        feat, rois = self.model(input, domain_label, label, return_rois=True) 
        loss_fda = 0., 0.


        output = self.model.classifier(feat) 
        output_smooth = output / self.cfg.ADA.smooth_coef

        loss_ce = F.cross_entropy(output_smooth, label)

        B,C= output.size(0),output.size(1)  
        fake_label = torch.tensor(C-1).repeat(B).to(output.device)
        output2 = output.clone()
        index_repeat = torch.arange(C).unsqueeze(0).repeat((B,1)).to(output.device)
        label_repeat = label.unsqueeze(1).repeat(1,C)
        small_negative_value = -1e8
        negative_nograd = torch.full_like(output2, small_negative_value).detach()
        output2 = torch.where(label_repeat != index_repeat, output2, negative_nograd) 
        loss_ua = F.cross_entropy(output2, fake_label) 

        penalty1 = (output ** 2).mean()

        loss_mining_dict, p_xk_x = self.forward_FDA(rois, label)

        loss_fda = sum(loss for loss in loss_mining_dict.values())

        if self.epoch < self.cfg.ADA.warmup_epoch:
            loss = loss_ce
        else:
            loss = loss_ce + self.cfg.ADA.ua_loss_coef * loss_ua + penalty1 * self.cfg.ADA.penalty_coef + self.cfg.ADA.fda_loss_coef * (loss_fda)

        self.model_backward_and_update(loss, names=['backbone', 'classifier']) # 分开optimize

        loss_summary = {
            "l_ce": loss_ce.item(),
            "l_ua": loss_ua.item(),
            "l_fda": loss_fda,
            "pxk|x": p_xk_x,
            "train_acc": compute_accuracy(output, label)[0].item(),
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary
        

    def parse_batch_train(self, batch):
        input = batch["img"]
        label = batch["label"]
        input = input.to(self.device)
        label = label.to(self.device)
        return input, label
    

    def forward_DCA(self, rois, domain='source'):
        domain_label = 1.0 if domain == 'source' else 0.0
        bs, c, h, w = rois.size()
        rois_flatten = rois.permute(0, 2, 3, 1).contiguous().view(-1, c)  #  (bs * h * w, c)

        self.model.classifier.apply(self.fix_bn)
        scores = self.model.classifier(rois_flatten).softmax(-1).detach()
        self.model.classifier.apply(self.enable_bn)

        target = torch.full((rois_flatten.size(0),),
                        domain_label,
                        dtype=torch.float,
                        device=rois_flatten.device)
        
        weight_unk = scores[:, -1] 
        weight_k = scores[:, :-1].sum(-1) 

        adv_k = self.model.adv_k(rois_flatten, self.cfg.ADA.adv_grl) 
        adv_unk = self.model.adv_unk(rois_flatten,  self.cfg.ADA.adv_grl)

        loss_adv_k = (criterion_bce_red(adv_k, target) * weight_k).mean()
        loss_adv_unk = (criterion_bce_red(adv_unk, target) * weight_unk).mean()

        return dict(loss_adv_k = loss_adv_k, loss_adv_unk=loss_adv_unk)




    def tta_model_zero_grad(self, names=None):
        names = self.get_model_names(names)
        for name in names:
            if self._tta_optims[name] is not None:
                self._tta_optims[name].zero_grad()

    def tta_model_backward(self, loss):
        self.detect_anomaly(loss)
        loss.backward()

    def tta_model_update(self, names=None):
        names = self.get_model_names(names)
        for name in names:
            if self._tta_optims[name] is not None:
                self._tta_optims[name].step()

    def tta_backward_and_update(self, loss, names=None):
        self.tta_model_zero_grad(names)
        self.tta_model_backward(loss)
        self.tta_model_update(names)


    # if use Domain Adaptation
    def test(self, split=None):
        self.set_model_mode("train")

        print("Before DA: ====================================================== ")
        self.test_after_DA()        
        
        self._tta_optims = OrderedDict()
            
        self._models['backbone_tta'] = self.model.backbone
        self._models['adv_k'] = self.model.adv_k
        self._models['adv_unk'] = self.model.adv_unk
        self._models['classifier_tta'] = self.model.classifier
        tta_backbone_optimizer = build_optimizer(self._models['backbone_tta'], self.cfg.ADA.TTA.OPTIM)
        for param_group in tta_backbone_optimizer.param_groups:
            param_group['lr'] /= 10
        tta_classifier_optimizer = build_optimizer(self._models['classifier_tta'], self.cfg.ADA.TTA.OPTIM)
        for param_group in tta_classifier_optimizer.param_groups:
            param_group['lr'] /= 10
        tta_adv_k_optimizer = build_optimizer(self._models['adv_k'], self.cfg.ADA.TTA.OPTIM)
        tta_adv_unk_optimizer = build_optimizer(self._models['adv_unk'], self.cfg.ADA.TTA.OPTIM)
        self._tta_optims["backbone_tta"] = tta_backbone_optimizer
        self._tta_optims["adv_k"] = tta_adv_k_optimizer
        self._tta_optims["adv_unk"] = tta_adv_unk_optimizer
        self._tta_optims["classifier_tta"] = tta_classifier_optimizer

        test_loader = self.test_loader 
        original_model = copy.deepcopy(self.model)


        prototype_file = torch.load(self.cfg.ADA.DSPATH)
 
        source_samples = prototype_file['data'][2][0].tolist()
        source_labels = prototype_file['data'][2][1].tolist()



        for epoch in range(self.cfg.ADA.TTA.epoch):

            Loss_s = AverageMeter()
            Loss_t = AverageMeter()
            Loss_ce = AverageMeter()
            Loss_sc = AverageMeter()

            for batch_idx, batch in enumerate(tqdm(test_loader)):

                test_input, test_label = self.parse_batch_test(batch)
                sampled_val_images, sampled_val_labels =  sample_batch(source_samples, source_labels, self.cfg.DATALOADER.TEST.BATCH_SIZE)


                sampled_val_images = torch.tensor(sampled_val_images).cuda()
                sampled_val_labels = torch.tensor(sampled_val_labels).cuda()


                feat, rois_t = self.model(test_input, None, test_label, return_rois=True)
                loss_align_t = self.forward_DCA(rois_t, domain='target')
                loss_t = sum(loss for loss in loss_align_t.values())


                _, rois_s = self.model(sampled_val_images, None, sampled_val_labels, return_rois=True)
                loss_align_s = self.forward_DCA(rois_s, domain='source')
                loss_s = sum(loss for loss in loss_align_s.values())
                

                Loss_s.update(loss_s)
                Loss_t.update(loss_t)
                

                _, test_styles = self.model.backbone.forward_DM(x=test_input, style_index=None, styles=None, norm_layer=[], inject_layer=[], return_style=True)
                feat_style = self.model.backbone.forward_DM(x=sampled_val_images, style_index=None, styles=test_styles, norm_layer=[], inject_layer=[2], return_style=False)
                output_style = self.model.classifier(feat_style)
                loss_style_ce = F.cross_entropy(output_style, sampled_val_labels)

                Loss_ce.update(loss_style_ce)


                feat_original, _ = original_model(test_input)
                output_original = original_model.classifier(feat_original)
                if test_input.shape[0] !=1:
                    feat_now, rois_now = self.model(test_input)
                    output_now = self.model.classifier(feat_now)
                loss_sc = cross_entropy_loss_with_2logits(output_original, output_now)


                loss = loss_s + loss_t + self.cfg.ADA.TTA.tta_ce_coef * loss_style_ce + self.cfg.ADA.TTA.tta_sc_coef * loss_sc 
                self.tta_backward_and_update(loss=loss, names=['backbone_tta', 'classifier_tta', 'adv_k', 'adv_unk'])

            print(f"[DA Epoch {epoch+1}/{self.cfg.ADA.TTA.epoch}], loss_s: {Loss_s.avg}, loss_t: {Loss_t.avg}, loss_ce: {Loss_ce.avg}, loss_sc :{Loss_sc.avg}")
            
            # after DA, evaluate
            self.test_after_DA()






    # # if TTA Online
    # def test(self, split=None):
    #     self.set_model_mode("train")

    #     print("Before DA: ====================================================== ")
    #     self.test_after_DA()


    #     self._tta_optims = OrderedDict()
            
    #     self._models['backbone_tta'] = self.model.backbone
    #     self._models['adv_k'] = self.model.adv_k
    #     self._models['adv_unk'] = self.model.adv_unk
    #     self._models['classifier_tta'] = self.model.classifier
    #     tta_backbone_optimizer = build_optimizer(self._models['backbone_tta'], self.cfg.ADA.TTA.OPTIM)
    #     for param_group in tta_backbone_optimizer.param_groups:
    #         param_group['lr'] /= 10
    #     tta_classifier_optimizer = build_optimizer(self._models['classifier_tta'], self.cfg.ADA.TTA.OPTIM)
    #     for param_group in tta_classifier_optimizer.param_groups:
    #         param_group['lr'] /= 10
    #     tta_adv_k_optimizer = build_optimizer(self._models['adv_k'], self.cfg.ADA.TTA.OPTIM)
    #     tta_adv_unk_optimizer = build_optimizer(self._models['adv_unk'], self.cfg.ADA.TTA.OPTIM)
    #     self._tta_optims["backbone_tta"] = tta_backbone_optimizer
    #     self._tta_optims["adv_k"] = tta_adv_k_optimizer
    #     self._tta_optims["adv_unk"] = tta_adv_unk_optimizer
    #     self._tta_optims["classifier_tta"] = tta_classifier_optimizer

    #     prototype_file = torch.load(self.cfg.ADA.DSPATH)
    #     source_samples = prototype_file['data'][2][0].tolist()
    #     source_labels = prototype_file['data'][2][1].tolist()



    #     test_loader = self.test_loader

    #     original_model = copy.deepcopy(self.model)
        
    #     for epoch in range(1): # only once

    #         Loss_s = AverageMeter()
    #         Loss_t = AverageMeter()
    #         Loss_ce = AverageMeter()
    #         Loss_sc = AverageMeter()
    #         Loss_fda = AverageMeter()

    #         for batch_idx, batch in enumerate(tqdm(test_loader)):

    #             test_input, test_label = self.parse_batch_test(batch)
    #             sampled_val_images, sampled_val_labels =  sample_batch(source_samples, source_labels, self.cfg.DATALOADER.TEST.BATCH_SIZE)

    #             # 如果用condensed的set
    #             sampled_val_images = torch.tensor(sampled_val_images).cuda()
    #             sampled_val_labels = torch.tensor(sampled_val_labels).cuda()

    #             for step in range(self.cfg.ADA.TTA.steps):
    #                 feat, rois_t = self.model(test_input, None, test_label, return_rois=True)
    #                 loss_align_t = self.forward_DCA(rois_t, domain='target')
    #                 loss_t = sum(loss for loss in loss_align_t.values())


    #                 _, rois_s = self.model(sampled_val_images, None, sampled_val_labels, return_rois=True)
    #                 loss_align_s = self.forward_DCA(rois_s, domain='source')
    #                 loss_s = sum(loss for loss in loss_align_s.values())
                

    #                 Loss_s.update(loss_s)
    #                 Loss_t.update(loss_t)
                

    #                 _, test_styles = self.model.backbone.forward_DM(x=test_input, style_index=None, styles=None, norm_layer=[], inject_layer=[], return_style=True)
    #                 feat_style = self.model.backbone.forward_DM(x=sampled_val_images, style_index=None, styles=test_styles, norm_layer=[], inject_layer=[2], return_style=False)
    #                 output_style = self.model.classifier(feat_style)
    #                 loss_style_ce = F.cross_entropy(output_style, sampled_val_labels)

    #                 Loss_ce.update(loss_style_ce)

    #                 feat_original, _ = original_model(test_input)
    #                 output_original = original_model.classifier(feat_original)
    #                 if test_input.shape[0] !=1:
    #                     feat_now, rois_now = self.model(test_input)
    #                     output_now = self.model.classifier(feat_now)
    #                 loss_sc = cross_entropy_loss_with_2logits(output_original, output_now)


    #                 loss = loss_s + loss_t + self.cfg.ADA.TTA.tta_ce_coef * loss_style_ce + self.cfg.ADA.TTA.tta_sc_coef * loss_sc 
    #                 self.tta_backward_and_update(loss=loss, names=['backbone_tta', 'classifier_tta', 'adv_k', 'adv_unk'])
       
    #             ############# evaluate
    #             self.set_model_mode("eval") 
    #             feat, _ = self.model(test_input, None, test_label, return_rois=True)
    #             output = self.model.classifier(feat)
    #             if self.cfg.MODEL.BACKBONE.NAME == 'mire_resnet18':
    #                 self.evaluator.process(output[0], test_label)
    #             else:
    #                 self.evaluator.process(output, test_label, None, self.model, self.num_classes+self.cfg.OSDG.ADD_DIMS, 'test')
    #                 if batch_idx == 22:
    #                     a=1

    #         results = self.evaluator.evaluate()
    #         for k, v in results.items():
    #             tag = "{}/{}".format(split, k)
    #             self.write_scalar(tag, v, self.epoch)

    #         return list(results.values())[0]




    @torch.no_grad()
    def test_after_DA(self, split=None):
        """A generic testing pipeline."""
        self.set_model_mode("eval")
        self.evaluator.reset()

        if split is None:
            split = self.cfg.TEST.SPLIT # 'test'
            stage =  'test'

        if split == "val" and self.val_loader is not None:
            data_loader = self.val_loader
            print("Do evaluation on {} set".format(split))
            stage = 'val'
        else:
            data_loader = self.test_loader
            print("Do evaluation on test set")
            stage = 'test'

        for batch_idx, batch in enumerate(tqdm(data_loader)):
            input, label = self.parse_batch_test(batch)

            output = self.model_inference(input) # [bs, num_classes+1]
            if self.cfg.MODEL.BACKBONE.NAME == 'mire_resnet18':
                self.evaluator.process(output[0], label)
            else:
                self.evaluator.process(output, label, None, self.model, self.num_classes+self.cfg.OSDG.ADD_DIMS, stage)
                if batch_idx == 22:
                    a=1

        results = self.evaluator.evaluate()

        for k, v in results.items():
            tag = "{}/{}".format(split, k)
            self.write_scalar(tag, v, self.epoch)

        return list(results.values())[0]

