import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from einops import rearrange
from deepspeed.ops.adam import FusedAdam, DeepSpeedCPUAdam
from deepspeed.runtime.lr_schedules import WarmupLR
import warnings
import clip
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, PNDMScheduler, DDPMScheduler
from transformers import CLIPTokenizer, CLIPTextModel
# from transformers import T5Tokenizer, T5ForConditionalGeneration
from modules.networks import EMAModel, EMAModel_for_deepspeed_stage3
from modules.StableDiffusion import UNetModel, FrozenOpenCLIPEmbedder, AutoencoderKL, NormResBlock
from torchvision.utils import save_image
from PIL import Image
import imageio
import math

def tensor_to_gif(tensor, H, W, output_gif_path, duration=0.25, loop=0):
    """
    将一个1024x1024的Torch张量图像切割成4x4的小块，生成16帧的GIF并存储到指定路径。
    
    :param tensor: 输入的Torch张量图像，形状为(3, 1024, 1024)
    :param output_gif_path: 输出GIF文件的路径
    :param duration: 每帧显示的时间，单位为秒
    :param loop: GIF循环播放的次数，0表示无限循环
    """
    tensor = (tensor + 1) * 127.5
    tensor = tensor.clamp(0, 255)
    # 转换为PIL图像
    img = Image.fromarray(tensor.permute(1, 2, 0).byte().cpu().numpy())
    
    # 切割成16个256x256的小块
    tiles = []
    _, th, tw = tensor.shape

    for i in range(th // H):
        for j in range(tw // W):
            left = j * W
            top = i * H
            right = left + W
            bottom = top + H
            tile = img.crop((left, top, right, bottom))
            tiles.append(tile)
    
    # 将小块组成帧并存储为GIF
    imageio.mimsave(output_gif_path, tiles, format='GIF', duration=duration, loop=loop)


class StableDiffusionTrainer(pl.LightningModule):
    def __init__(self, config):
        super(StableDiffusionTrainer, self).__init__()

        self.config = config
        self.name = config.name
        self.savevideo = config.savevideo
        self.H = config.H
        self.W = config.W
        self.T = config.T
        self.save_hyperparameters()
        # self.T = config['timesteps'] // config['splits']

        self.text_clip = FrozenOpenCLIPEmbedder(layer='penultimate')
        ckpt_path = 'StableDiffusion/v2-1_512-ema-pruned-openclip.ckpt'
        self.text_clip.load_state_dict(torch.load(ckpt_path, map_location='cpu'))

        self.scale_factor = 0.18215
        self.ae = AutoencoderKL(
            {
                'double_z': True,
                'z_channels': 4,
                'resolution': 256,
                'in_channels': 3,
                'out_ch': 3,
                'ch': 128,
                'ch_mult': (1, 2, 4, 4),
                'num_res_blocks': 2,
                'attn_resolutions': [],
                'dropout': 0.0,
            },
            {'target': 'torch.nn.Identity'},
            4,
        )
        self.ae.eval()
        self.ae.requires_grad_(False)
        ckpt_path = 'StableDiffusion/v2-1_512-ema-pruned-autoencoder.ckpt'
        self.ae.load_state_dict(torch.load(ckpt_path, map_location='cpu'))

        self.unet = UNetModel(
            use_checkpoint = True,
            use_fp16 = True,
            image_size = [self.H, self.W], # unused
            in_channels = 4,
            out_channels = 4,
            model_channels = 320,
            attention_resolutions = [ 4, 2, 1 ],
            num_res_blocks = 2,
            channel_mult = [ 1, 2, 4, 4 ],
            num_head_channels = 64, # need to fix for flash-attn
            use_spatial_transformer = True,
            use_linear_in_transformer = True,
            transformer_depth = 1,
            context_dim = 1024,
            legacy = False,
            spatial_finetune_blocks = 1,
            frame = self.T,
        )
        ckpt_path = 'StableDiffusion/v2-1_512-ema-pruned-diffusionmodel.ckpt'
        self.unet.load_state_dict(torch.load(ckpt_path, map_location='cpu'), strict=False)
        self.unet.requires_grad_(True)

        if config['ema']:
            self.ema_model = EMAModel(self.unet)
        else:
            self.ema_model = None

        self.scheduler = DDIMScheduler(
            beta_start         =     0.00085, 
            beta_end           =     0.0120, 
            beta_schedule      =     'scaled_linear', 
            clip_sample        =     False,
            set_alpha_to_one   =     False,
            steps_offset       =     1,
        )

    def configure_optimizers(self):
        opt = DeepSpeedCPUAdam(
            filter(lambda p: p.requires_grad, self.unet.parameters()), 
            lr=self.config["base_learning_rate"], 
            betas=(0.9, 0.9), 
            weight_decay=0.03
        )
        return opt

    @torch.no_grad()
    def encode_text(self, text):
        # text : ['a dog', 'a cat', 'a bird', ...]
        outputs = self.text_clip.encode(text)
        return outputs # b * l * c

    @torch.no_grad()
    def encode_video(self, image, unflatten=True):
        # image = rearrange(video, 'B C T H W -> (B T) C H W')
        z = self.ae.encode(image).sample() # (N*T, C, H, W)
        # if unflatten:
        #     z = rearrange(z, '(B T) C H W -> B C T H W', B=video.shape[0])
        return z * self.scale_factor

    @torch.no_grad()
    def decode_video(self, z, unflatten=False):
        # if unflatten:
        #     z = rearrange(z, 'B C T H W -> (B T) C H W')
        z = z * (1. / self.scale_factor)
        video = self.ae.decode(z) # (N*T, C, H, W)
        return video

    @torch.no_grad()
    def intergrate_batch(self, batch):
        _batch = {}

        text_latent = self.encode_text(batch['text']) # (N, 77, C)
        _batch['text'] = text_latent.to(self.dtype)
        video_latent = self.encode_video(batch['image']) # N * C * H * W
        # if self.stage == 1:
        noise = torch.randn_like(video_latent)
        # else:
        #     alpha = 1.0
        #     noise = torch.randn_like(video_latent[:, :, :1]) * ((alpha**2/(1+alpha**2))**0.5) \
        #           + torch.randn_like(video_latent) * ((1/(1+alpha**2))**0.5)
        t = torch.randint(0, self.scheduler.num_train_timesteps, (video_latent.shape[0],), dtype=torch.int64, device=self.device)
        xt_video = self.scheduler.add_noise(video_latent, noise, t)
        _batch['video_timestep'] = t
        _batch['video_target'] = noise.to(self.dtype)
        _batch['video'] = xt_video.to(self.dtype)
        _batch['fps'] = batch['fps']

        # _batch['dataset'] = batch['dataset']

        return _batch

    @property
    def dtype(self):
        return next(self.parameters()).data.dtype

    def on_train_batch_end(self, outputs, batch, batch_idx):
        if self.ema_model is not None:
            self.ema_model(self.unet)

    def training_step(self, batch, batch_idx):
        ### Step1: reconstruct batch for multi dataset training including noise preparing
        _batch = self.intergrate_batch(batch)

        ### Step2: Predict the noise residual
        video_out = self.unet(
            _batch['video'], 
            _batch['video_timestep'], 
            _batch['text'], 
            _batch['fps']
            # _batch['dataset']
        )

        ### Step3: Compute loss
        loss = 0.0
        
        loss_video = F.mse_loss(video_out.float().flatten(1), _batch['video_target'].float().flatten(1))
        self.log('loss_video', loss_video, prog_bar=True, sync_dist=True)
        loss += loss_video

        # loss_video_e = 

        # tiles = []
        # for i in range(4):
        #     for j in range(4):
        #         left = j * 256
        #         top = i * 256
        #         right = left + 256
        #         bottom = top + 256
        #         tile = loss_video_e[left: right + 1, bottom: top + 1]
        #         tiles.append(tile)

        # loss_chad = 

        return loss

    @torch.no_grad()
    def inference(self, 
        text=None, 
        fps=30,
        num_inference_steps=50, 
        do_classifier_free_guidance=True, 
        guidance_scale=7.0, 
        image=True,
        video=None,
        **extra_step_kwargs
    ):
        self.scheduler.set_timesteps(num_inference_steps, device=self.device)
        timesteps = self.scheduler.timesteps

        if isinstance(text, list) or isinstance(text, tuple):
            batch_size = len(text)
        elif isinstance(text, str):
            batch_size = 1
            text = [text]
        else:
            raise NotImplementedError

        if not do_classifier_free_guidance:
            text_latent = self.encode_text(text).to(self.dtype)
        else:
            text = text + [''] * batch_size
            text_latent = self.encode_text(text).to(self.dtype)
        # do_classifier_free_guidance = 0
        # text_latent = self.encode_text(['']).to(self.dtype)
        # text_latent = torch.cat([text_latent, text_latent], 0)
        video_latent = torch.randn((batch_size, 4, self.config['diff_latent_res_H'], self.config['diff_latent_res_W']), dtype=self.dtype, device=self.device)
        alpha = 1.0
        # if not fixed_noise or not hasattr(self, 'fixed_noise'):
        #     video_latent = torch.randn((batch_size, 4, 1, self.config['diff_latent_res_H'], self.config['diff_latent_res_W']), dtype=self.dtype, device=self.device) * ((alpha**2/(1+alpha**2))**0.5) \
        #             + torch.randn((batch_size, 4, self.T, self.config['diff_latent_res_H'], self.config['diff_latent_res_W']), dtype=self.dtype, device=self.device) * ((1/(1+alpha**2))**0.5)
        # else:
        # video_latent = self.fixed_noise

        # if fixed_noise and not hasattr(self, 'fixed_noise'):
        # self.fixed_noise = video_latent
        # guidance_scale = 100
        fps = torch.cat([fps, fps], 0)
        for t in timesteps:
            if do_classifier_free_guidance:
                input_video_latent = torch.cat([video_latent, video_latent], 0)
            else:
                input_video_latent = video_latent
            
            
    
            timestep = t.repeat(input_video_latent.shape[0]).contiguous()
            
            # dataset = [use_dataset] * input_video_latent.shape[0]
            noise_pred = self.unet(
                input_video_latent, 
                timestep, 
                text_latent,
                fps
                # dataset
            )

            if do_classifier_free_guidance:
                noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
                noise_pred = noise_pred_cond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

            video_latent = self.scheduler.step(noise_pred, t, video_latent, **extra_step_kwargs).prev_sample


        result = {}
        result['image'] = self.decode_video(video_latent, unflatten=True)
        # result['video'] = rearrange(result['video'], '(B T) C H W -> B T C H W', B=batch_size)

        return result

    def validation_step(self, batch, batch_idx):
        x_ren = self.inference(batch['text'], batch['fps'])
        if not os.path.exists(os.path.join(self.savevideo, self.name)):
            os.makedirs(os.path.join(self.savevideo, self.name))
        for i in range(len(x_ren['image'])):
            try:
                tensor_to_gif(batch['image'][i], self.H, self.W, os.path.join(self.savevideo, self.name, str(self.global_step) + '_' + str(batch_idx) + '_' + str(self.global_rank + torch.cuda.device_count() * i) + '_ori_' + batch['text'][i] + '.gif'))
            except:
                pass

            try:
                tensor_to_gif(x_ren['image'][i], self.H, self.W, os.path.join(self.savevideo, self.name, str(self.global_step) + '_' + str(batch_idx) + '_' + str(self.global_rank + torch.cuda.device_count() * i) + '_ren_' + batch['text'][i] + '.gif'))
            except:
                pass

    def test_step(self, batch, batch_idx):
        x_ren = self.inference(batch['text'], batch['fps'], video = self.encode_video(batch['image']))
        if not os.path.exists(os.path.join(self.savevideo, self.name)):
            os.makedirs(os.path.join(self.savevideo, self.name))
        for i in range(len(x_ren['image'])):
            try:
                tensor_to_gif(batch['image'][i], self.H, self.W, os.path.join(self.savevideo, self.name, str(self.global_step) + '_' + str(batch_idx) + '_' + str(self.global_rank + torch.cuda.device_count() * i) + '_ori_' + batch['text'][i] + '.gif'))
            except:
                pass
            
            try:
                tensor_to_gif(x_ren['image'][i], self.H, self.W, os.path.join(self.savevideo, self.name, str(self.global_step) + '_' + str(batch_idx) + '_' + str(self.global_rank + torch.cuda.device_count() * i) + '_ren_' + batch['text'][i] + '.gif'))
            except:
                pass
    