# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import math
from typing import Any, Dict, Mapping, Optional, Tuple

import attrs
import torch
import os
from einops import rearrange
from megatron.core import parallel_state
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor import DTensor
from torch.nn.modules.module import _IncompatibleKeys
from cosmos_predict2.utils.checkpointer import non_strict_load_model
from cosmos_predict2.utils.dtensor_helper import broadcast_dtensor_model_states
from cosmos_predict2.utils.optim_instantiate import get_base_scheduler
from cosmos_predict2.utils.torch_future import clip_grad_norm_
from imaginaire.lazy_config import LazyDict, instantiate
from imaginaire.model import ImaginaireModel
from imaginaire.utils import log
from wan.modules.block_attention import get_flex_causal_block_mask_for_tf_training
try:
    import wan
except ImportError:
    raise ImportError("WAN is not installed. Please set PYTHONPATH to the root of the WAN repository")
from wan.modules.model_causal import WanModelCausal
from wan.modules.vae2_2 import Wan2_2_VAE
from wan.utils.fm import FlowMatchScheduler



@attrs.define(slots=False)
class WanWarpedConfig:
    precision: str = "bfloat16"
    loss_reduce: str = "mean"
    loss_scale: float = 10.0
    dit_pretrain_path: str = "Wan2.2/Wan2.2-TI2V-5B"
    dit_pt_path: str = None
    vae_pretrain_path: str = "Wan2.2/Wan2.2-TI2V-5B/Wan2.2_VAE.pth"
    gradient_checkpoint: bool = True
    using_shift: bool = True

    # debug flag
    debug_without_randomness: bool = False
    fsdp_shard_size: int = 8  # 0 means not using fsdp, -1 means set to world size
    # High sigma strategy
    high_sigma_ratio: float = 0.0


class WanWarpedModel(ImaginaireModel):
    def __init__(self, config: WanWarpedConfig):
        super().__init__()

        self.config = config
        self.precision = {
            "float32": torch.float32,
            "float16": torch.float16,
            "bfloat16": torch.bfloat16,
        }[config.precision]
        self.tensor_kwargs = {"device": "cuda", "dtype": self.precision}
        self.device = torch.device("cuda")
        self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
        if config.using_shift:
            log.info("Using shift in WanWarpedModel")
            self.scheduler.set_timesteps(1000, training=True)
        else:
            log.info("without shift in WanWarpedModel")
            self.scheduler.set_timesteps_no_shift(1000, training=True)
            
            

        # 4. Set up loss options, including loss masking, loss reduce and loss scaling
        self.loss_reduce = getattr(config, "loss_reduce", "mean")
        assert self.loss_reduce in ["mean", "sum"]
        self.loss_scale = getattr(config, "loss_scale", 1.0)
        log.critical(f"Using {self.loss_reduce} loss reduce with loss scale {self.loss_scale}")

        # 7. training states
        if parallel_state.is_initialized():
            self.data_parallel_size = parallel_state.get_data_parallel_world_size()
        else:
            self.data_parallel_size = 1

        self.vae = Wan2_2_VAE(vae_pth=config.vae_pretrain_path, device=self.device)
        self.dit = WanModelCausal.from_pretrained(config.dit_pretrain_path)
        if config.dit_pt_path is not None:
            log.info(f"Loading DIT model from {config.dit_pt_path}")
            self.dit.load_state_dict(torch.load(config.dit_pt_path, map_location='cpu'), strict=True)
        else:
            log.info("No DIT pretrain path provided, using the default DIT model")
        self.dit = self.dit.to(**self.tensor_kwargs)
        self.dit.gradient_checkpoint = config.gradient_checkpoint
        self.attention_mask = None
        self.dit.train()
        self.dit.requires_grad_(True)


        total_params = sum(p.numel() for p in self.parameters())
        frozen_params = sum(p.numel() for p in self.parameters() if not p.requires_grad)
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        # Print the number in billions, or in the format of 1,000,000,000
        log.info(
            f"Total parameters: {total_params / 1e9:.2f}B, Frozen parameters: {frozen_params:,}, Trainable parameters: {trainable_params:,}"
        )

        if torch.distributed.is_initialized() and config.fsdp_shard_size > 0:
            if config.fsdp_shard_size == -1:
                fsdp_shard_size = torch.distributed.get_world_size()
                replica_group_size = 1
            else:
                fsdp_shard_size = min(config.fsdp_shard_size, torch.distributed.get_world_size())
                replica_group_size = torch.distributed.get_world_size() // fsdp_shard_size
            dp_mesh = init_device_mesh(
                "cuda", (replica_group_size, fsdp_shard_size), mesh_dim_names=("replicate", "shard")
            )
            log.info(f"Using FSDP with shard size {fsdp_shard_size} | device mesh: {dp_mesh}")
            self.dit.fully_shard(mesh=dp_mesh)
            self.dit = fully_shard(self.dit, mesh=dp_mesh, reshard_after_forward=True)
            broadcast_dtensor_model_states(self.dit, dp_mesh)
        else:
            log.info("FSDP (Fully Sharded Data Parallel) is disabled.")

    # New function, added for i4 adaption
    @property
    def net(self) -> torch.nn.Module:
        return self.dit

    # New function, added for i4 adaption
    @property
    def net_ema(self) -> torch.nn.Module:
        raise NotImplementedError
        # return self.pipe.dit_ema

    # New function, added for i4 adaption
    def init_optimizer_scheduler(
        self, optimizer_config: LazyDict, scheduler_config: LazyDict
    ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]:
        """Creates the optimizer and scheduler for the model.

        Args:
            config_model (ModelConfig): The config object for the model.

        Returns:
            optimizer (torch.optim.Optimizer): The model optimizer.
            scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
        """
        optimizer = instantiate(optimizer_config, model=self.net)
        scheduler = get_base_scheduler(optimizer, self, scheduler_config)
        return optimizer, scheduler

    # ------------------------ training hooks ------------------------
    def on_before_zero_grad(
        self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int
    ) -> None:
        """
        update the net_ema
        """
        # del scheduler, optimizer

        # if self.config.pipe_config.ema.enabled:
            # calculate beta for EMA update
            # ema_beta = self.ema_beta(iteration)
            # self.pipe.dit_ema_worker.update_average(self.net, self.net_ema, beta=ema_beta)

    # New function, added for i4 adaption
    def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None:
        return
        if self.config.pipe_config.ema.enabled:
            self.net_ema.to(dtype=torch.float32)
        for module in [self.net, self.pipe.tokenizer]:
            if module is not None:
                module.to(memory_format=memory_format, **self.tensor_kwargs)



    def get_per_sigma_loss_weights(self, sigma: torch.Tensor) -> torch.Tensor:
        """
        Args:
            sigma (tensor): noise level

        Returns:
            loss weights per sigma noise level
        """
        return (sigma**2 + self.pipe.sigma_data**2) / (sigma * self.pipe.sigma_data) ** 2



    def training_step_full(self, data_batch: dict, data_batch_idx: int) -> tuple[dict, torch.Tensor]:
        # Loss
        video_pix_uint8 = data_batch['video'] # B, 3, T, H, W
        umt5_embedding = data_batch['umt5_embedding']
        
        # encode video to latent, need normalize to [-1, 1]
        video_pix = video_pix_uint8.float() / 127.5 - 1.0
        with torch.no_grad():
            # video_latent = self.vae.model.encode(video_pix, self.vae.scale) # 38G vram
            video_latent = torch.stack(self.vae.encode(list(video_pix))) # 22G vram
        # video_rec = self.vae.model.decode(video_latent, self.vae.scale)
        # t in shape B, L
        # Rectified Flow Training
        B, C, T, H, W = video_latent.shape
        
        noise = torch.randn_like(video_latent) # Sample noise
        timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (B,))
        timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.precision, device=self.device)
        timestep = timestep.view(B, 1).repeat(1, T * H * W // 4)
        sigma = self.scheduler.sigmas[timestep_id].to(dtype=self.precision, device=self.device).view(B, 1, 1, 1, 1)
        noisy_latent = (1 - sigma) * video_latent + sigma * noise
        target = noise - video_latent
        with torch.autocast(device_type="cuda", dtype=self.precision):
            pred = self.dit(noisy_latent, timestep, umt5_embedding)
        rectified_flow_loss = torch.nn.functional.mse_loss(pred, target, reduction='none')
        no_scale_loss = rectified_flow_loss.detach().mean().item()
        
        # Apply loss reduction
        if self.loss_reduce == "mean":
            rectified_flow_loss = rectified_flow_loss.mean() * self.loss_scale
        elif self.loss_reduce == "sum":
            rectified_flow_loss = rectified_flow_loss.sum(dim=(1, 2, 3, 4)).mean() * self.loss_scale
        else:
            raise ValueError(f"Invalid loss_reduce: {self.loss_reduce}")
        
        # Prepare output batch for logging
        output_batch = {
            'loss': rectified_flow_loss,
            "no_scale_loss": no_scale_loss,
            # 'timestep_mean': t.mean().item(),
        }

        return output_batch, rectified_flow_loss

    def training_step(self, data_batch: dict, data_batch_idx: int) -> tuple[dict, torch.Tensor]:
        # Loss
        video_pix_uint8 = data_batch['video'] # B, 3, T, H, W
        umt5_embedding = data_batch['umt5_embedding']
        
        # encode video to latent, need normalize to [-1, 1]
        video_pix = video_pix_uint8.float() / 127.5 - 1.0
        with torch.no_grad():
            video_latent = torch.stack(self.vae.encode(list(video_pix))) # 22G vram
        # video_rec = self.vae.model.decode(video_latent, self.vae.scale)
        # t in shape B, L
        # Rectified Flow Training
        B, C, T, H, W = video_latent.shape
        
        noise = torch.randn_like(video_latent) # Sample noise
        timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (B,))
        timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.precision, device=self.device)
        timestep = timestep.view(B, 1).repeat(1, T * H * W // 4)
        sigma = self.scheduler.sigmas[timestep_id].to(dtype=self.precision, device=self.device).view(B, 1, 1, 1, 1)
        sigma = sigma.repeat(1, 1, T, 1, 1)
        noisy_latent = (1 - sigma) * video_latent + sigma * noise
        input_latent = torch.cat([video_latent, noisy_latent], dim=2)  # Concatenate along the time dimension
        loss_mask = torch.cat([torch.zeros_like(sigma), torch.ones_like(sigma)], dim=2)  # Mask for loss calculation, pre t mask as zero, last t as one
        
        if self.attention_mask is None:
            seq_len = 2 * T * H * W // 4
            self.attention_mask = get_flex_causal_block_mask_for_tf_training(seq_len, T, H * W // 4, self.device)
        # Prepare timestep input for the model
        timestep = torch.cat([torch.zeros_like(timestep), timestep], dim=1)
        # Forward pass
        with torch.autocast(device_type="cuda", dtype=self.precision):
            pred = self.dit(input_latent, timestep, umt5_embedding, self.attention_mask, train=True)
        
        target = noise - video_latent
        target = torch.cat([video_latent, target], dim=2)  # Concatenate along the time dimension
        
        # Compute MSE loss
        rectified_flow_loss = torch.nn.functional.mse_loss(pred * loss_mask, target * loss_mask, reduction='none')
        
        # Apply loss reduction
        if self.loss_reduce == "mean":
            rectified_flow_loss = rectified_flow_loss.mean() * self.loss_scale
        elif self.loss_reduce == "sum":
            rectified_flow_loss = rectified_flow_loss.sum(dim=(1, 2, 3, 4)).mean() * self.loss_scale
        else:
            raise ValueError(f"Invalid loss_reduce: {self.loss_reduce}")
        
        # Prepare output batch for logging
        output_batch = {
            'loss': rectified_flow_loss,
        }

        return output_batch, rectified_flow_loss

    # ------------------ Checkpointing ------------------

    def state_dict(self) -> Dict[str, Any]:
        # the checkpoint format should be compatible with traditional imaginaire4
        # pipeline contains both net and net_ema
        # checkpoint should be saved/loaded from Model
        # checkpoint should be loadable from pipeline as well - We don't use Model for inference only jobs.

        net_state_dict = self.dit.state_dict()
        # if self.config.pipe_config.ema.enabled:
            # ema_state_dict = self.pipe.dit_ema.state_dict(prefix="net_ema.")
            # net_state_dict.update(ema_state_dict)

        # convert DTensor to Tensor
        for key, val in net_state_dict.items():
            if isinstance(val, DTensor):
                # Convert to full tensor
                net_state_dict[key] = val.full_tensor().detach().cpu()
            else:
                net_state_dict[key] = val.detach().cpu()

        return net_state_dict

    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False):
        """
        Loads a state dictionary into the model and optionally its EMA counterpart.
        Different from torch strict=False mode, the method will not raise error for unmatched state shape while raise warning.

        Parameters:e
            state_dict (Mapping[str, Any]): A dictionary containing separate state dictionaries for the model and
                                            potentially for an EMA version of the model under the keys 'model' and 'ema', respectively.
            strict (bool, optional): If True, the method will enforce that the keys in the state dict match exactly
                                    those in the model and EMA model (if applicable). Defaults to True.
            assign (bool, optional): If True and in strict mode, will assign the state dictionary directly rather than
                                    matching keys one-by-one. This is typically used when loading parts of state dicts
                                    or using customized loading procedures. Defaults to False.
        """
        # _reg_state_dict = collections.OrderedDict()
        # _ema_state_dict = collections.OrderedDict()
        # for k, v in state_dict.items():
        #     if k.startswith("net."):
        #         _reg_state_dict[k.replace("net.", "")] = v
        #     elif k.startswith("net_ema."):
        #         _ema_state_dict[k.replace("net_ema.", "")] = v
        # state_dict = _reg_state_dict

        if strict:
            reg_results: _IncompatibleKeys = self.dit.load_state_dict(state_dict, strict=strict, assign=assign)

            # if self.config.pipe_config.ema.enabled:
            #     ema_results: _IncompatibleKeys = self.pipe.dit_ema.load_state_dict(
            #         _ema_state_dict, strict=strict, assign=assign
            #     )

            return _IncompatibleKeys(
                missing_keys=reg_results.missing_keys,
                unexpected_keys=reg_results.unexpected_keys,
            )
        else:
            log.critical("load model in non-strict mode")
            log.critical(non_strict_load_model(self.dit, state_dict), rank0_only=False)
            # if self.config.pipe_config.ema.enabled:
            #     log.critical("load ema model in non-strict mode")
            #     log.critical(non_strict_load_model(self.dit_ema, _ema_state_dict), rank0_only=False)

    # ------------------ public methods ------------------
    def ema_beta(self, iteration: int) -> float:
        """
        Calculate the beta value for EMA update.
        weights = weights * beta + (1 - beta) * new_weights

        Args:
            iteration (int): Current iteration number.

        Returns:
            float: The calculated beta value.
        """
        iteration = iteration + self.config.pipe_config.ema.iteration_shift
        if iteration < 1:
            return 0.0
        return (1 - 1 / (iteration + 1)) ** (self.pipe.ema_exp_coefficient + 1)

    def clip_grad_norm_(
        self,
        max_norm: float,
        norm_type: float = 2.0,
        error_if_nonfinite: bool = False,
        foreach: Optional[bool] = None,
    ) -> torch.Tensor:
        return clip_grad_norm_(
            self.net.parameters(),
            max_norm,
            norm_type=norm_type,
            error_if_nonfinite=error_if_nonfinite,
            foreach=foreach,
        )
