import mediapy as media
import gc
import logging
import math
import os
import shutil
from typing import Any, Callable, Dict, List, Optional, Union
from datetime import timedelta, datetime
from pathlib import Path
from typing import Any, Dict
import argparse
import importlib
import json
import copy
import cv2
import diffusers
import torch
import transformers
import matplotlib.pyplot as plt
import torch.distributed as dist
import wandb
from diffusers.utils import export_to_video
from accelerate import Accelerator, DistributedType, init_empty_weights
from accelerate.logging import get_logger
from accelerate.utils import (
    DistributedDataParallelKwargs,
    InitProcessGroupKwargs,
    ProjectConfiguration,
    set_seed,
    broadcast_object_list
)
import numpy as np
from wan.autoencoder import AutoencoderKLWan
from wan.scheduler import FlowMatchEulerDiscreteScheduler
from wan.pipelines import WanPipeline
from s3_ar.wan.transformer_joint_s3 import WanTransformer3DModel
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer, UMT5EncoderModel
from args import get_args
from text_encoder import compute_prompt_embeddings
from dataloader.dataset_human_load_tensor import VideoParquetDataset
from utils import (
    get_gradient_norm,
    get_optimizer,
    prepare_rotary_positional_embeddings,
    print_memory,
    reset_memory,
    unwrap_model,
    put_object_from_file,
    sample_timesteps_with_group_new
)
from collections import OrderedDict

def log_per_layer_grad_norms(model, global_step, output_dir, total_grad_norm, logger, threshold=0.3):
    if total_grad_norm <= threshold:
        return
    
    grad_norm_file = os.path.join(output_dir, "per_layer_grad_norms.log")
    
    layer_norms = []
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norm = param.grad.detach().norm(2).item()
            layer_norms.append((name, grad_norm))
    
    layer_norms = sorted(layer_norms, key=lambda x: x[1], reverse=True)

def save_batch_tensors_and_metadata(global_step, batch, output_dir, max_entries=10):
    json_path = os.path.join(output_dir, "metadata.json")
    if os.path.exists(json_path):
        with open(json_path, 'r', encoding='utf-8') as f:
            try:
                existing_data = json.load(f)
            except json.JSONDecodeError:
                existing_data = []
    else:
        existing_data = []
    
    if len(existing_data) >= max_entries:
        return False
    
    videos_dir = os.path.join(output_dir, "videos")
    depths_dir = os.path.join(output_dir, "depths")
    os.makedirs(videos_dir, exist_ok=True)
    os.makedirs(depths_dir, exist_ok=True)
    
    videos = batch["videos"]
    depths = batch["depths"]
    prompts_text = batch["prompts_text"]
    
    if len(videos) != 2 or len(depths) != 2 or len(prompts_text) != 2:
        print(f"Warning: Expected batch_size=2, got videos={len(videos)}, depths={len(depths)}, prompts={len(prompts_text)}")
        return False
    
    new_entries = []
    for idx in range(2):
        if len(existing_data) + len(new_entries) >= max_entries:
            break
        
        video_filename = f"video_step{global_step}_idx{idx}.pt"
        depth_filename = f"depth_step{global_step}_idx{idx}.pt"
        video_path = os.path.join(videos_dir, video_filename)
        depth_path = os.path.join(depths_dir, depth_filename)
        
        try:
            torch.save(videos[idx].cpu(), video_path)
            torch.save(depths[idx].cpu(), depth_path)
        except Exception as e:
            print(f"Error saving tensors for step {global_step}, idx {idx}: {e}")
            continue
        
        entry = {
            "video_path": video_path,
            "depth_path": depth_path,
            "prompts_text": prompts_text[idx]
        }
        new_entries.append(entry)
    
    if new_entries:
        existing_data.extend(new_entries)
        try:
            with open(json_path, 'w', encoding='utf-8') as f:
                json.dump(existing_data, f, ensure_ascii=False, indent=2)
            print(f"Saved {len(new_entries)} entries for step {global_step} to {json_path}")
        except Exception as e:
            print(f"Error writing to JSON {json_path}: {e}")
            return False
    
    return True

def colorize_video_depth(depth_video, colormap="Spectral"):
    if isinstance(depth_video, torch.Tensor):
        depth_video = depth_video.cpu().numpy()
    T, H, W = depth_video.shape
    colored_depth_video = []
    for i in range(T):
        colored_depth = plt.get_cmap(colormap)(depth_video[i], bytes=True)[...,:3]
        colored_depth_video.append(colored_depth)
    colored_depth_video = np.stack(colored_depth_video, axis=0)
    return colored_depth_video

class PathSimplifierFormatter(logging.Formatter):
    def format(self, record):
        record.short_path = os.path.relpath(record.pathname)
        return super().format(record)

def setup_logger(log_directory, experiment_name, process_rank, source_module=__name__):
    handlers = [logging.StreamHandler()]
    if process_rank == 0:
        log_file_path = os.path.join(log_directory, f"{experiment_name}.log")
        handlers.append(logging.FileHandler(log_file_path))
    
    log_formatter = PathSimplifierFormatter(
        fmt='[%(asctime)s %(short_path)s:%(lineno)d] %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    
    for handler in handlers:
        handler.setFormatter(log_formatter)
    
    logging.basicConfig(level=logging.INFO, handlers=handlers)
    return logging.getLogger(source_module)

def normalize_latents(
        latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
    ) -> torch.Tensor:
        latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device)
        latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device)
        latents = ((latents.float() - latents_mean) * latents_std).to(latents)
        return latents

def log_validation(
    pipe: WanPipeline,
    args: Dict[str, Any],
    pipeline_args: Dict[str, Any],
    is_final_validation: bool = False,
    step: int = 0,
    fps: int = 16,
    device: str ='cuda',
    process_rank: int = 0
):
    os.makedirs(os.path.join(args.output_dir, "videos"), exist_ok=True)
    phase_name = "test" if is_final_validation else "validation"
    pipe = pipe.to(device)
    generator = torch.Generator(device=device).manual_seed(0)
    videos = []
    video_depths = []
    for _ in range(args.num_validation_videos):
        video, video_depth = pipe(**pipeline_args, generator=generator, output_type="np")
        videos.append(video.frames[0])
        video_depths.append(video_depth.frames)
    video_filenames = []
    for i, video in enumerate(videos):
        prompt = (
            pipeline_args["prompt"][:25]
            .replace(" ", "_")
            .replace("'", "_")
            .replace('"', "_")
            .replace("/", "_")
        )
        video_id = f"{phase_name}_step_{step}_video_{i}_{prompt}.mp4"
        filename = os.path.join(args.output_dir, "videos", video_id)
        video_id_depth = f"{phase_name}_step_{step}_video_{i}_{prompt}_depth.mp4"
        filename_depth = os.path.join(args.output_dir, "videos", video_id_depth)
        colored_depth_video = colorize_video_depth(video_depths[i])
        media.write_video(filename_depth, colored_depth_video, fps=fps)
        export_to_video(video, filename, fps=fps)
        video_filenames.append(filename)
        if process_rank == 0:
            wandb.log({
                f"{phase_name}_video_{i}": wandb.Video(filename, fps=fps, format="mp4"),
                f"{phase_name}_depth_video_{i}": wandb.Video(filename_depth, fps=fps, format="mp4")
            }, step=step)
    return None

class CollateFunction:
    def __init__(self, weight_dtype: torch.dtype,) -> None:
        self.weight_dtype = weight_dtype

    def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]:
        prompts = [x["prompt"] for x in data]
        prompts_text = [x["prompt_text"] for x in data]
        prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True)
        videos = [x["video"] for x in data]
        videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True)
        depths = [x["depth"] for x in data]
        depths = torch.stack(depths).to(dtype=self.weight_dtype, non_blocking=True)
        return {
            "videos": videos,
            "prompts": prompts,
            "prompts_text": prompts_text,
            "depths": depths
        }

@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
    ema_params = OrderedDict(ema_model.named_parameters())
    model_params = OrderedDict(model.named_parameters())
    for name, param in model_params.items():
        ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)

def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag

def main(args):
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ['RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
        local_rank = int(os.environ['LOCAL_RANK'])
        torch.cuda.set_device(local_rank)
        dist.init_process_group(
            backend='nccl',
            init_method='env://',
            timeout=timedelta(seconds=args.nccl_timeout)
        )
    else:
        rank = 0
        world_size = 1
        local_rank = 0
    device = torch.device(f'cuda:{local_rank}')
    
    exp_name = "{}".format(datetime.now().strftime('%Y%m%d-%H%M%S'))
    logger = setup_logger(args.output_dir, exp_name, rank, __name__)
    if args.seed is not None:
        torch.manual_seed(args.seed + rank)

    if rank == 0:
        wandb.init(
            project="0509_long",
            name=exp_name,
            config=vars(args),
            dir=args.output_dir
        )

    if rank == 0:
        os.makedirs(args.output_dir, exist_ok=True)
        args_save = {k: v for k, v in vars(args).items() if json.dumps(v, default=str) is not None}
        with open(os.path.join(args.output_dir, "config.json"), "w") as f:
            json.dump(args_save, f, indent=4)
    
    load_dtype = torch.bfloat16
    if not args.resume_from_checkpoint:
        transformer_path = args.pretrained_model_name_or_path
    else:
        if args.resume_from_checkpoint != "latest":
            path = os.path.basename(args.resume_from_checkpoint)
        else:
            dirs = os.listdir(args.output_dir)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            path = dirs[-1] if len(dirs) > 0 else None
        if path is not None:
            transformer_path = os.path.join(args.output_dir, path)
        else:
            transformer_path = args.pretrained_model_name_or_path
    
    transformer = WanTransformer3DModel.from_pretrained(
        transformer_path,
        subfolder="transformer",
        torch_dtype=load_dtype,
        revision=args.revision,
        variant=args.variant,
        in_channels=16,
        out_channels=16,
        ignore_mismatched_sizes=True
    )
    scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler", shift=5.0)
    transformer.requires_grad_(True)
    ema_model = copy.deepcopy(transformer).to(device)
    requires_grad(ema_model, False)
    
    weight_dtype = torch.bfloat16
    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
        raise ValueError("Mixed precision training with bfloat16 is not supported on MPS.")
    
    transformer.to(device)

    update_ema(ema_model, transformer, decay=0)
    
    if args.gradient_checkpointing:
        transformer.enable_gradient_checkpointing()

    optimizer = torch.optim.AdamW(
        [p for p in transformer.parameters() if p.requires_grad],
        lr=args.learning_rate,
        betas=(args.beta1, args.beta2),
        weight_decay=args.weight_decay,
        eps=args.epsilon,
    )
    
    dataset_init_kwargs = {
        "path": args.data_root,
        "max_num_frames": args.max_num_frames,
        "height_buckets": args.height_buckets,
        "width_buckets": args.width_buckets,
        "frame_buckets": args.frame_buckets,
        "infinite": False,
        "refs": (
            "NebuTosClient(retry=3)",
            "clip_toskey",
            "video_bytes",
            "partial(default_get_hash, key='clip_toskey')",
        ),
    }

    train_dataset = VideoParquetDataset(**dataset_init_kwargs)

    collate_fn = CollateFunction(weight_dtype)

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=args.train_batch_size,
        collate_fn=collate_fn,
        num_workers=args.dataloader_num_workers,
        pin_memory=True,
        drop_last=True,
    )
    global_step = 0
    if not args.resume_from_checkpoint:
        initial_global_step = 0
        skiped_train_dataloader = train_dataloader
    else:
        if args.resume_from_checkpoint != "latest":
            path = os.path.basename(args.resume_from_checkpoint)
        else:
            dirs = os.listdir(args.output_dir)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            path = dirs[-1] if len(dirs) > 0 else None

        if path is None:
            logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting new training.")
            args.resume_from_checkpoint = None
            initial_global_step = 0
        else:
            global_step = int(path.split("-")[1])
            initial_global_step = global_step
        skiped_train_dataloader = train_dataloader
    progress_bar = tqdm(
        range(0, args.max_train_steps),
        initial=initial_global_step,
        desc="Steps",
    )
    train_dataloader_iter = iter(skiped_train_dataloader)
    transformer.train()
    ema_model.eval()
    transformer = torch.nn.parallel.DistributedDataParallel(transformer, device_ids=[local_rank], output_device=local_rank)
    x = scheduler.timesteps
    y = torch.exp(-2 * ((x - 1000 / 2) / 1000) ** 2)
    y_shifted = y - y.min()
    bsmntw_weighing = y_shifted * (1000 / y_shifted.sum())
    linear_timesteps_weights = bsmntw_weighing
   
    while global_step < args.max_train_steps:
        try:
            batch = next(train_dataloader_iter)
        except:
            train_dataloader_iter = iter(train_dataloader)
            batch = next(train_dataloader_iter)
        videos = batch["videos"].to(device, non_blocking=True)
        prompt_embeds = batch["prompts"].to(device, non_blocking=True)
        depths = batch['depths'].to(device, non_blocking=True)
        prompts_text = batch["prompts_text"]
        print(f'{prompts_text} on {rank}')
        videos = videos.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
        depths = depths.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
        prompt_embeds = prompt_embeds.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
        model_input_rgb = videos
        model_input_depth = depths
        noise_rgb = torch.randn_like(model_input_rgb)
        noise_depth = torch.randn_like(model_input_depth)
        
        batch_size, num_channels, num_frames, height, width = model_input_rgb.shape
        timestep_id = sample_timesteps_with_group_new(
            batch_size,
            num_frames=21,
            device=device,
            num_train_timesteps=scheduler.num_train_timesteps
        )
        timestep = scheduler.timesteps[timestep_id].to(dtype=weight_dtype, device=device)
        timestep[:,3:5] = 0
        noisy_model_input_rgb = scheduler.add_noise(model_input_rgb, noise_rgb, timestep)
        noisy_model_input_depth = scheduler.add_noise(model_input_depth, noise_depth, timestep)
        noisy_model_input = torch.cat([noisy_model_input_rgb, noisy_model_input_depth], dim=1)
        def training_target(sample, noise):
            target = noise - sample
            return target
        target_rgb = training_target(model_input_rgb, noise_rgb)
        target_depth = training_target(model_input_depth, noise_depth)

        optimizer.zero_grad()
        model_output = transformer(
            hidden_states=noisy_model_input.to(dtype=weight_dtype),
            timestep=timestep,
            encoder_hidden_states=prompt_embeds,
            return_dict=False,
        )[0]

        model_output_depth = model_output[:, 16:, :, :, :]
        model_output_rgb = model_output[:, :16, :, :, :]
        timestep_cpu = timestep.cpu()
        timestep_diff = torch.abs(scheduler.timesteps.view(1, -1) - timestep_cpu.view(-1, 1))
        timestep_id = torch.argmin(timestep_diff, dim=1)
        weights = linear_timesteps_weights[timestep_id].to(device=device)
        weights = weights.view(batch_size, 1, num_frames, 1, 1)

        loss_rgb = torch.mean(
            (weights[:, :, 5:] * (model_output_rgb[:, :, 5:].float() - target_rgb[:, :, 5:].float()) ** 2).reshape(target_rgb.shape[0], -1),
            1,
        )
        loss_depth = torch.mean(
            (weights[:, :, 5:] * (model_output_depth[:, :, 5:].float() - target_depth[:, :, 5:].float()) ** 2).reshape(target_depth.shape[0], -1),
            1,
        )
        loss = loss_rgb + loss_depth
        loss_rgb = loss_rgb.mean()
        loss_depth = loss_depth.mean()
        loss = loss.mean()
        
        loss.backward()
        
        grad_norm = torch.nn.utils.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
        if rank == 0 and global_step > 200:
            log_per_layer_grad_norms(
                model=transformer,
                global_step=global_step,
                output_dir=args.output_dir,
                total_grad_norm=grad_norm.item(),
                logger=logger,
                threshold=0.3
            )
        optimizer.step()
        
        update_ema(ema_model, transformer.module)
        
        global_step += 1
        loss_value = loss.item()
        progress_bar.update(1)
        progress_bar.set_postfix({
            "loss":      f"{loss_value:.6f}",
            "loss_rgb":  f"{loss_rgb:.6f}",
            "grad_norm": f"{grad_norm:.6f}"
        })
        logger.info(
            f"Step {global_step}: loss = {loss_value:.6f} "
            f"loss_rgb = {loss_rgb:.6f} grad_norm = {grad_norm:.6f}"
        )
        
        if rank == 0:
            wandb.log({
                "step": global_step,
                "loss": loss_value,
                "loss_rgb": loss_rgb.item(),
                "loss_depth": loss_depth.item(),
                "grad_norm": grad_norm.item()
            }, step=global_step)

        if rank == 0:
            if global_step % args.checkpointing_steps == 0:
                save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                with init_empty_weights():
                    transformer_copy = WanTransformer3DModel.from_config(
                        args.pretrained_model_name_or_path,
                        subfolder="transformer",
                        torch_dtype=load_dtype,
                        revision=args.revision,
                        variant=args.variant,
                        in_channels=16,
                        out_channels=16,
                        ignore_mismatched_sizes=True
                    )
                transformer_copy.load_state_dict(transformer.module.state_dict(), strict=True, assign=True)
                transformer_copy.save_pretrained(save_path, safe_serialization=True, max_shard_size="10GB")
                torch.save(ema_model.state_dict(), os.path.join(save_path, "ema.pt"))
                torch.save(optimizer.state_dict(), os.path.join(save_path, "optimizer.pt"))
                logger.info(f"Saved state to {save_path}")
            
            if global_step % args.validation_steps == 0:
                pipe = WanPipeline.from_pretrained(
                    args.pretrained_model_name_or_path,
                    transformer=ema_model,
                    torch_dtype=weight_dtype,
                )
                validation_prompts = [
                    "A man in a striped shirt and cap is crouching down, carefully examining a suitcase. The background features a textured wall with graffiti, suggesting an urban setting. The man appears focused and methodical as he inspects the suitcase."
                ]
                for validation_prompt in validation_prompts:
                    pipeline_args = {
                        "prompt": validation_prompt,
                        "guidance_scale": args.guidance_scale,
                        "height": args.height,
                        "width": args.width,
                    }
                    log_validation(
                        pipe=pipe,
                        args=args,
                        pipeline_args=pipeline_args,
                        step=global_step,
                        fps=args.fps,
                        device=device,
                        process_rank=rank
                    )
                del pipe
                torch.cuda.empty_cache()
        
        if global_step >= args.max_train_steps:
            break

    if rank == 0:
        wandb.finish()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="configs/sft_config.py", help="Path to config.py")
    parser.add_argument("--resume", type=str, default=None, help="save folder, e.g. 2025-01-13_11-57-29")
    parser.add_argument("--lr", type=float, default=None, help="Learning rate")
    main_args = parser.parse_args()
    
    import_path = main_args.config.replace(".py", "").replace("/", ".")
    config = importlib.import_module(import_path)
    args = getattr(config, "args")

    if main_args.resume is not None:
        args.resume_from_checkpoint = "latest"
        out_path = args.output_dir.split("/")
        out_path[-1] = main_args.resume
        args.output_dir = "/".join(out_path)
    
    if main_args.lr is not None:
        args.learning_rate = main_args.lr

    main(args)