import torch
import torch.nn.functional as F
import lightning as L
import random
import matplotlib.pyplot as plt
import cmaps
import io
from PIL import Image
import torchvision.transforms as T
import numpy as np
from main import instantiate_from_config
from contextlib import contextmanager
from collections import OrderedDict
from einops import rearrange
import torch
from torchvision import transforms
from PIL import Image
from functools import partial

import numpy as np
import lzma
from omegaconf import OmegaConf
import os

from transcoder.scheduler.lr_scheduler import Scheduler_LinearWarmup, Scheduler_LinearWarmup_CosineDecay, Scheduler_LinearWarmup_CosineDecay_BSQ
from transcoder.scheduler.ema import LitEma


import torch
import torch.nn as nn
import torch.nn.functional as F

# from transcoder.models.transformer import TransformerDecoder, TransformerEncoder
from transcoder.models.TSSUN.vit_nlc_tout import ViT_Encoder, ViT_Decoder

class BSQModel(L.LightningModule):
    def __init__(self,
                # vitconfig,
                # lossconfig,
                # embed_dim,
                # embed_group_size=9,
                # ## Quantize Related
                # l2_norm=False, logit_laplace=False, ckpt_path=None, ignore_keys=[],
                # dvitconfig=None, beta=0., gamma0=1.0, gamma=1.0, zeta=1.0,
                # persample_entropy_compute='group',
                # cb_entropy_compute='group',
                # post_q_l2_norm=False,
                # inv_temperature=1.,
                feature_dim = 1024,
                enc_dim = 192,
                dec_dim = 192,
                in_chans = 32,
                out_chans = 32,
                perci_weight = 1,
                weather_weight = 0.5,
                sat_weight = 0.1,
                unify_time_length=2,
                img_size=(256, 256),

                ### scheduler config
                resume_lr=None,
                min_learning_rate = 0,
                use_ema = False,
                stage = None,
                lr_drop_epoch = None,
                lr_drop_rate = 0.1,
                warmup_epochs = 1.0, #warmup epochs
                scheduler_type = "linear-warmup_cosine-decay-bsq",
                lr_start = 0.1,
                lr_max = 1.0,
                lr_min = 0.5,

                ):
        super().__init__()

        large_default_dict =dict(
            drop_path_rate=0, use_abs_pos_emb=True,  # as in table 11
            patch_size=(12,12), patch_stride=(8,8), patch_padding=(2,2), in_chans=in_chans, out_chans=out_chans, embed_dim=feature_dim, depth=8,
            num_heads=16, mlp_ratio=4, qkv_bias=True,
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            z_dim =  None,
            learnable_pos= True,
            window= True,
            window_size = [(32, 8), (16, 16), (8, 32)],  # 为32,32 latent图像设计
            interval = 4,
            round_padding= True,
            pad_attn_mask= True , # to_do: ablation
            test_pos_mode= 'learnable_simple_interpolate', # to_do: ablation
            lms_checkpoint_train= True,
            img_size= img_size,
            enc_dim=enc_dim,
            unify_time_length=unify_time_length
        )  # 通过统一模块，把输入特征调整为固定的b,2,32,256,256大小的特征，然后进行8倍降采样，变为b,2,1024,32,32大小的特征
        large_default_dict_decoder =dict(
            drop_path_rate=0, use_abs_pos_emb=True,  # as in table 11
            patch_size=(12,12), patch_stride=(8,8), patch_padding=(2,2), in_chans=in_chans, out_chans=out_chans, embed_dim=feature_dim, depth=8,
            num_heads=16, mlp_ratio=4, qkv_bias=True,
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            z_dim =  None,
            learnable_pos= True,
            window= True,
            window_size = [(32, 8), (16, 16), (8, 32)],  # 为32,32 latent图像设计
            interval = 4,
            round_padding= True,
            pad_attn_mask= True , # to_do: ablation
            test_pos_mode= 'learnable_simple_interpolate', # to_do: ablation
            lms_checkpoint_train= True,
            img_size= img_size,
            dec_dim=dec_dim,
            unify_time_length=unify_time_length
        )  # 通过统一模块，把输入特征调整为固定的b,3,32,256,256大小的特征，然后进行8倍降采样，变为b,3,1024,32,32大小的特征

        self.latent_h = (large_default_dict['img_size'][0]-large_default_dict['patch_size'][0]+2*large_default_dict['patch_padding'][0]) // large_default_dict['patch_stride'][0] + 1
        self.encoder = ViT_Encoder(**large_default_dict)
        self.decoder = ViT_Decoder(**large_default_dict_decoder)
        self.criterion = nn.BCEWithLogitsLoss()
        # self.criterion = torch.nn.CrossEntropyLoss()

        self.in_chans = large_default_dict['in_chans']
        self.perci_weight = perci_weight
        self.weather_weight = weather_weight
        self.sat_weight = sat_weight

        self.use_ema = use_ema
        if self.use_ema and stage is None: #no need to construct EMA when training Transformer
            self.model_ema = LitEma(self)

        self.resume_lr = resume_lr
        self.lr_drop_epoch = lr_drop_epoch
        self.lr_drop_rate = lr_drop_rate
        self.scheduler_type = scheduler_type
        self.warmup_epochs = warmup_epochs
        self.min_learning_rate = min_learning_rate
        # self.automatic_optimization = False

        self.lr_start = lr_start
        self.lr_max = lr_max
        self.lr_min = lr_min

        self.strict_loading = False
        self.img2tensor = T.ToTensor()


        # self.dataset_name_list = ['whu', 'whucd', 'tscd', 'levircd']
        self.dataset_name_list = ['levircd']
        self.metrics_list = ['tp', 'fp', 'fn', 'tn']

        self.train_metrics = {
            f'{dataset_name_log}_{metric}': 0 for dataset_name_log in self.dataset_name_list for metric in self.metrics_list
        }
        self.validate_metrics = {
            f'{dataset_name_log}_{metric}': 0 for dataset_name_log in self.dataset_name_list for metric in self.metrics_list
        }

    @contextmanager
    def ema_scope(self, context=None):
        if self.use_ema:
            self.model_ema.store(self.parameters())
            self.model_ema.copy_to(self)
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.parameters())
                if context is not None:
                    print(f"{context}: Restored training weights")

    def load_state_dict(self, *args, strict=False):
        """
        Resume not strict loading
        """
        return super().load_state_dict(*args, strict=strict)

    def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
        '''
        filter out the non-used keys
        '''
        return {k: v for k, v in super().state_dict(*args, destination, prefix, keep_vars).items() if ("inception_model" not in k and "lpips_vgg" not in k and "lpips_alex" not in k)}
        

    def forward(self, input_tensor, dataset_name):


        h = self.encoder(input_tensor, dataset_name)  # b,c,h,w
        dec = self.decoder(h, dataset_name)
        
        return dec

    def on_train_start(self):
        """
        change lr after resuming
        """
        if self.resume_lr is not None:
            opt_gen, opt_disc = self.optimizers()
            for opt_gen_param_group, opt_disc_param_group in zip(opt_gen.param_groups, opt_disc.param_groups):
                opt_gen_param_group["lr"] = self.resume_lr
                opt_disc_param_group["lr"] = self.resume_lr

    def compute_metrics(self, predict, target, dataset_name_log):
        """
        计算 Precision, Recall, F1-score 和 IoU。
        
        参数:
            predict: (b, h, w) 的二值张量，预测结果
            target:  (b, h, w) 的二值张量，真实标签
        
        返回:
            metrics: 包含 Precision, Recall, F1-score 和 IoU 的字典
        """
        # 确保数据类型是 bool
        predict = predict.bool()
        target = target.bool()
        
        # 计算 True Positives, False Positives, False Negatives
        # tp = (predict & target).sum(dim=(1, 2))  # 计算每个 batch 的 TP
        # fp = (predict & ~target).sum(dim=(1, 2)) # 计算每个 batch 的 FP
        # fn = (~predict & target).sum(dim=(1, 2)) # 计算每个 batch 的 FN
        # tn = (~predict & ~target).sum(dim=(1, 2)) # 计算每个 batch 的 TN
        
        tp = (predict & target).sum()  # 计算每个 batch 的 TP
        fp = (predict & ~target).sum() # 计算每个 batch 的 FP
        fn = (~predict & target).sum() # 计算每个 batch 的 FN
        tn = (~predict & ~target).sum() # 计算每个 batch 的 TN

        if f'{dataset_name_log}_tp' not in self.train_metrics:
            self.train_metrics[f'{dataset_name_log}_tp'] = tp
        else:
            self.train_metrics[f'{dataset_name_log}_tp'] += tp
        if f'{dataset_name_log}_fp' not in self.train_metrics:
            self.train_metrics[f'{dataset_name_log}_fp'] = fp
        else:
            self.train_metrics[f'{dataset_name_log}_fp'] += fp
        if f'{dataset_name_log}_fn' not in self.train_metrics:
            self.train_metrics[f'{dataset_name_log}_fn'] = fn
        else:
            self.train_metrics[f'{dataset_name_log}_fn'] += fn
        if f'{dataset_name_log}_tn' not in self.train_metrics:
            self.train_metrics[f'{dataset_name_log}_tn'] = tn
        else:
            self.train_metrics[f'{dataset_name_log}_tn'] += tn
        # self.train_metrics['tp'] += tp
        # self.train_metrics['fp'] += fp
        # self.train_metrics['fn'] += fn
        # self.train_metrics['tn'] += tn

        # 计算 Precision, Recall, F1-score
        precision = (self.train_metrics[f'{dataset_name_log}_tp'] + 1e-8) / (self.train_metrics[f'{dataset_name_log}_tp'] + self.train_metrics[f'{dataset_name_log}_fp'] + 1e-8)  # 避免除零
        recall = (self.train_metrics[f'{dataset_name_log}_tp'] + 1e-8) / (self.train_metrics[f'{dataset_name_log}_tp'] + self.train_metrics[f'{dataset_name_log}_fn'] + 1e-8)
        f1_score = 2 * (precision * recall) / (precision + recall + 1e-8)
        
        # 计算 IoU (Intersection over Union)
        iou = (self.train_metrics[f'{dataset_name_log}_tp'] + 1e-8) / (self.train_metrics[f'{dataset_name_log}_tp'] + self.train_metrics[f'{dataset_name_log}_fp'] + self.train_metrics[f'{dataset_name_log}_fn'] + 1e-8)
        
        # 取 batch 维度的平均值
        metrics = {
            f"precision_{dataset_name_log}": precision.mean().item(),
            f"recall_{dataset_name_log}": recall.mean().item(),
            f"f1_score_{dataset_name_log}": f1_score.mean().item(),
            f"iou_{dataset_name_log}": iou.mean().item()
        }
        
        return metrics



    # fix mulitple optimizer bug
    # refer to https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html
    def training_step(self, batch, batch_idx):
        # input_tensor: b,t,c,h,w
        # label_tensor: b,t,h,w

        input_tensor, label_tensor, dataset_name = batch
        b,t,c,h,w = input_tensor.shape
        dataset_name_log = dataset_name[0]



        predict_tensor = self(input_tensor, dataset_name)  # b,t,h,w,c
        # print(f'input_tensor min max:{input_tensor.min()}, {input_tensor.max()}')
        # print(f'predict_tensor min max:{predict_tensor.min()}, {predict_tensor.max()}')

        predict_tensor_loss = rearrange(predict_tensor.squeeze(-1), 'b t h w -> (b t) h w').to(torch.float32)
        label_tensor_loss = rearrange(label_tensor.squeeze(-1), 'b t h w -> (b t) h w').to(torch.float32)
        # 计算 BCEWithLogitsLoss
        bce_loss = self.criterion(predict_tensor_loss, label_tensor_loss)
        
        # 计算 Dice Loss
        # 对预测结果应用 sigmoid 激活函数
        sigmoid_pred = torch.sigmoid(predict_tensor_loss)
        
        # 计算交集
        intersection = (sigmoid_pred * label_tensor_loss).sum()
        
        # 计算并集
        cardinality = sigmoid_pred.sum() + label_tensor_loss.sum()
        
        # 计算 Dice Loss
        dice_loss = 1 - (2 * intersection + 1e-8) / (cardinality + 1e-8)
        
        # 组合 BCE 和 Dice Loss
        loss = bce_loss + dice_loss
        
        # loss = self.criterion(predict_tensor_loss, label_tensor_loss)

        # print(f'loss:{loss}')

        # print(f'predict_tensor_loss min max:{predict_tensor_loss.min()}, {predict_tensor_loss.max()}')
        # print(f'label_tensor_loss min max:{label_tensor_loss.min()}, {label_tensor_loss.max()}')
        predict_tensor = predict_tensor.squeeze(-1) > 0  # b,t,h,w
        # print(f'predict_tensor max:{predict_tensor.max()}, min:{predict_tensor.min()}')
        predict_tensor_save = predict_tensor.clone()
        label_tensor_save = label_tensor.clone()
        input_tensor_save = input_tensor.clone()

        predict_tensor = rearrange(predict_tensor, 'b t h w -> (b t) h w')
        label_tensor = rearrange(label_tensor, 'b t h w -> (b t) h w')
        metrics = self.compute_metrics(predict_tensor, label_tensor, dataset_name_log)

        if metrics[f'f1_score_{dataset_name_log}'] > 0.9:
            np.save(f'/mnt/petrelfs/xxx/TSSUN_remote_sensing/paper_figure/building/{batch_idx}_{dataset_name_log}_predict.npy', predict_tensor_save.cpu().numpy())
            np.save(f'/mnt/petrelfs/xxx/TSSUN_remote_sensing/paper_figure/building/{batch_idx}_{dataset_name_log}_label.npy', label_tensor_save.cpu().numpy())
            np.save(f'/mnt/petrelfs/xxx/TSSUN_remote_sensing/paper_figure/building/{batch_idx}_{dataset_name_log}_input.npy', input_tensor_save.cpu().numpy())
        metrics_log = {f'train/{k}_{dataset_name_log}': v for k, v in metrics.items()}

        loss_log_dict = {f'train/loss_{dataset_name_log}': loss,
                         f'train/bce_loss_{dataset_name_log}': bce_loss,
                         f'train/dice_loss_{dataset_name_log}': dice_loss
                         }
        loss_log_dict.update(metrics_log)
        
        self.log_dict(loss_log_dict, prog_bar=False, logger=True, on_step=True, on_epoch=True)

        return loss




    def on_train_batch_end(self, *args, **kwargs):
        if self.use_ema:
            self.model_ema(self)
    
    def on_train_epoch_end(self):
        for dataset_name_log in self.dataset_name_list:
            precision = (self.train_metrics[f'{dataset_name_log}_tp'] + 1e-8) / (self.train_metrics[f'{dataset_name_log}_tp'] + self.train_metrics[f'{dataset_name_log}_fp'] + 1e-8)  # 避免除零
            recall = (self.train_metrics[f'{dataset_name_log}_tp'] + 1e-8) / (self.train_metrics[f'{dataset_name_log}_tp'] + self.train_metrics[f'{dataset_name_log}_fn'] + 1e-8)
            f1_score = 2 * (precision * recall) / (precision + recall + 1e-8)
            iou = (self.train_metrics[f'{dataset_name_log}_tp'] + 1e-8) / (self.train_metrics[f'{dataset_name_log}_tp'] + self.train_metrics[f'{dataset_name_log}_fp'] + self.train_metrics[f'{dataset_name_log}_fn'] + 1e-8)

            self.log_dict({
                f'epoch_train/precision_{dataset_name_log}': precision,
                f'epoch_train/recall_{dataset_name_log}': recall,
                f'epoch_train/f1_score_{dataset_name_log}': f1_score,
                f'epoch_train/iou_{dataset_name_log}': iou
            }, prog_bar=False, logger=True, on_step=False, on_epoch=True)

        self.train_metrics = {
        }
        self.validate_metrics = {
        }


    def validation_step(self, batch, batch_idx): 
        if self.use_ema:
            with self.ema_scope():
                log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
        else:
            log_dict = self._validation_step(batch, batch_idx)

            
    def _validation_step(self, batch, batch_idx):
        # input_tensor: b,t,c,h,w
        # label_tensor: b,t,h,w

        input_tensor, label_tensor, dataset_name = batch
        b,t,c,h,w = input_tensor.shape
        dataset_name_log = dataset_name[0]


        predict_tensor = self(input_tensor, dataset_name)  # b,t,h,w,c
        # print(f'input_tensor min max:{input_tensor.min()}, {input_tensor.max()}')
        # print(f'predict_tensor min max:{predict_tensor.min()}, {predict_tensor.max()}')

        predict_tensor_loss = rearrange(predict_tensor.squeeze(-1), 'b t h w -> (b t) h w').to(torch.float32)
        label_tensor_loss = rearrange(label_tensor.squeeze(-1), 'b t h w -> (b t) h w').to(torch.float32)
        # 计算 BCEWithLogitsLoss
        bce_loss = self.criterion(predict_tensor_loss, label_tensor_loss)
        
        # 计算 Dice Loss
        # 对预测结果应用 sigmoid 激活函数
        sigmoid_pred = torch.sigmoid(predict_tensor_loss)
        
        # 计算交集
        intersection = (sigmoid_pred * label_tensor_loss).sum()
        
        # 计算并集
        cardinality = sigmoid_pred.sum() + label_tensor_loss.sum()
        
        # 计算 Dice Loss
        dice_loss = 1 - (2 * intersection + 1e-8) / (cardinality + 1e-8)
        
        # 组合 BCE 和 Dice Loss
        loss = bce_loss + dice_loss

        # loss = self.criterion(predict_tensor_loss, label_tensor_loss)

        # print(f'predict_tensor_loss min max:{predict_tensor_loss.min()}, {predict_tensor_loss.max()}')
        # print(f'label_tensor_loss min max:{label_tensor_loss.min()}, {label_tensor_loss.max()}')
        predict_tensor = predict_tensor.squeeze(-1) > 0  # b,t,h,w
        # print(f'predict_tensor max:{predict_tensor.max()}, min:{predict_tensor.min()}')
        predict_tensor = rearrange(predict_tensor, 'b t h w -> (b t) h w')
        label_tensor = rearrange(label_tensor, 'b t h w -> (b t) h w')
        metrics = self.compute_metrics(predict_tensor, label_tensor, dataset_name_log)
        


        loss_log_dict = {'val/loss': loss,
                         'val/bce_loss': bce_loss,
                         'val/dice_loss': dice_loss
                         }
        loss_log_dict.update(metrics)
        
        self.log_dict(loss_log_dict, prog_bar=False, logger=True, on_step=True, on_epoch=True)

        return loss
    



    def configure_optimizers(self):
        lr = self.learning_rate
        print(f'optimizer lr: {lr}')
        opt_gen = torch.optim.AdamW(list(self.encoder.parameters())+
                                  list(self.decoder.parameters()),
                                  lr=lr, betas=(0.9, 0.99), weight_decay=1e-4, eps=1e-8)



        if self.trainer.is_global_zero:
            print("step_per_epoch: {}".format(len(self.trainer.datamodule._train_dataloader()) // self.trainer.world_size))

        step_per_epoch  = len(self.trainer.datamodule._train_dataloader()) // self.trainer.world_size
        warmup_steps = step_per_epoch * self.warmup_epochs
        training_steps = step_per_epoch * self.trainer.max_epochs
        max_decay_steps = training_steps
        # max_decay_steps = training_steps * 0.7  # TODO ZSJ 注意, 这里乘以0.7, 是因为single变量可能由于学习率太大训练崩溃了, 所以减少训练步数来降低学习率, 正常训练不需要乘0.7

        if self.scheduler_type == "None":
            return ({"optimizer": opt_gen})
    
        if self.scheduler_type == "linear-warmup":
            scheduler_ae = torch.optim.lr_scheduler.LambdaLR(opt_gen, Scheduler_LinearWarmup(warmup_steps))

        elif self.scheduler_type == "linear-warmup_cosine-decay":
            multipler_min = self.min_learning_rate / self.learning_rate
            scheduler_ae = torch.optim.lr_scheduler.LambdaLR(opt_gen, Scheduler_LinearWarmup_CosineDecay(warmup_steps=warmup_steps, max_steps=training_steps, multipler_min=multipler_min))

        elif self.scheduler_type == "linear-warmup_cosine-decay-bsq":
            multipler_min = self.min_learning_rate / self.learning_rate

            scheduler_ae = {
                "scheduler": torch.optim.lr_scheduler.LambdaLR(opt_gen, Scheduler_LinearWarmup_CosineDecay_BSQ(warmup_steps=warmup_steps, lr_min=self.lr_min, lr_max=self.lr_max, lr_start=self.lr_start, max_decay_steps=max_decay_steps)),
                "interval": "step",  # 设置为 step 级别更新
                "frequency": 1,      # 每个 step 更新
            }
        else:
            raise NotImplementedError()
        return [
            {"optimizer": opt_gen, "lr_scheduler": scheduler_ae}, 
            ]



    def visualize_tensor_as_heatmap(self, tensor):
        # 创建一个新的图像
        fig, ax = plt.subplots()

        dict = cmaps.WhiteBlueGreenYellowRed

        # 将 tensor 转换为 numpy 格式
        tensor = tensor.detach().cpu().float().numpy()

        tensor = (tensor - tensor.mean())/tensor.std()

        contour = ax.contourf(tensor, levels=[-5 + 0.2 * x for x in range(50)], cmap=dict, extend='both')

        # 添加 colorbar
        fig.colorbar(contour, ax=ax)

        # 将图像保存到内存中的字节缓冲区
        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        buf.seek(0)

        # 使用 PIL.Image 打开图像
        image = Image.open(buf)

        image = self.img2tensor(image)

        image = image*2 -1  # 把范围从0-1变换到-1到1

        # 关闭当前图像, 避免后续绘制时重叠
        plt.close(fig)

        # 关闭缓冲区
        buf.close()

        return image

    def log_images(self, batch, **kwargs):
        # input_tensor: b,t,c,h,w
        # label_tensor: b,t,h,w
        log = dict()

        with torch.inference_mode():

            input_tensor, label_tensor, dataset_name = batch
            b,t,c,h,w = input_tensor.shape

            # # 不能把batch_size设置为1, 否则会导致计算的时候因为batch size不一致，效果下降很多
            # label_tensor = label_tensor[0].unsqueeze(0)  # b,t,h,w
            # input_tensor = input_tensor[0].unsqueeze(0)  # b,t,c,h,w

            dataset_name_log = dataset_name[0]


            predict_tensor = self(input_tensor, dataset_name)  # b,t,h,w
            # predict_tensor = torch.argmax(predict_tensor, dim=-1)[0]*2 -1  # t,h,w
            predict_tensor = predict_tensor.squeeze(-1) > 0  # b,t,h,w
            predict_tensor = predict_tensor[0]*2 -1
            label_tensor = label_tensor[0]*2 -1

            t,h,w = predict_tensor.shape

            for i in range(t):
                log[f'{dataset_name_log}_predict_time_{i}'] = predict_tensor[i].unsqueeze(0).unsqueeze(0).repeat(1,3,1,1)
                log[f'{dataset_name_log}_label_time_{i}'] = label_tensor[i].unsqueeze(0).unsqueeze(0).repeat(1,3,1,1)
                print('loged!')

        return log
