import torch
import os
import numpy as np
import argparse
import json
import random
from WanVideoModel import wan
from WanVideoModel.wan.configs import WAN_CONFIGS
from WanVideoModel.wan.utils.utils import cache_video
from VideoReward.score import *
from LatentReward.generate_model import LatentReward


parser = argparse.ArgumentParser(description="VideoGen")
parser.add_argument('--infer_step', type=int, default=50, help='total inference timestep T')
parser.add_argument('--frame_number', type=int, default=33, help='number of video frames')
parser.add_argument('--seed', type=int, default=42, help='Random seed to determine the initial latent.')
parser.add_argument('--device', type=str, default='cuda', help='Device where the model inference is performed.')
parser.add_argument('--model_name', type=str, default='t2v-1.3B')
parser.add_argument('--model_path', type=str, default='./WanVideoModel/Wan2.1-T2V-1.3B')
parser.add_argument('--reward_model_path', type=str, default="/data/latent_reward.pt")
parser.add_argument('--prompt_path', type=str, default='/data/VBench2_full_info.json', help='prompt file path')
parser.add_argument('--dimension_list', nargs='+', default=["Camera_Motion"], help='List of dimensions (one or more)')
parser.add_argument('--save_dir', type=str, default='/data/generated_videos', help='Path to save the generated videos.')
parser.add_argument('--beta', type=float, default=0.7)
args = parser.parse_args()


random.seed(args.seed)  
np.random.seed(args.seed) 
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


if __name__ == '__main__':

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir, exist_ok=True)
    dtype = torch.float16
    device = torch.device(args.device)

    cfg = WAN_CONFIGS[args.model_name]
    pipe = wan.WanT2VWithLatentReward(
            config=cfg,
            checkpoint_dir=args.model_path,
            device_id=0,
        )

    latent_verifier = LatentReward(
            load_from_pretrained='./VideoReward',
            device=device,
            dtype=torch.bfloat16,
    ).to(device)
    latent_verifier.load_state_dict(torch.load(args.reward_model_path))
    latent_verifier.eval()

    file = open(args.prompt_path)
    prompts = json.load(file)

    for idx, prompt_ in enumerate(prompts):

        dimension = prompt_["dimension"][0]
        
        if dimension in args.dimension_list:
            if dimension=='Diversity':
                num_iters=20
            else:
                num_iters=3
            prompt = prompt_["prompt_en"]

            os.makedirs(f'{args.save_dir}/{dimension}', exist_ok=True)

            for index in range(num_iters):
                video = pipe.generate(
                    input_prompt=prompt,
                    size=(832,480),
                    frame_num=args.frame_number,
                    shift=5.0,
                    sample_solver='unipc',
                    sampling_steps=args.infer_step,
                    guide_scale=5.0,
                    seed=args.seed + index,
                    verifier=latent_verifier,
                    search_schedule=[10, 15, 20],
                    num_candidates=4,
                    beta=args.beta,
                )
                video_path = f'{args.save_dir}/{dimension}/{prompt[:180]}-{index}.mp4'
                cache_video(
                    tensor=video[None],
                    save_file=video_path,
                    fps=16,
                    nrow=1,
                    normalize=True,
                    value_range=(-1, 1))
                print(f"Saved: {video_path}")

    print('Generation End!')