from pathlib import Path
import warnings
import time

import torch
from tqdm import tqdm
from omegaconf import OmegaConf
import os

from models.vaura_model import VAURAModel
from utils.train_utils import get_datamodule_from_type
from utils.demo_utils import resolve_ckpt_demo, resolve_hparams_demo
from models.data.vggsound_dataset import EPS as EPS_VGGSOUND
from scripts.generate import save_results, get_original_data
import random
import numpy as np
import argparse

def get_args():
    parser = argparse.ArgumentParser(description="V-AURA Demo")
    parser.add_argument("--save_v2a_dir", type=str, required=True, help="Directory to save generated samples")
    parser.add_argument("--data_json_file", type=str, default=None, help="Path to the JSONL file containing the data")
    parser.add_argument("--device_id", type=int, default=0, help="Device ID to use for generation")
    parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility")
    args = parser.parse_args()
    return args

def generate_v2a(save_v2a_dir, data_json_file, device_id, seed):
        
    EXPERIMENT_DIR = "V-AURA/logs/24-08-01T08-34-26"
    AVCLIP_CKPT = "V-AURA/segment_avclip/vggsound/best.pt"
    OUTPUT_DIR = save_v2a_dir

    DURATION = 2.56
    STRIDE = 1.28  
    assert DURATION % 0.64 == 0
    assert STRIDE % 0.64 == 0

    DEVICE = f"cuda:{device_id}"
    TEMP = 0.95
    TOP_K = 128
    CFG_SCALE = 6.0

    # Resolve paths and download checkpoints if needed
    output_path = Path(OUTPUT_DIR)
    output_path.mkdir(exist_ok=True, parents=True)
    checkpoint_path = resolve_ckpt_demo(EXPERIMENT_DIR)
    hparams_path = resolve_hparams_demo(checkpoint_path, AVCLIP_CKPT)

    print(f"Using checkpoint: {checkpoint_path}")
    print(f"Using hparams: {hparams_path}")
    print(f"Using output path: {output_path}")

    # Load the model
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        model = VAURAModel.load_from_checkpoint(
            checkpoint_path, hparams_file=hparams_path, map_location=DEVICE
        )
    model.eval()

    # Resolve dataloader
    dl_cfg = OmegaConf.load("./data/demo/dataloader_config.yaml") 
    dl_cfg["sample_duration"] = DURATION
    dl_cfg["path_to_metadata"] = data_json_file
    OmegaConf.resolve(dl_cfg)  # resolve durations

    datamodule = get_datamodule_from_type("motionformer_gen", dl_cfg)
    datamodule.setup("test")
    dataloader = datamodule.test_dataloader()

    # Resolve generation parameters
    MODEL_MAX_DURATION = 2.56  # do not modify
    COMPRESSION_MODEL_FRAME_RATE = 86  # do not modify
    if DURATION > MODEL_MAX_DURATION:
        assert STRIDE < MODEL_MAX_DURATION

    total_gen_len = int(DURATION * COMPRESSION_MODEL_FRAME_RATE) 
    model.sampler.audio_tokens_per_video_frame = 7

    # Generate
    for sample in tqdm(dataloader):
        # Set random seed for reproducibility
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        assert sample["meta"]["duration"] >= DURATION, "Sample duration can not exceed conditional video duration"

        frames = sample["frames"].to(DEVICE)
        current_gen_offset: int = 0
        prompt_length: int = 0
        all_tokens = []
        prompt_tokens = None

        # get original data without transformations
        original_frames, _ = get_original_data( 
            sample["meta"],
            0.0,
            EPS_VGGSOUND,
            DURATION,
            0,
        )

        start_time = time.time()
        assert DURATION <= MODEL_MAX_DURATION, "Sample duration can not exceed conditional video duration"
        item = model.generate(
            frames=frames,
            audio=prompt_tokens,
            max_new_tokens=total_gen_len,
            return_sampled_indices=True,
            use_sampling=True,
            temp=TEMP,
            top_k=TOP_K,
            cfg_scale=CFG_SCALE,
            remove_prompts=False,
            prompt_is_encoded=True,
        )
        generated_audios = item["generated_audio"]

        end_time = time.time()
        # Save results
        for i, generated_audios in enumerate(generated_audios):
            file_full_name = sample["meta"]["filepath"][i]
            file_rel_path = file_full_name.split("/")[-2]
            file_name = os.path.join(output_path, file_rel_path, os.path.basename(file_full_name))
            os.makedirs(os.path.dirname(file_name), exist_ok=True)
            save_results(
                generated_audios[i],
                original_frames[i],
                str(file_name),
                sample["meta"]["video_fps"][i].item(),
                44100,
                sample["meta"]["audio_fps"][i].item(),
            )
            
if __name__ == "__main__":
    args = get_args()
    generate_v2a(args.save_v2a_dir, args.data_json_file, args.device_id, args.seed)