
from typing import Tuple
from einops import rearrange
from torch import nn
import torch.distributed as dist
import torch

from utils.loss import get_denoising_loss
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper, WanCLIPEncoder

import os 

from pipeline import BidirectionalTrainingPipeline

import torch.nn.functional as F
from typing import Optional, Tuple
import torch

from torchvision.io import write_video
from einops import rearrange
import os
import math

from utils.rm_video_align_wrapper import VideoAlignWrapper

import torch
import torch.distributed as dist
from tqdm.auto import tqdm

def sd3_time_shift(shift, t):
    return (shift * t) / (1 + (shift - 1) * t)

def all_gather_same_group(tensor, group_size, subgroups):
    rank = dist.get_rank()
    group_id = rank // group_size
    # print(f"GET group_id = {group_id}")
    group = subgroups[group_id]  # 从缓存中获取子组

    gathered = [torch.empty_like(tensor) for _ in range(group_size)]
    dist.all_gather(gathered, tensor, group=group)
    gathered = torch.stack(gathered)  # shape [group_size]
    return gathered 



def sde_step(pred_flow, x_t, x_t_prev, sigmas, time_idx, eta=0.25):
    sigma = sigmas[time_idx]
    dsigma = sigmas[time_idx + 1] - sigma
    x_t_prev_mean = x_t + dsigma * pred_flow

    pred_original_sample = x_t - sigma * pred_flow

    delta_t = sigma - sigmas[time_idx + 1]
    std_dev_t = eta * math.sqrt(delta_t)

    score_estimate = -(x_t-pred_original_sample*(1 - sigma))/sigma**2
    log_term = -0.5 * eta**2 * score_estimate
    x_t_prev_mean = x_t_prev_mean + log_term * dsigma

    if x_t_prev is None:
        x_t_prev = x_t_prev_mean + torch.randn_like(x_t_prev_mean) * std_dev_t 
    

    # log prob of x_t_prev given x_t_prev_mean and std_dev_t
    log_prob = (
        -((x_t_prev.detach().to(torch.float32) - x_t_prev_mean.to(torch.float32)) ** 2) / (2 * (std_dev_t**2))
    ) - math.log(std_dev_t)- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))

    # mean along all but batch dimension
    log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
    return x_t_prev, pred_original_sample, log_prob





class T2V_DANCE_GRPO(nn.Module):
    def __init__(self, args, device):
        super().__init__()
        
        self.sample_steps = 16
        sigma_schedule = torch.linspace(1, 0, self.sample_steps + 1)
        self.timestep_shift = getattr(args, "timestep_shift", 5.0)  # 5.0
        self.sigma_schedule = sd3_time_shift(self.timestep_shift, sigma_schedule)
        print(f"self.sigma_schedule = {self.sigma_schedule}")

        self._initialize_models(args, device)

        self.global_rank = dist.get_rank()
        self.is_main_process = self.global_rank == 0
        
        self.step = 0
        
        self.gen_train_videos_dir = os.path.join(args.logdir, "gen_train_videos")
        self.gen_eval_videos_dir = os.path.join(args.logdir, "gen_eval_videos")
        self.gen_log_videos_dir = os.path.join(args.logdir, "gen_log_videos")
        if self.is_main_process:
            os.makedirs(self.gen_train_videos_dir, exist_ok=True)
            os.makedirs(self.gen_eval_videos_dir, exist_ok=True)
            os.makedirs(self.gen_log_videos_dir, exist_ok=True)

        self.device = device
        self.args = args

        self.dtype = torch.bfloat16 if args.mixed_precision else torch.float32

        # Step 2: Initialize all dmd hyperparameters
        self.num_train_timestep = args.num_train_timestep  # 1000
        self.min_step = int(0.02 * self.num_train_timestep)  # 20
        self.max_step = int(0.98 * self.num_train_timestep)  # 980

        self.real_guidance_scale = args.guidance_scale  # 6.0
        
        self.min_score_timestep = getattr(args, "min_score_timestep", 0)


    def _initialize_models(self, args, device):
        self.real_model_name = getattr(args, "real_name", "Wan2.1-I2V-14B-720P") 
        self.generator_name = getattr(args, "generator_name", "Wan2.1-I2V-14B-720P") 

        self.real_model_path = getattr(args, "real_path") 
        self.generator_path = getattr(args, "generator_path")

        print("--------- loading generator model")
        self.generator = WanDiffusionWrapper(
            **getattr(args, "model_kwargs", {}), # timestep_shift: 5.0
            model_name=self.generator_name, model_path=self.generator_path, wan_type="gen"
        )
        self.generator.model.requires_grad_(True)

        print("--------- loading real_score model")
        self.real_score = WanDiffusionWrapper(model_name=self.real_model_name, model_path=self.real_model_path, wan_type="real")
        self.real_score.model.requires_grad_(False)


        print(f"loading WanVAEWrapper ...")
        self.vae = WanVAEWrapper(model_name=self.generator_name).eval()
        self.vae.requires_grad_(False)
        self.vae = self.vae.to(dtype=torch.bfloat16)

        self.generator.enable_gradient_checkpointing()

        print(f"--------- loading RewardModel ...")
        video_align_path = 
        self.reward_model = VideoAlignWrapper(video_align_path, device)
        self.reward_model.inferencer.model.requires_grad_(False)

        self.scheduler = self.generator.get_scheduler()
        self.scheduler.timesteps = self.scheduler.timesteps.to(device)
        self.scheduler.alphas_cumprod = None  # None

    def _get_noise_from_rank0(self, noise_shape):
        rank = dist.get_rank()

        if rank == 0:
            noise_buffer = torch.randn(*noise_shape, device=self.device, dtype=self.dtype)          # 只在 rank0 生成
        else:
            noise_buffer = torch.randn(*noise_shape, device=self.device, dtype=self.dtype)          # 其余 rank 占位
        gathered_noise = all_gather_same_group(noise_buffer, self.group_size, self.subgroups)
        noise = gathered_noise[0]  # shape 与 image_or_video 完全一致

        return noise

    def run_sample_step(self, z, progress_bar, conditional_dict):
        all_latents = [z]
        all_log_probs = []

        for i in progress_bar:  
            B = 1
            sigma = self.sigma_schedule[i]
            timestep_value = int(sigma * 1000)
            timestep = torch.ones(z.shape[:2], dtype=torch.long, device=z.device) * timestep_value 
            pred_flow, pred_image = self.real_score(
                noisy_image_or_video=z,
                conditional_dict=conditional_dict,
                timestep=timestep,
            )

            z, pred_original, log_prob = sde_step(pred_flow, z.to(torch.float32), x_t_prev=None, 
                sigmas=self.sigma_schedule, time_idx=i)
            z.to(torch.bfloat16)
            all_latents.append(z)
            all_log_probs.append(log_prob)

        latents = pred_original.to(torch.bfloat16)
        all_latents = torch.stack(all_latents, dim=1)  # (batch_size, num_steps + 1, 4, 64, 64)
        all_log_probs = torch.stack(all_log_probs, dim=1)  # (batch_size, num_steps, 1)
        return z, latents, all_latents, all_log_probs


    def grpo_loss(
        self, prompt, gas,
        image_or_video_shape,
        conditional_dict: dict, unconditional_dict: dict,
    ) -> Tuple[torch.Tensor, dict]:
        # Step 1: sample ref
        noise = self._get_noise_from_rank0(image_or_video_shape)
        progress_bar = tqdm(range(0, self.sample_steps), desc="Sampling Progress")
        

        with torch.no_grad():
            z, latents, all_latents, all_log_probs = self.run_sample_step(
                noise, progress_bar, conditional_dict
            )
        

        reward_log_dict = self.save_latent_and_reward(latents, f"ref_sample", prompt)
        

        bs = 1
        timestep_value = [int(sigma * 1000) for sigma in self.sigma_schedule][:self.sample_steps]
        print(f"timestep_value = {timestep_value}")
        device = all_latents.device
        T = len(timestep_value)
        timesteps =  torch.tensor(timestep_value, device=device, dtype=torch.long).unsqueeze(0)  # [1, T]
        samples = {
            "timesteps": timesteps.detach().clone()[:, :-1],
            "latents": all_latents[:, :-1][:, :-1],  # each entry is the latent before timestep t
            "next_latents": all_latents[:, 1:][:, :-1],  # each entry is the latent after timestep t
            "log_probs": all_log_probs[:, :-1],
            "avg_rewards": reward_log_dict['avg_reward']
        }

        reward = reward_log_dict['avg_reward']
        # reward = reward_log_dict['ta_reward']
        reward_tensor = torch.tensor([reward], dtype=torch.float32, device=self.device)
        all_rewards = all_gather_same_group(reward_tensor, self.group_size, self.subgroups)  # [32]
        all_rewards = all_rewards.squeeze(-1)  

        group_mean = all_rewards.mean()
        group_std  = all_rewards.std() + 1e-8
        advantages = ((all_rewards - group_mean) / group_std).squeeze(-1) 
        # 当前 rank 在 group 内的索引
        local_rank_in_group = self.global_rank % self.group_size

        # 当前 rank 对应的 advantage
        local_advantage = advantages[local_rank_in_group]

        timestep_fraction = 0.6
        train_timesteps = int(T * timestep_fraction)

        # 给当前样本（batch=1）做一次 permutation
        perm = torch.randperm(T-1, device=device)              # T-1 是因为前面切片掉了最后一个
        chosen_ids = perm[:train_timesteps]                    # 要训练的 timestep 下标

        # 4. 只循环被选中的 timestep
        clip_range  = 1e-4
        adv_clip_max = 5.0

        for t_idx in chosen_ids:
            new_log_probs = self.grpo_one_step(
                samples["latents"][:, t_idx],           # [1, ...]
                samples["next_latents"][:, t_idx],
                conditional_dict,
                samples["timesteps"][:, t_idx],
                perm[t_idx],                            # 如果 grpo_one_step 里需要原始 sigma 索引
            )
            print(f"new_log_probs = {new_log_probs}")
            ratio = torch.exp(new_log_probs - samples["log_probs"][:, t_idx])

            unclipped_loss = -local_advantage * ratio
            clipped_loss = -local_advantage * torch.clamp(
                ratio,
                1.0 - clip_range,
                1.0 + clip_range,
            )
            loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss)) / (train_timesteps * gas)
            loss.backward()
            print(f"loss = {loss}")
    
        return loss, reward_log_dict

    def grpo_one_step(self,
        latents, pre_latents, conditional_dict,
        timesteps, i):
        # print(f"timesteps-------- {timesteps}")
        pred_flow, pred_image_or_video = self.generator(
                noisy_image_or_video=latents, conditional_dict=conditional_dict,
                timestep=torch.ones(latents.shape[:2], dtype=torch.long, device=latents.device) * timesteps.item(), 
            )  # [B, F, C, H, W]

        z, pred_original, log_prob = sde_step(
            pred_flow, latents.to(torch.float32), x_t_prev=pre_latents.to(torch.float32), 
            sigmas=self.sigma_schedule, time_idx=i)
    
        return log_prob

    def save_latent_and_reward(self, video_latents, task_type, prompt):
        save_dir = self.gen_train_videos_dir
        # 1. 解码视频
        video_latents = video_latents.detach()
        self.vae.to(device=self.device, dtype=self.dtype)
        videos = self.vae.decode_to_pixel(video_latents)
        videos = (videos * 0.5 + 0.5).clamp(0, 1)
        videos = rearrange(videos, 'b t c h w -> b t h w c').cpu() * 255.0
        self.vae.model.clear_cache()
        self.vae.to(device='cpu')
        # 2. 保存 mp4
        base_name = f'step{self.step}-{task_type}-rank{self.global_rank}'
        video_path = os.path.join(save_dir, f'{base_name}.mp4')
        write_video(video_path, videos[0], fps=16)
        torch.cuda.empty_cache()
        # 3. reward
        all_reward = self.reward_model.inferencer.reward([video_path], prompt, use_norm=False)[0]
        avg_reward = 0.1 * all_reward['MQ'] + 0.1 * all_reward["VQ"] + 0.8 * all_reward["TA"]
        # avg_reward = all_reward["TA"]
        # avg_reward = (all_reward['MQ'] + all_reward["VQ"] + all_reward["TA"]) / 3

        reward_log_dict = {
            "avg_reward" : avg_reward,
            "mq_reward" : all_reward['MQ'],
            "vq_reward" : all_reward['VQ'],
            "ta_reward" : all_reward['TA']
        }
        # 4. save prompt
        txt_path = os.path.join(save_dir, f'{base_name}.txt')
        with open(txt_path, 'w', encoding='utf-8') as f:
            f.write(prompt[0])

        return reward_log_dict
