import os
import torch
import json
from diffusers import AutoencoderKLWan
from diffusers.utils import export_to_video
import numpy as np
import mediapy as media
import matplotlib.pyplot as plt
from s3_ar.wan.transformer_joint_s3 import WanTransformer3DModel
from wan.pipelines.pipeline_wan import WanPipeline
from wan.pipelines.pipeline_flow import WanPipeline_flow
from wan.scheduler_pcm import UniPCMultistepScheduler
from wan.flow_frame import FlowMatchScheduler
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F
from torchvision.transforms.functional import resize
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from training.dataloader.dataset_depth_normalize import ScaleShiftDepthNormalizer

def colorize_video_depth(depth_video, colormap="Spectral"):
    if isinstance(depth_video, torch.Tensor):
        depth_video = depth_video.cpu().numpy()
    T, H, W = depth_video.shape
    colored_depth_video = []
    for i in range(T):
        colored_depth = plt.get_cmap(colormap)(depth_video[i], bytes=True)[..., :3]
        colored_depth_video.append(colored_depth)
    colored_depth_video = np.stack(colored_depth_video, axis=0)
    return colored_depth_video

def read_json_metadata(json_path):
    if not os.path.exists(json_path):
        raise FileNotFoundError(f"JSON file {json_path} not found")
    try:
        with open(json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
    except json.JSONDecodeError:
        raise ValueError(f"Invalid JSON format in {json_path}")
    if not data:
        raise ValueError(f"JSON file {json_path} is empty")
    for item in data:
        if not all(key in item for key in ["video_path", "depth_path", "prompts_text"]):
            raise ValueError(f"Missing required fields in JSON item: {item}")
        if not os.path.exists(item["video_path"]):
            raise FileNotFoundError(f"Video path {item['video_path']} not found")
        if not os.path.exists(item["depth_path"]):
            raise FileNotFoundError(f"Depth path {item['depth_path']} not found")
    return data

def load_and_preprocess_video(video_path, height=480, width=832, num_frames=81, transform=None):
    video = media.read_video(video_path)
    video = torch.from_numpy(video).float()
    
    T = video.shape[0]
    if T > num_frames:
        video = video[:num_frames]
    elif T < num_frames:
        video = F.pad(video, (0, 0, 0, 0, 0, 0, 0, num_frames - T))
    
    videos = video.permute(0, 3, 1, 2)
    video = torch.stack([resize(frame, (480, 832)) for frame in videos], dim=0)
    
    if transform:
        video = torch.stack([transform(frame) for frame in video])
    
    video = video.unsqueeze(0)
    return video

def load_and_preprocess_depth(depth_path, height=480, width=832, num_frames=81, transform=None):
    depth = media.read_video(depth_path)
    depth = torch.from_numpy(depth).float()
    
    if depth.dim() == 4:
        depth = depth[..., 0]
    if depth.max() > 1.0:
        depth = depth / depth.max()
    
    T = depth.shape[0]
    if T > num_frames:
        depth = depth[:num_frames]
    elif T < num_frames:
        depth = F.pad(depth, (0, 0, 0, 0, 0, num_frames - T))
    
    depth = depth.unsqueeze(1)
    depth = F.interpolate(depth, size=(height, width), mode='bilinear', align_corners=False)
    
    if transform:
        depth = transform(depth)
    
    depth = depth.unsqueeze(0)
    return depth

def normalize_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor) -> torch.Tensor:
    latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device)
    latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device)
    latents = ((latents.float() - latents_mean) * latents_std).to(latents)
    return latents

def encode_data(data, vae, weight_dtype=torch.bfloat16):
    data = data.permute(0, 2, 1, 3, 4)
    
    if data.shape[1] == 1:
        data = data.repeat(1, 3, 1, 1, 1)
    
    moments = vae._encode(data)
    latents = moments.to(dtype=weight_dtype)
    latents_mean = torch.tensor(vae.config.latents_mean, device=latents.device)
    latents_std = 1.0 / torch.tensor(vae.config.latents_std, device=latents.device)
    mu, logvar = torch.chunk(latents, 2, dim=1)
    mu = normalize_latents(mu, latents_mean, latents_std)
    logvar = normalize_latents(logvar, latents_mean, latents_std)
    latents = torch.cat([mu, logvar], dim=1)
    posterior = DiagonalGaussianDistribution(latents)
    sampled_latents = posterior.sample()
    
    del moments, latents, mu, logvar, posterior
    torch.cuda.empty_cache()
    
    return sampled_latents

def run_inference(checkpoints, json_path, output_dir, model_id, height=480, width=832, num_frames=81, num_inference_steps=60, guidance_scale=5.0, fps=15):
    metadata = read_json_metadata(json_path)
    print(f"Loaded {len(metadata)} entries from {json_path}")
    
    negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
    
    vae_2 = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32).to("cuda")
    vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16).to("cuda")
    vae.requires_grad_(False)
    
    random_flip = False
    video_transforms = transforms.Compose([
        transforms.RandomHorizontalFlip(random_flip) if random_flip else transforms.Lambda(lambda x: x),
        transforms.Lambda(lambda x: x / 255.0),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
    ])
    depth_transforms = ScaleShiftDepthNormalizer(
        norm_min=-1.0,
        norm_max=1.0,
        min_max_quantile=0.01,
        clip=True,
    )
    flow = True
    infer_real = True
    
    for checkpoint in checkpoints:
        checkpoint_name = os.path.basename(checkpoint)
        subfolder = os.path.join(output_dir, f'{checkpoint_name}_{flow}')
        os.makedirs(subfolder, exist_ok=True)
        
        transformer = WanTransformer3DModel.from_pretrained(
            checkpoint,
            in_channels=16,
            out_channels=16,
            ignore_mismatched_sizes=True,
            subfolder="transformer",
            torch_dtype=torch.bfloat16
        ).to("cuda")
        
        flow_shift = 5.0
        if flow:
            scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
            pipe = WanPipeline_flow.from_pretrained(
                model_id,
                vae=vae_2,
                transformer=transformer,
                torch_dtype=torch.bfloat16
            ).to("cuda")
            pipe.scheduler = scheduler
        else:
            scheduler = UniPCMultistepScheduler(
                prediction_type='flow_prediction',
                use_flow_sigmas=True,
                num_train_timesteps=1000,
                flow_shift=flow_shift
            )
            pipe = WanPipeline.from_pretrained(
                model_id,
                vae=vae_2,
                transformer=transformer,
                scheduler=scheduler,
                torch_dtype=torch.bfloat16
            ).to("cuda")
        
        for idx, entry in enumerate(metadata):
            video_path = entry["video_path"]
            depth_path = entry["depth_path"]
            prompt = entry["prompts_text"]
            
            try:
                video = load_and_preprocess_video(
                    video_path,
                    height=height,
                    width=width,
                    num_frames=num_frames,
                    transform=video_transforms
                ).to("cuda", torch.bfloat16)
                
                depth = load_and_preprocess_depth(
                    depth_path,
                    height=height,
                    width=width,
                    num_frames=num_frames,
                    transform=depth_transforms
                ).to("cuda", torch.bfloat16)
                
                video_latents = encode_data(video, vae, weight_dtype=torch.bfloat16)
                depth_latents = encode_data(depth, vae, weight_dtype=torch.bfloat16)
                
                latents = torch.cat([video_latents, depth_latents], dim=1).to("cuda", torch.bfloat16)
                
                del video, depth
                torch.cuda.empty_cache()
            except Exception as e:
                print(f"Error processing data for entry {idx}: {e}")
                continue
            
            prompt_prefix = prompt[:30].replace(" ", "_").replace(",", "").replace(".", "")
            
            prompts = [prompt] * 2
            for i in range(1):
                seed = 2
                num_sample_groups = 8 if infer_real else 4
                video, depth, gt = pipe(
                    prompts=prompts,
                    negative_prompt=negative_prompt,
                    num_inference_steps=num_inference_steps,
                    height=height,
                    width=width,
                    num_frames=num_frames,
                    guidance_scale=guidance_scale,
                    latents=latents,
                    generator=torch.Generator(device='cuda').manual_seed(seed),
                    num_sample_groups=num_sample_groups,
                    num_noise_groups=4,
                    infer_real=infer_real
                )
                
                video = video.frames[0]
                gt = gt.frames[0]
                depth = depth.frames
                colored_depth_video = colorize_video_depth(depth)
                
                video_filename = os.path.join(subfolder, f"{checkpoint_name}_flow_{flow_shift}_json{idx}_cfg{guidance_scale}_seed{seed}_video.mp4")
                gt_filename = os.path.join(subfolder, f"{checkpoint_name}_flow_{flow_shift}_json{idx}_cfg{guidance_scale}_seed{seed}_video_gt.mp4")
                depth_filename = os.path.join(subfolder, f"{checkpoint_name}_flow_{flow_shift}_json{idx}_cfg{guidance_scale}_seed{seed}_depth.mp4")
                
                export_to_video(video, video_filename, fps=fps)
                export_to_video(gt, gt_filename, fps=fps)
                media.write_video(depth_filename, colored_depth_video, fps=fps)
                
                print(f"Saved: {video_filename}, {depth_filename}")
        
        del transformer, pipe
        torch.cuda.empty_cache()

def main():
    model_id = "path/to/model"
    output_dir = "path/to/output"
    json_path = "path/to/metadata.json"
    checkpoints = ["path/to/checkpoint"]
    
    run_inference(
        checkpoints=checkpoints,
        json_path=json_path,
        output_dir=output_dir,
        model_id=model_id,
        height=480,
        width=832,
        num_frames=81,
        num_inference_steps=60,
        guidance_scale=5.0,
        fps=15
    )

if __name__ == "__main__":
    main()