import torch
import json
from models.sound_engine import GNSE_models
from modules.stable_audio_tools.models import create_model_from_config
from modules.stable_audio_tools.models.autoencoders import AudioAutoencoder
from modules.stable_audio_tools.models.utils import load_ckpt_state_dict
from modules.stable_audio_tools.utils.torch_common import copy_state_dict
from data.extracted_gamegenx import ExtractedGameGenXwithPCA
from torch.utils.data import default_collate
import os
import copy
import torchaudio

import argparse
import subprocess
import time
import glob

torch.backends.cuda.enable_flash_sdp = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')

setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)

def custom_collate(batch):
    batch = [sample for sample in batch if sample is not None]
    if len(batch) == 0:
        raise ValueError("No valid samples in batch")
    return default_collate(batch)

def main(args):
    # Setup PyTorch:
    torch.manual_seed(args.global_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.set_grad_enabled(False)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    
    save_folder = os.path.join(
        args.results_dir, args.gnse_model,
        "seed{}_token_seq{}-diffsteps{}-temp{}-trans_cfg{}_{}diff_cfg{}_inference_noise{}".format(
            args.global_seed,
            int(2 * args.inf_duration * args.video_fps),
            args.num_diffusion_steps, 
            args.temperature, 
            args.trans_cfg_scale, 
            args.cfg_schedule, 
            args.diff_cfg_scale,
            args.inference_noise,
            )
        )
    
    
    test_dataset = ExtractedGameGenXwithPCA(
        csv_path = args.test_csv_path,
        premade_feature_dir_video = args.premade_feature_dir_video,
        pca_path = args.pca_path,
        # premade_feature_dir_audio = args.premade_feature_dir_audio,
        video_fps=args.video_fps,
        start_time = args.start_time,
        duration = args.inf_duration,
        error_log_path=args.error_log_path,
        vision_aggregation = args.vision_aggregation,
        source = args.source,
        no_pca = args.no_pca
    )

    grid_feature_length = test_dataset.grid_feature_length // args.spatial_ds_rate**2
    
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset, 
        batch_size=args.batch_size, 
        num_workers=args.num_workers, 
        drop_last=False,
        shuffle=False,
        pin_memory=True, 
        collate_fn=custom_collate,
    )
    
    # create and load gpt model
    precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
    model = GNSE_models[args.gnse_model](
        seq_len = int(args.train_duration * args.video_fps),
        num_kv_heads = args.num_kv_heads,
        multiple_of = args.multiple_of,
        per_frame_len = args.per_frame_len,
        ffn_dim_multiplier = args.ffn_dim_multiplier,
        max_period = args.max_period,
        num_types = args.num_types,
        input_type = args.input_type,        
        type_drop_p = args.type_dropout_p,
        token_dropout_p = args.token_dropout_p,
        attn_dropout_p = args.attn_dropout_p,
        resid_dropout_p = args.resid_dropout_p,
        ffn_dropout_p = args.ffn_dropout_p, 
        drop_path_rate = args.drop_path_rate,
        audio_embed_dim = args.audio_embed_dim,
        video_embed_dim = args.video_embed_dim,
        noise_augmentation = args.noise_augmentation,
        noise_aug_max = args.noise_aug_max,
        diffusion_batch_mul = args.diffusion_batch_mul,
        P_mean = args.P_mean,
        P_std = args.P_std,
        sigma_data = args.sigma_data,
        label_drop_prob = args.label_drop_prob,
        label_balance = args.label_balance,
        trans_cfg_dropout_prob = args.trans_cfg_dropout_prob,
        audio_pretraining = args.audio_pretraining,
        vision_aggregation = args.vision_aggregation, 
        num_aggregated_tokens = args.num_aggregated_tokens,
        grid_feature_length = grid_feature_length,
        d_aggregate = args.d_aggregate,
        num_heads_aggregate = args.num_heads_aggregate,
        num_layers_aggregate = args.num_layers_aggregate,
        aggregate_trans_architecture = args.aggregate_trans_architecture,
        aggregation_method = args.aggregation_method,
        norm_type=args.norm_type,
        norm_type_agg=args.norm_type_agg,
        naive_mar_mlp = args.naive_mar_mlp,
        condition_merge = args.condition_merge,
        head_dropout_p = args.head_dropout_p,
        diffusion_type = args.diffusion_type,
        no_subtract = args.no_subtract,
        ect_q = args.ect_q,
        ect_c = args.ect_c,
        ect_k = args.ect_k,
        ect_b = args.ect_b,
    ).to(device, dtype=precision)
    

    checkpoint = torch.load(args.model_checkpoint, map_location="cpu")
    # if args.from_fsdp: # fspd
    #     model_weight = checkpoint
    if "model" in checkpoint:  # ddp
        model_weight = checkpoint["model"]
    elif "module" in checkpoint: # deepspeed
        model_weight = checkpoint["module"]
    elif "state_dict" in checkpoint:
        model_weight = checkpoint["state_dict"]
    else:
        raise Exception("please check model weight, maybe add --from-fsdp to run command")
    # if 'freqs_cis' in model_weight:
    #     model_weight.pop('freqs_cis')
    model.load_state_dict(model_weight, strict=False)
    del checkpoint
    del model_weight
    torch.cuda.empty_cache()
    print("Prioe model is loaded")
    
    if args.ema_checkpoint is not None:
        checkpoint = torch.load(args.ema_checkpoint, map_location="cpu")
        ema_state_dict = checkpoint['ema_model']
        ema_params = [ema_state_dict[name].cuda() for name, _ in model.named_parameters()]
        save_folder = save_folder + "_ema"
        ema_state_dict = copy.deepcopy(model.state_dict())
        for i, (name, _value) in enumerate(model.named_parameters()):
            assert name in ema_state_dict
            ema_state_dict[name] = ema_params[i]
            model.load_state_dict(ema_state_dict)
        del checkpoint
        del ema_state_dict
        torch.cuda.empty_cache()
        
        print("EMA params are loaded")
    # model.eval().requires_grad_(False).to(memory_format=torch.channels_last)
    model.eval().requires_grad_(False)
    
    
    # create and load model
    with open(args.stage1_model_config) as f:
        stage1_model_config = json.load(f)
    stage1_model: AudioAutoencoder = create_model_from_config(stage1_model_config)
    copy_state_dict(stage1_model, load_ckpt_state_dict(args.stage1_ckpt_path))
    stage1_model.to(device).eval().requires_grad_(False)
    # print(f"Prior model parameters: {sum(p.numel() for p in stage1_model.parameters()):,}")
    # print("Audio tokenizer is loaded")
    # breakpoint()
    
    save_folder = save_folder + "_evaluate"
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    print("Save to:", save_folder)
    # model.half()
    if args.compile:
        print("compiling the model...")
        model = torch.compile(
            model,
            backend="inductor",        # inductor is the default, but explicit is clearer
            mode="max-autotune",       # profiles and picks best GEMM/conv kernels, enables CUDA graphs :contentReference[oaicite:0]{index=0}
            fullgraph=True,            # no graph breaks anywhere :contentReference[oaicite:1]{index=1}
            dynamic=False,             # static-shape specialization to avoid recompiles :contentReference[oaicite:2]{index=2}

            # }
        )
        
        model.eval().requires_grad_(False)
    else:
        print(f"no need to compile model in demo") 
        
    
    
    if args.z_stats_path is not None:
        stats = torch.load(args.z_stats_path, map_location="cpu")
        z_mean = stats["z_mean"].to(device).to(precision).view(1, 1, -1) 
        z_std = stats["z_std"].to(device).to(precision).view(1, 1, -1)
    else:
        AssertionError("z_stats_path is required")

    with torch.inference_mode():
        for batch in test_dataloader:
            video = batch["video_feature"].to(precision).to(device)
            with torch.no_grad():
                audio_latents = model.offline_sample_tokens_condition_merge_pca(
                        cond = video, 
                        context_mode = args.context_mode, #("pi", "ntk", "slideing", "none")
                        max_new_tokens = int(2 * args.inf_duration * args.video_fps -1), 
                        vision_aggregation=args.vision_aggregation,
                        audio_pretraining = args.audio_pretraining,
                        trans_cfg_scale = args.trans_cfg_scale, 
                        cfg_interval = args.cfg_interval, 
                        cfg_schedule = args.cfg_schedule, 
                        diff_cfg_scale = args.diff_cfg_scale, 
                        temperature = args.temperature, 
                        num_steps = args.num_diffusion_steps, 
                        sigma_min = args.sigma_min,
                        sigma_max = args.sigma_max,
                        rho = args.rho,
                        S_churn = args.S_churn,
                        S_min = args.S_min,
                        S_max = args.S_max,
                        S_noise = args.S_noise,
                        inference_noise = args.inference_noise
                    )
                denormalized_latents = 2 * z_std * audio_latents + z_mean
                waveform = stage1_model.decode(denormalized_latents.transpose(-2, -1)).float()
                
            
            for j, wav in enumerate(waveform):
                if wav.dim() == 1:
                    wav = wav.unsqueeze(0) 
                elif wav.dim() == 3:
                    wav = wav.squeeze(0)
                elif wav.dim() == 4:
                    wav = wav.squeeze(0).squeeze(0) 
                torchaudio.save(f'{save_folder}/{os.path.basename(batch["filename"][j])}.flac', wav.to('cpu'), format='flac', sample_rate=args.sample_rate)
            del video, audio_latents, waveform
            torch.cuda.empty_cache()
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--test-csv-path", type=str, required=True)
    parser.add_argument("--premade-feature_dir-video", type=str, required=True)
    # parser.add_argument("--premade-feature_dir-audio", type=str, required=True)
    parser.add_argument("--z-stats-path", type=str, required=True, help="z_stats path for training")
    parser.add_argument("--results-dir", type=str, default="results")
    parser.add_argument("--gt-video-dir", type=str, default="/path")
    parser.add_argument("--error-log-path", type=str, default='error_inference_data.txt', help="error log path in dataloder")
    parser.add_argument("--audio-embed-dim", type=int, default=128, help="dimension for audio compression model")
    parser.add_argument("--video-embed-dim", type=int, default=1024, help="dimension for video feature")
    parser.add_argument("--model-checkpoint", type=str, default=None, help="ckpt path")
    parser.add_argument("--ema-checkpoint", type=str, default=None, help="ckpt path")
    parser.add_argument("--batch-size", type=int, default=12)
    parser.add_argument("--global-seed", type=int, default=0)
    parser.add_argument("--num-workers", type=int, default=4)
    parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 
    parser.add_argument("--compile", action='store_true', default=False)
    parser.add_argument("--gnse-model", type=str, choices=list(GNSE_models.keys()), default="GNSE-T")
    parser.add_argument("--stage1-model-config", type=str, required=True, help="model config of stage1 model")
    parser.add_argument("--stage1-ckpt-path", type=str, required=True, help="ckpt path for stage1 model")
    parser.add_argument("--per-frame-len", type=int, default=1)
    parser.add_argument("--video-fps", type=int, default=30, help="frame rate of video")
    parser.add_argument("--channels", type=int, default=2, help="audio channels")
    parser.add_argument("--start-time", type=float, default=0, help="start time of dataset")
    parser.add_argument("--train_duration", type=float, default=8, help="duration of dataset")
    parser.add_argument("--inf_duration", type=float, default=16, help="duration of dataset")
    parser.add_argument("--num-types", type=int, default=2)
    parser.add_argument("--input-type", type=str, choices=['interleave', 'concat'], default="interleave")
    parser.add_argument("--max-period", type=float, default=10_000)
    parser.add_argument("--num-kv-heads", type=int, default=None)
    parser.add_argument("--drop-path-rate", type=float, default=0.1, help="drop_path_rate of attention and ffn")
    parser.add_argument("--token-dropout-p", type=float, default=0., help="dropout_p of token_dropout_p")
    parser.add_argument("--type-dropout-p", type=float, default=0., help="dropout_p of type_dropout_p")
    parser.add_argument("--resid-dropout-p", type=float, default=0.1, help="dropout_p of resid_dropout_p")
    parser.add_argument("--ffn-dropout-p", type=float, default=0.1, help="dropout_p of ffn_dropout_p")
    parser.add_argument("--attn-dropout-p", type=float, default=0.1, help="dropout_p of attn_dropout_p")
    parser.add_argument("--ffn-dim-multiplier", type=float, default=None, help="ffn_dim_multiplier")
    parser.add_argument("--multiple-of", type=int, default=256, help="multiple_of")
    parser.add_argument("--head-dropout-p", type=float, default=0.1)
    parser.add_argument("--diffusion-batch-mul", type=int, default=4)
    parser.add_argument("--label-balance", type=float, default=0.5)
    parser.add_argument("--P-mean", type=float, default=-0.4)
    parser.add_argument("--P-std", type=float, default=1.0)
    parser.add_argument("--sigma-data", type=float, default=0.5)
    parser.add_argument("--label-drop-prob", type=float, default=0.1)
    parser.add_argument("--sample-rate", type=int, default=48000, help="sample rate of audio")
    parser.add_argument("--trans-cfg-dropout-prob", type=float, default=0.1)
    parser.add_argument("--audio-pretraining", action='store_true', help="flag for audio only pretraining")
    parser.add_argument("--vision-aggregation", action='store_true', help="flag for vision aggregation along with spatial axis.")
    parser.add_argument("--num-aggregated-tokens", type=int, default=1, help="number of aggregated tokens for vision aggregation")
    parser.add_argument("--d-aggregate", type=int, default=512, choices=[64, 128, 256, 384, 512, 768, 1024], help="number of feature dimension of attention blocks for vision aggregation")
    parser.add_argument("--num-heads-aggregate", type=int, default=8, help="number of heads of attention blocks for vision aggregation")
    parser.add_argument("--num-layers-aggregate", type=int, default=1, help="number of attention blocks for vision aggregation")
    parser.add_argument("--spatial-ds-rate", type=int, default=2, help="Downsampling rate of conv2d in vision aggregation")
    parser.add_argument("--aggregate-trans-architecture", type=str, default='sdpa', choices=['linear','sdpa'],help="Use linear transformer or sdpa for vision aggregation")
    parser.add_argument("--aggregation-method", type=str, default='self', choices=['self','cross'],help="select self or cross attention for vision aggregation")
    parser.add_argument("--norm-type-agg", type=str, default='rms', choices=['rms','dyt'], help="choose normalization type for aggregation transformer")
    parser.add_argument("--norm-type", type=str, default='rms', choices=['rms','dyt'], help="choose normalization type")
    parser.add_argument("--naive-mar-mlp", action='store_true', help="Use MAR's diffusion head architecture")
    parser.add_argument("--condition-merge", action='store_true', help="Maerge different conditon into diffusion head.")
    parser.add_argument("--diffusion-type", type=str, default='diffusion', choices=['diffusion','trigflow', 'diffusion_v2', 'ect'], help="choose diffusion model type")

    parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
    parser.add_argument("--trans-cfg-scale", type=float, default=1.0, help="CFG for transformer")
    parser.add_argument("--cfg-interval", type=int, default=-1, help="CFG interval for transformer")
    parser.add_argument("--cfg-schedule", type=str, default="constant", choices=["constant", "linear"], help="CFG scheduler for diffusion head with time axis")
    parser.add_argument("--diff-cfg-scale", type=float, default=1.0, help="CFG for diffusion head")
    parser.add_argument("--num-diffusion-steps", type=int, default=20, help="number of sampling steps for diffusion")
    parser.add_argument("--sigma-min", type=float, default=0.002, help="sigma_min")
    parser.add_argument("--sigma-max", type=float, default=80.0, help="sigma_max")
    parser.add_argument("--rho", type=float, default=7.0, help="rho")
    parser.add_argument("--S-churn", type=int, default=0, help="S_churn")
    parser.add_argument("--S-min", type=int, default=0, help="S_min")
    parser.add_argument("--S-max", type=float, default=float('inf'), help="S_max")
    parser.add_argument("--S-noise", type=int, default=1, help="S_noise")
    parser.add_argument("--source", type=str, default='hdf5', required=True , choices=['hdf5','npy'], help="choose test dataset format")
    
    
    parser.add_argument("--pca-path", type=str, help="path for precomputed pca")
    parser.add_argument("--merge_video_and_audio", action='store_true')
    parser.add_argument("--no_pca", action='store_true')
    parser.add_argument("--no_subtract", action='store_true')
    parser.add_argument("--ect-q", type=float, default=4, help="4 for imagenet, 2 for cifar10, 256 for prototyping in ect paper")
    parser.add_argument("--ect-c", type=float, default=0.06, help="hyperparameters for continous-time scheduling in ect paper.")
    parser.add_argument("--ect-k", type=float, default=8, help="hyperparameters for continous-time scheduling in ect paper. 8 for all the settings.")
    parser.add_argument("--ect-b", type=float, default=1, help="hyperparameters for continous-time scheduling in ect paper. 1 for all the settings.")
    parser.add_argument("--ect-d-rate", type=int, default=4, choices=[4, 8], help="4 for imagenet, 8 for cifar10 in ect paper")
    parser.add_argument("--train-transformer", action='store_true', help="Train transformer part with consistency training/distillation.")
    parser.add_argument("--context-mode", type=str, default='none', choices=['pi', 'ntk', 'sliding', 'none'], help="choose context mode")
    args = parser.parse_args()
    main(args)
    
    