import gc
import logging

from utils.dataset import ShardingLMDB_T2V_Dataset, cycle
from utils.distributed import EMA_FSDP, fsdp_wrap, fsdp_state_dict, launch_distributed_job, reduce_mean, reduce_dict
from utils.misc import set_seed, merge_step_dicts
import torch.distributed as dist
from omegaconf import OmegaConf
from model import T2V_DMD_GRPO
import torch
import wandb
import time
import os

import datetime
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from torch.utils.data import Sampler


class GroupSharedSampler(Sampler):
    """
    每 group_size 张卡共享同一条数据，整个 world 每 step 只走 world_size // group_size 条不同样本。
    例：64 卡，group_size=32 → 每 step 2 条样本。
    """
    def __init__(self,
                dataset,
                group_size,
                shuffle=True,
                seed=0,
                drop_last=False):
        self.dataset = dataset
        self.group_size = group_size
        self.shuffle = shuffle
        self.seed = seed
        self.drop_last = drop_last

        if not dist.is_available():
            raise RuntimeError("Requires torch.distributed")

        self.world_size = dist.get_world_size()
        self.rank = dist.get_rank()

        if self.world_size % self.group_size != 0:
            raise ValueError("world_size must be divisible by group_size")

        self.num_groups = self.world_size // self.group_size  # 2
        self.group_id = self.rank // self.group_size          # 0, 1

        # 计算「每 step 可见」的样本数
        total = len(self.dataset)
        if self.drop_last:
            self.num_samples = total // self.num_groups * self.num_groups
        else:
            self.num_samples = total
        self.total_size = self.num_samples

    def __iter__(self):
        # 全局索引
        indices = list(range(self.total_size))

        # shuffle（同组同 seed 保证顺序一致）
        if self.shuffle:
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = [indices[i] for i in torch.randperm(len(indices), generator=g)]

        # 按 num_groups 切块，每个组负责一块
        chunk_size = self.total_size // self.num_groups
        start = self.group_id * chunk_size
        end   = start + chunk_size
        group_indices = indices[start:end]

        # 组内 4 卡都返回同一条索引（一次 yield 一条）
        for idx in group_indices:
            yield idx

    def __len__(self):
        return self.total_size // self.num_groups

    def set_epoch(self, epoch):
        """每个 epoch 换一次顺序"""
        self.epoch = epoch

# 初始化阶段缓存所有子组
def init_subgroups(group_size):
    world_size = dist.get_world_size()
    assert world_size % group_size == 0
    num_groups = world_size // group_size
    subgroups = []
    for group_id in range(num_groups):
        start = group_id * group_size
        end = start + group_size
        ranks = list(range(start, end))
        subgroup = dist.new_group(ranks)
        subgroups.append(subgroup)
    return subgroups

class Trainer:
    def __init__(self, config):
        self.group_size = 8
        self.num_timesteps = 4
        self.GAS = config.GAS

        self.config = config
        self.step = config.step

        # Step 1: Initialize the distributed training environment (rank, seed, dtype, logging etc.)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

        launch_distributed_job()
        global_rank = dist.get_rank()
        self.world_size = dist.get_world_size()
        print(f"---------- global_rank={global_rank}")
        print(f"---------- world_size={self.world_size}")
        # 单机内进程编号（local rank，0-7）
        local_rank = int(os.environ["LOCAL_RANK"])
        print(f"---------- local_rank={local_rank}")

        self.subgroups = init_subgroups(self.group_size)

        self.dtype = torch.bfloat16 if config.mixed_precision else torch.float32
        self.device = torch.cuda.current_device()
        self.is_main_process = global_rank == 0
        self.disable_wandb = config.disable_wandb

        # use a random seed for the training
        if config.seed == 0:
            random_seed = torch.randint(0, 10000000, (1,), device=self.device)
            dist.broadcast(random_seed, src=0)
            config.seed = random_seed.item()

        set_seed(config.seed + global_rank)

        if self.is_main_process and not self.disable_wandb:
            wandb.login(key=config.wandb_key)
            wandb.init(
                config=OmegaConf.to_container(config, resolve=True),
                name=config.config_name,
                mode="online",  # os.environ["WANDB_MODE"] = "offline"
                entity=config.wandb_entity,
                project=config.wandb_project,
                dir=config.wandb_save_dir
            )

        self.output_path = config.logdir

        # Step 2: Initialize the model and optimizer
        self.model = T2V_DMD_GRPO(config, device=self.device)
        self.model.subgroups = self.subgroups
        self.model.group_size = self.group_size
        # ---------------------------------------------------------------------------
        # 7. (If resuming) Load the model and optimizer, lr_scheduler, ema's statedicts
        if getattr(config, "generator_ckpt", False):
            print(f"Loading pretrained generator from {config.generator_ckpt}")
            state_dict = torch.load(config.generator_ckpt, map_location="cpu", mmap=True)
            if "generator" in state_dict:
                state_dict_gen = state_dict["generator"]
                state_dict_critic = state_dict["critic"]
            print("----------- LOADING PRE GEN STATE DICT")
            self.model.generator.load_state_dict(
                state_dict_gen, strict=True
            )
            print("----------- LOADING PRE FAKE STATE DICT")
            self.model.fake_score.load_state_dict(
                state_dict_critic, strict=True
            )

            del state_dict_gen
            del state_dict_critic
            del state_dict
            gc.collect()
        # ---------------------------------------------------------------------------


        print("------------ We begin to fsdp all the models")
        self.model.generator = self.model.generator.float()
        self.model.generator = fsdp_wrap(
            self.model.generator,
            sharding_strategy=config.sharding_strategy,
            mixed_precision=config.mixed_precision,
            wrap_strategy=config.generator_fsdp_wrap_strategy,
            cpu_offload = False
        )
        # 打印当前卡的本地显存占用 (MB)
        allocated = torch.cuda.memory_allocated() / 1024**2
        print(f"[after gen model R{os.environ.get('RANK', '?')}] allocated: {allocated:.1f} MB")
        # for name, p in self.model.generator.named_parameters():
        #     print(f"{name}: {p.dtype} {p.shape}")  # 应全是 bfloat16
        print(f"gen {self.model.generator}")

        self.model.real_score = fsdp_wrap(
            self.model.real_score,
            sharding_strategy=config.sharding_strategy,
            mixed_precision=config.mixed_precision,
            wrap_strategy=config.real_score_fsdp_wrap_strategy,
            cpu_offload = False
        )
        # 打印当前卡的本地显存占用 (MB)
        allocated = torch.cuda.memory_allocated() / 1024**2
        print(f"[after real model R{os.environ.get('RANK', '?')}] allocated: {allocated:.1f} MB")
        

        self.model.fake_score = self.model.fake_score.float()  
        self.model.fake_score = fsdp_wrap(
            self.model.fake_score,
            sharding_strategy=config.sharding_strategy,
            mixed_precision=config.mixed_precision,
            wrap_strategy=config.fake_score_fsdp_wrap_strategy,
            cpu_offload = False
        )        
        allocated = torch.cuda.memory_allocated() / 1024**2
        print(f"[after fake model R{os.environ.get('RANK', '?')}] allocated: {allocated:.1f} MB")

        print("------------ We begin to fsdp the reward model")
        self.model.reward_model.inferencer.model = self.model.reward_model.inferencer.model.to(torch.bfloat16)
        self.model.reward_model.inferencer.model = fsdp_wrap(
            self.model.reward_model.inferencer.model,
            sharding_strategy=config.sharding_strategy,
            mixed_precision=config.mixed_precision,
            wrap_strategy=config.generator_fsdp_wrap_strategy,
            cpu_offload = False
        )
        allocated = torch.cuda.memory_allocated() / 1024**2
        print(f"[after reward model R{os.environ.get('RANK', '?')}] allocated: {allocated:.1f} MB")

        # for online vis and reward
        self.model.vae = self.model.vae.to(device=self.device, dtype=self.dtype) 
        allocated = torch.cuda.memory_allocated() / 1024**2
        print(f"[after vae model R{os.environ.get('RANK', '?')}] allocated: {allocated:.1f} MB")

        self.generator_optimizer = torch.optim.AdamW(
            [param for param in self.model.generator.parameters() if param.requires_grad],
            lr=config.lr,  # 2.0e-06
            betas=(config.beta1, config.beta2),
            foreach=False, 
            weight_decay=config.weight_decay
        )
        self.critic_optimizer = torch.optim.AdamW(
            [param for param in self.model.fake_score.parameters() if param.requires_grad],
            lr=config.lr_critic if hasattr(config, "lr_critic") else config.lr,  # 4.0e-07
            betas=(config.beta1_critic, config.beta2_critic),
            foreach=False, 
            weight_decay=config.weight_decay
        )
        print(f"config.lr = {config.lr} weight_decay = {config.weight_decay}")
        print(f"config.lr_critic = {config.lr_critic}")

        print("------------ We begin to initialize all the dataset")
        # Step 3: Initialize the dataloader
        # 1. train dataloader
        train_dataset = ShardingLMDB_T2V_Dataset(config.train_data_path, max_pair=int(1e8))
        train_sampler = GroupSharedSampler(train_dataset, group_size=self.group_size, shuffle=True, drop_last=True)
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset, batch_size=config.batch_size,  # 1
            sampler=train_sampler, num_workers=0)

        # 2. train_eval dataloader, used for eval when training, only 128 prompts
        train_eval_dataset = ShardingLMDB_T2V_Dataset(config.train_eval_data_path, max_pair=int(1e8))
        train_eval_sampler = torch.utils.data.distributed.DistributedSampler(train_eval_dataset, shuffle=False, drop_last=False)
        train_eval_dataloader = torch.utils.data.DataLoader(
            train_eval_dataset, batch_size=config.batch_size,  # 1
            sampler=train_eval_sampler, num_workers=0)

        if dist.get_rank() == 0:
            print("DATASET SIZE %d" % len(train_dataset))
        self.dataloader = cycle(train_dataloader, train_sampler)
        self.train_eval_dataloader = train_eval_dataloader

        ##############################################################################################################
        # 6. Set up EMA parameter containers
        self.ema_weight = config.get("ema_weight", 0.99)  # 0.99
        self.ema_start_step = config.get("ema_start_step", 0)  # 400
        self.generator_ema = None
        if (self.ema_weight > 0.0) and (self.step >= self.ema_start_step):
            print(f"Setting up EMA with weight {self.ema_weight}")
            self.generator_ema = EMA_FSDP(self.model.generator, decay=self.ema_weight)

        self.max_grad_norm_generator = getattr(config, "max_grad_norm_generator", 5.0)
        print(f"self.max_grad_norm_generator {self.max_grad_norm_generator}")
        self.max_grad_norm_critic = getattr(config, "max_grad_norm_critic", 5.0)
        print(f"self.max_grad_norm_critic {self.max_grad_norm_critic}")
        self.previous_time = None

    def save(self):
        print("Start gathering distributed model states...")
        generator_state_dict = fsdp_state_dict(
            self.model.generator)
        critic_state_dict = fsdp_state_dict(
            self.model.fake_score)

        if (self.ema_weight > 0.0) and (self.ema_start_step < self.step):
            state_dict = {
                "generator": generator_state_dict,
                "critic": critic_state_dict,
                "generator_ema": self.generator_ema.state_dict(),
            }
        else:
            state_dict = {
                "generator": generator_state_dict,
                "critic": critic_state_dict,
            }

        if self.is_main_process:
            os.makedirs(os.path.join(self.output_path,
                        f"checkpoint_model_{self.step:06d}"), exist_ok=True)
            torch.save(state_dict, os.path.join(self.output_path,
                    f"checkpoint_model_{self.step:06d}", "model.pt"))
            print("Model saved to", os.path.join(self.output_path,
                f"checkpoint_model_{self.step:06d}", "model.pt"))

    def fwdbwd_one_step(self, batch, train_generator):
        self.model.eval()  # prevent any randomness (e.g. dropout)

        if self.is_main_process:
            print(f"------------[TRAIN] We begin to one batch with train_generator {train_generator} + fake model")

        if self.step % 5 == 0:
            torch.cuda.empty_cache()

        # Step 1: Get the next batch of text prompts
        noise_shape = batch["noise_shape"][0].numpy().tolist()
        noise_shape = tuple(map(int, noise_shape))  # 1, 21, 16, tgt_h // 8, tgt_w // 8
        text_feature = batch['text_feature'][0].to(device=self.device, dtype=self.dtype)
        neg_text_feature = batch['neg_text_feature'][0].to(device=self.device, dtype=self.dtype)
        prompt = batch['prompt']

        conditional_dict = dict()
        conditional_dict['prompt_embeds'] = text_feature
        unconditional_dict = dict()
        unconditional_dict['prompt_embeds'] = neg_text_feature

        # Step 2: Extract the conditional infos get from batch

        # Step 3: Store gradients for the generator (if training the generator)
        if train_generator:
            if self.is_main_process:
                print(f"------------[TRAIN] We begin to train train_generator")
            generator_loss, generator_log_dict = self.model.generator_loss(
                prompt = prompt,
                image_or_video_shape=noise_shape,
                conditional_dict=conditional_dict, unconditional_dict=unconditional_dict,
            )
            generator_loss /= self.GAS
            generator_loss.backward()
            generator_log_dict.update({"generator_loss": generator_loss})

            return generator_log_dict

        # Step 4: Store gradients for the critic (if training the critic)
        critic_loss, critic_log_dict = self.model.critic_loss(
            prompt = prompt,
            image_or_video_shape=noise_shape,
            conditional_dict=conditional_dict, unconditional_dict=unconditional_dict,
        )

        critic_loss.backward()
        critic_grad_norm = self.model.fake_score.clip_grad_norm_(self.max_grad_norm_critic)
        critic_log_dict.update({"critic_loss": critic_loss, "critic_grad_norm": critic_grad_norm})

        return critic_log_dict

    def train(self):
        start_step = self.step

        if self.is_main_process:
            print(f"Update Gen Every {self.config.dfake_gen_update_ratio} steps")
        
        while True:
            if self.is_main_process:
                print(f"training step {self.step} ...")
            TRAIN_GENERATOR = self.step % self.config.dfake_gen_update_ratio == self.config.dfake_gen_update_ratio-1 # 每 5 步训练一次生成器
            self.model.step = self.step

            if TRAIN_GENERATOR:
                # Train the generator + critic
                if self.is_main_process:
                    print("------------[TRAIN] We begin to train geneartor + fake model")
                self.generator_optimizer.zero_grad(set_to_none=True)
                step_dicts = []
                for gas_idx in range(self.GAS):
                    batch = next(self.dataloader)
                    generator_log_dict_gas = self.fwdbwd_one_step(batch, True)
                    step_dicts.append(generator_log_dict_gas)
                generator_grad_norm = self.model.generator.clip_grad_norm_(self.max_grad_norm_generator)
                self.generator_optimizer.step()
                generator_log_dict = merge_step_dicts(step_dicts)
                generator_log_dict["generator_grad_norm"] = generator_grad_norm.item()
                if self.generator_ema is not None:
                    self.generator_ema.update(self.model.generator)
                
            
            # Train the critic
            if self.is_main_process:
                print("------------[TRAIN] We begin to train fake model")
            self.critic_optimizer.zero_grad(set_to_none=True)
            batch = next(self.dataloader)
            critic_log_dict = self.fwdbwd_one_step(batch, False)
            self.critic_optimizer.step()

            # Increment the step since we finished gradient update
            self.step += 1

            # Create EMA params (if not already created)
            if (self.step >= self.ema_start_step) and (self.generator_ema is None) and (self.ema_weight > 0):
                self.generator_ema = EMA_FSDP(self.model.generator, decay=self.ema_weight)

            # Save the model
            if (not self.config.no_save) and (self.step - start_step) > 0 and self.step % self.config.log_iters == 0:
                torch.cuda.empty_cache()
                self.save()
                torch.cuda.empty_cache()
            if self.step == 81:
                torch.cuda.empty_cache()
                self.save()
                torch.cuda.empty_cache()

            # Gen videos on train_eval dataset
            reward_log_dict = None
            if self.step % 10 == 0:   
                self.model.eval()  
                reward_dict = {
                    "eval_avg_reward": 0.0,
                    "eval_mq_reward": 0.0,
                    "eval_vq_reward": 0.0,
                    "eval_ta_reward": 0.0,
                    "eval_avg_reward_mv": 0.0
                }
                total_batches = 0
                with torch.no_grad():
                    for batch in self.train_eval_dataloader:
                        cur_reward_dict = self.gen_vids(batch, total_batches)   

                        reward_dict["eval_avg_reward"] += cur_reward_dict["avg_reward"]
                        reward_dict["eval_mq_reward"] += cur_reward_dict["mq_reward"]
                        reward_dict["eval_vq_reward"] += cur_reward_dict["vq_reward"]
                        reward_dict["eval_ta_reward"] += cur_reward_dict["ta_reward"]
                        reward_dict["eval_avg_reward_mv"] += cur_reward_dict["avg_reward_mv"]
                        total_batches += 1
                for k in reward_dict:
                    reward_dict[k] /= total_batches
                # 最终得到所有 batch 平均后的字典
                reward_log_dict = reduce_dict(reward_dict, self.device)


            ## For Logging
            if TRAIN_GENERATOR:
                generator_log_dict = reduce_dict(generator_log_dict, self.device)

            critic_log_dict = reduce_dict(critic_log_dict, self.device)

            # Logging
            if self.is_main_process:
                wandb_loss_dict = {}
                if TRAIN_GENERATOR:
                    wandb_loss_dict.update(generator_log_dict)
                if reward_log_dict:
                    wandb_loss_dict.update(reward_log_dict)
                wandb_loss_dict.update(critic_log_dict)

                time_str = datetime.datetime.now().strftime('%m-%d %H:%M:%S')
                print(f"[{time_str}] step={self.step}", end='')
                for k, v in wandb_loss_dict.items():
                    print(f' | {k}: {v:.6f}', end='')
                if not self.disable_wandb:
                    wandb.log(wandb_loss_dict, step=self.step)

            if self.step % self.config.gc_interval == 0:
                if dist.get_rank() == 0:
                    logging.info("DistGarbageCollector: Running GC.")
                gc.collect()
                torch.cuda.empty_cache()

            if self.is_main_process:
                current_time = time.time()
                if self.previous_time is None:
                    self.previous_time = current_time
                else:
                    if not self.disable_wandb:
                        wandb.log({"per iteration time": current_time - self.previous_time}, step=self.step)
                    print(f' | per iteration time: {(current_time - self.previous_time):.6f}', end='')
                    print()  # 换行
                    self.previous_time = current_time
    
    def gen_vids(self, batch, batch_idx):
        if self.is_main_process:
            print(f"------------[EVAL TRAIN] We begin eval generator on train_eval dataset")

        # Step 1: Get the next batch of text prompts
        noise_shape = batch["noise_shape"][0].numpy().tolist()
        noise_shape = tuple(map(int, noise_shape))  # 1, 21, 16, tgt_h // 8, tgt_w // 8
        text_feature = batch['text_feature'][0].to(device=self.device, dtype=self.dtype)
        neg_text_feature = batch['neg_text_feature'][0].to(device=self.device, dtype=self.dtype)
        prompt = batch['prompt']

        clean_latent = None
        conditional_dict = dict()
        conditional_dict['prompt_embeds'] = text_feature
        unconditional_dict = dict()
        unconditional_dict['prompt_embeds'] = neg_text_feature

        # Step 2: Gen videos
        video_path = self.model.gen_video(
            idx = batch_idx, prompt = prompt, image_or_video_shape=noise_shape,
            conditional_dict=conditional_dict, unconditional_dict=unconditional_dict,
        )

        # Step 3: reward
        all_reward = self.model.reward_model.inferencer.reward([video_path], prompt, use_norm=False)[0]
        avg_reward = (all_reward['MQ'] + all_reward["VQ"] + all_reward["TA"]) / 3
        avg_reward_mv = 0.1 * all_reward['MQ'] + 0.1 * all_reward["VQ"] + 0.8  * all_reward["TA"]
        reward_log_dict = {
            "avg_reward" : avg_reward,
            "mq_reward" : all_reward['MQ'],
            "vq_reward" : all_reward['VQ'],
            "ta_reward" : all_reward['TA'],
            "avg_reward_mv" : avg_reward_mv
        }
        return reward_log_dict