import torch
import json
from models.sound_engine import GNSE_models
# causal dec
from modules_2.stable_audio_tools.models import create_model_from_config
from modules_2.stable_audio_tools.models.autoencoders import AudioAutoencoder
from modules_2.stable_audio_tools.models.utils import load_ckpt_state_dict
from modules_2.stable_audio_tools.utils.torch_common import copy_state_dict
from models.waveform_streaming import convert_decoder_to_streaming_inplace
from torch.utils.data import default_collate
from models.vision_streaming import DINOv2StreamingEncoder
from data.video_stream_dataset import OnlineVideoDataset_2
import os
import copy
import numpy as np
import torchaudio

import argparse
import subprocess
import time
import glob

torch.backends.cuda.enable_flash_sdp = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True  
torch.set_float32_matmul_precision('high')

setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)


def custom_collate(batch):
    batch = [sample for sample in batch if sample is not None]
    if len(batch) == 0:
        raise ValueError("No valid samples in batch")
    return default_collate(batch)


def main(args):
    # Setup PyTorch:
    torch.manual_seed(args.global_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.set_grad_enabled(False)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    
    save_folder = os.path.join(
        args.results_dir, args.gnse_model,
        "seed{}_token_seq{}-diffsteps{}-temp{}-trans_cfg{}_{}diff_cfg{}_inference_noise{}".format(
            args.global_seed,
            int(2 * args.inf_duration * args.video_fps),
            args.num_diffusion_steps, 
            args.temperature, 
            args.trans_cfg_scale, 
            args.cfg_schedule, 
            args.diff_cfg_scale,
            args.inference_noise,
            )
        )
    
    test_dataset = OnlineVideoDataset_2.from_csv(
        csv_path=args.test_csv_path,
        column="video_folder",
        start_time=args.start_time,
        duration_sec=args.inf_duration,
        fps=args.video_fps,
        resize_hw=(480, 854),
        drop_short=False,
        error_log_path="/log",
        exist_filter=True,
    )

    grid_feature_length =  540 // args.spatial_ds_rate**2
    # if data == 'ogamedata':
    #     self.grid_feature_length = 540
    # elif data == 'vggsound':
    #     self.grid_feature_length = 336
    
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset, 
        batch_size=args.batch_size, 
        num_workers=args.num_workers, 
        drop_last=False,
        shuffle=False,
        pin_memory=True, 
        collate_fn=custom_collate,
    )
    
    # create and load gpt model
    precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
    model = GNSE_models[args.gnse_model](
        seq_len = int(args.train_duration * args.video_fps),
        num_kv_heads = args.num_kv_heads,
        multiple_of = args.multiple_of,
        per_frame_len = args.per_frame_len,
        ffn_dim_multiplier = args.ffn_dim_multiplier,
        max_period = args.max_period,
        num_types = args.num_types,
        input_type = args.input_type,        
        type_drop_p = args.type_dropout_p,
        token_dropout_p = args.token_dropout_p,
        attn_dropout_p = args.attn_dropout_p,
        resid_dropout_p = args.resid_dropout_p,
        ffn_dropout_p = args.ffn_dropout_p, 
        drop_path_rate = args.drop_path_rate,
        audio_embed_dim = args.audio_embed_dim,
        video_embed_dim = args.video_embed_dim,
        noise_augmentation = args.noise_augmentation,
        noise_aug_max = args.noise_aug_max,
        diffusion_batch_mul = args.diffusion_batch_mul,
        P_mean = args.P_mean,
        P_std = args.P_std,
        sigma_data = args.sigma_data,
        label_drop_prob = args.label_drop_prob,
        label_balance = args.label_balance,
        trans_cfg_dropout_prob = args.trans_cfg_dropout_prob,
        audio_pretraining = args.audio_pretraining,
        vision_aggregation = args.vision_aggregation, 
        num_aggregated_tokens = args.num_aggregated_tokens,
        grid_feature_length = grid_feature_length,
        d_aggregate = args.d_aggregate,
        num_heads_aggregate = args.num_heads_aggregate,
        num_layers_aggregate = args.num_layers_aggregate,
        aggregate_trans_architecture = args.aggregate_trans_architecture,
        aggregation_method = args.aggregation_method,
        norm_type=args.norm_type,
        norm_type_agg=args.norm_type_agg,
        naive_mar_mlp = args.naive_mar_mlp,
        condition_merge = args.condition_merge,
        head_dropout_p = args.head_dropout_p,
        diffusion_type = args.diffusion_type,
        no_subtract = args.no_subtract,
        ect_q = args.ect_q,
        ect_c = args.ect_c,
        ect_k = args.ect_k,
        ect_b = args.ect_b,
    ).to(device, dtype=precision)
    
    dino = DINOv2StreamingEncoder(
        pca_path=args.pca_path,
        model_name="dinov2_vits14_reg",
        device=device, dtype=precision,
        resize_grid_factor=2, # hard coded for 854x480 input
        maintain_aspect=True,
    ).eval().requires_grad_(False)
    

    checkpoint = torch.load(args.model_checkpoint, map_location="cpu")
    # if args.from_fsdp: # fspd
    #     model_weight = checkpoint
    if "model" in checkpoint:  # ddp
        model_weight = checkpoint["model"]
    elif "module" in checkpoint: # deepspeed
        model_weight = checkpoint["module"]
    elif "state_dict" in checkpoint:
        model_weight = checkpoint["state_dict"]
    else:
        raise Exception("please check model weight, maybe add --from-fsdp to run command")
    # if 'freqs_cis' in model_weight:
    #     model_weight.pop('freqs_cis')
    model.load_state_dict(model_weight, strict=False)
    del checkpoint
    del model_weight
    torch.cuda.empty_cache()
    print("Prioe model is loaded")
    
    if args.ema_checkpoint is not None:
        checkpoint = torch.load(args.ema_checkpoint, map_location="cpu")
        ema_state_dict = checkpoint['ema_model']
        ema_params = [ema_state_dict[name].cuda() for name, _ in model.named_parameters()]
        save_folder = save_folder + "_ema"
        ema_state_dict = copy.deepcopy(model.state_dict())
        for i, (name, _value) in enumerate(model.named_parameters()):
            assert name in ema_state_dict
            ema_state_dict[name] = ema_params[i]
            model.load_state_dict(ema_state_dict)
        del checkpoint
        del ema_state_dict
        torch.cuda.empty_cache()
        
        print("EMA params are loaded")
    # model.eval().requires_grad_(False).to(memory_format=torch.channels_last)
    model.eval().requires_grad_(False)
    
    
    # create and load model
    with open(args.stage1_model_config) as f:
        stage1_model_config = json.load(f)
    stage1_model: AudioAutoencoder = create_model_from_config(stage1_model_config)
    copy_state_dict(stage1_model, load_ckpt_state_dict(args.stage1_ckpt_path))
    stage1_model.to(device).eval().requires_grad_(False)
    dec_dtype = next(model.transformer.parameters()).dtype
    convert_decoder_to_streaming_inplace(stage1_model, device=device, dtype=dec_dtype)
    stage1_model.to(device).eval().requires_grad_(False)
    # stage1_model.decoder = torch.compile(
    #     stage1_model.decoder, backend="inductor",
    #     mode="reduce-overhead",
    #     fullgraph=True,
    #     dynamic=False,
    # ).eval()

    save_folder = save_folder + "_evaluate"
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    print("Save to:", save_folder)
    
    if args.compile:
        print("compiling the model...")
        model = torch.compile(
            model,
            backend="inductor",
            mode="reduce-overhead",
            fullgraph=False,
            dynamic=True,
        )
        
        model.eval().requires_grad_(False)
    else:
        print(f"no need to compile model in demo") 
    
    
    
    if args.z_stats_path is not None:
        stats = torch.load(args.z_stats_path, map_location="cpu")
        z_mean = stats["z_mean"].to(device).to(precision).view(1, 1, -1) 
        z_std = stats["z_std"].to(device).to(precision).view(1, 1, -1)
    else:
        AssertionError("z_stats_path is required")
    # acc = {k: [] for k in ("vision_ms", "gen_ms", "frame_ms")}
    acc = {k: [] for k in ("frame_ms",)}
    with torch.inference_mode():
        for i, batch in enumerate(test_dataloader):
            if i >= int(56):
                break
            frames_u8 = batch["video_frames"] 
            audio_latents, incremental_waveform, lat = model.online_sample_tokens_from_video(
                frames_u8, dino,
                stage1_model=stage1_model,
                z_mean=z_mean,
                z_std=z_std,
                device=device,                    
                noncausal_right_margin_latents=0,
                context_mode=args.context_mode,
                measure_latency=True,
                vision_aggregation=True,
                audio_pretraining=args.audio_pretraining,
                trans_cfg_scale=args.trans_cfg_scale,
                cfg_schedule=args.cfg_schedule,
                diff_cfg_scale=args.diff_cfg_scale,
                temperature=args.temperature,
                num_steps=args.num_diffusion_steps,
                sigma_min=args.sigma_min,
                sigma_max=args.sigma_max,
                rho=args.rho,
                S_churn=args.S_churn,
                S_min=args.S_min,
                S_max=args.S_max,
                S_noise=args.S_noise,
                inference_noise=args.inference_noise,
            )
            if lat is not None and "frame_ms" in lat and len(lat["frame_ms"]) > 0:
                mean_frame = float(np.mean(lat["frame_ms"]))
                print(f"[sample {i}] per-frame mean (ms): frame={mean_frame:.2f}")
                acc["frame_ms"].append(mean_frame)
            else:
                print(f"[sample {i}] no frame_ms recorded")
            for j, wav in enumerate(incremental_waveform):
                wav = wav.float()
                if wav.dim() == 1:
                    wav = wav.unsqueeze(0) 
                elif wav.dim() == 3:
                    wav = wav.squeeze(0)
                elif wav.dim() == 4:
                    wav = wav.squeeze(0).squeeze(0) 
                torchaudio.save(f'{save_folder}/{os.path.basename(batch["filename"][j])}.flac', wav.to('cpu'), format='flac', sample_rate=args.sample_rate)
    arr = np.asarray(acc["frame_ms"], dtype=np.float64)
    trimmed = arr[6:]  # warm-up skip
    if trimmed.size == 0:
        print("frame_ms: not enough samples")
    else:
        mean = trimmed.mean()
        var  = trimmed.var(ddof=1) if trimmed.size > 1 else 0.0
        std  = var ** 0.5
        print(f"{'frame_ms':>9}: mean={mean:.2f} ms, var={var:.2f}, std={std:.2f} (n={trimmed.size})")
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--test-csv-path", type=str, required=True)
    parser.add_argument("--premade-feature_dir-video", type=str, required=True)
    # parser.add_argument("--premade-feature_dir-audio", type=str, required=True)
    parser.add_argument("--z-stats-path", type=str, required=True, help="z_stats path for training")
    parser.add_argument("--results-dir", type=str, default="results")
    parser.add_argument("--gt-video-dir", type=str, default="/home2")
    parser.add_argument("--error-log-path", type=str, default='error_inference_data.txt', help="error log path in dataloder")
    parser.add_argument("--audio-embed-dim", type=int, default=128, help="dimension for audio compression model")
    parser.add_argument("--video-embed-dim", type=int, default=1024, help="dimension for video feature")
    parser.add_argument("--model-checkpoint", type=str, default=None, help="ckpt path")
    parser.add_argument("--ema-checkpoint", type=str, default=None, help="ckpt path")
    parser.add_argument("--batch-size", type=int, default=12)
    parser.add_argument("--global-seed", type=int, default=0)
    parser.add_argument("--num-workers", type=int, default=4)
    parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 
    parser.add_argument("--compile", action='store_true', default=False)
    parser.add_argument("--gnse-model", type=str, choices=list(GNSE_models.keys()), default="GNSE-T")
    parser.add_argument("--stage1-model-config", type=str, required=True, help="model config of stage1 model")
    parser.add_argument("--stage1-ckpt-path", type=str, required=True, help="ckpt path for stage1 model")
    parser.add_argument("--per-frame-len", type=int, default=1)
    parser.add_argument("--video-fps", type=int, default=30, help="frame rate of video")
    parser.add_argument("--channels", type=int, default=2, help="audio channels")
    parser.add_argument("--start-time", type=float, default=0, help="start time of dataset")
    parser.add_argument("--train_duration", type=float, default=8, help="duration of dataset")
    parser.add_argument("--inf_duration", type=float, default=16, help="duration of dataset")
    parser.add_argument("--num-types", type=int, default=2)
    parser.add_argument("--input-type", type=str, choices=['interleave', 'concat'], default="interleave")
    parser.add_argument("--max-period", type=float, default=10_000)
    parser.add_argument("--num-kv-heads", type=int, default=None)
    parser.add_argument("--drop-path-rate", type=float, default=0.1, help="drop_path_rate of attention and ffn")
    parser.add_argument("--token-dropout-p", type=float, default=0., help="dropout_p of token_dropout_p")
    parser.add_argument("--type-dropout-p", type=float, default=0., help="dropout_p of type_dropout_p")
    parser.add_argument("--resid-dropout-p", type=float, default=0.1, help="dropout_p of resid_dropout_p")
    parser.add_argument("--ffn-dropout-p", type=float, default=0.1, help="dropout_p of ffn_dropout_p")
    parser.add_argument("--attn-dropout-p", type=float, default=0.1, help="dropout_p of attn_dropout_p")
    parser.add_argument("--ffn-dim-multiplier", type=float, default=None, help="ffn_dim_multiplier")
    parser.add_argument("--multiple-of", type=int, default=256, help="multiple_of")
    parser.add_argument("--head-dropout-p", type=float, default=0.1)
    parser.add_argument("--diffusion-batch-mul", type=int, default=4)
    parser.add_argument("--label-balance", type=float, default=0.5)
    parser.add_argument("--P-mean", type=float, default=-0.4)
    parser.add_argument("--P-std", type=float, default=1.0)
    parser.add_argument("--sigma-data", type=float, default=0.5)
    parser.add_argument("--label-drop-prob", type=float, default=0.1)
    parser.add_argument("--sample-rate", type=int, default=48000, help="sample rate of audio")
    parser.add_argument("--trans-cfg-dropout-prob", type=float, default=0.1)
    parser.add_argument("--noise-augmentation", action='store_true')
    parser.add_argument("--noise-aug-max", type=float, default=0.5)
    parser.add_argument("--inference-noise", type=float, default=0.1)
    parser.add_argument("--audio-pretraining", action='store_true', help="flag for audio only pretraining")
    parser.add_argument("--vision-aggregation", action='store_true', help="flag for vision aggregation along with spatial axis.")
    parser.add_argument("--num-aggregated-tokens", type=int, default=1, help="number of aggregated tokens for vision aggregation")
    parser.add_argument("--d-aggregate", type=int, default=512, choices=[64, 128, 256, 384, 512, 768, 1024], help="number of feature dimension of attention blocks for vision aggregation")
    parser.add_argument("--num-heads-aggregate", type=int, default=8, help="number of heads of attention blocks for vision aggregation")
    parser.add_argument("--num-layers-aggregate", type=int, default=1, help="number of attention blocks for vision aggregation")
    parser.add_argument("--spatial-ds-rate", type=int, default=2, help="Downsampling rate of conv2d in vision aggregation")
    parser.add_argument("--aggregate-trans-architecture", type=str, default='sdpa', choices=['linear','sdpa'],help="Use linear transformer or sdpa for vision aggregation")
    parser.add_argument("--aggregation-method", type=str, default='self', choices=['self','cross'],help="select self or cross attention for vision aggregation")
    parser.add_argument("--norm-type-agg", type=str, default='rms', choices=['rms','dyt'], help="choose normalization type for aggregation transformer")
    parser.add_argument("--norm-type", type=str, default='rms', choices=['rms','dyt'], help="choose normalization type")
    parser.add_argument("--naive-mar-mlp", action='store_true', help="Use MAR's diffusion head architecture")
    parser.add_argument("--condition-merge", action='store_true', help="Maerge different conditon into diffusion head.")
    parser.add_argument("--diffusion-type", type=str, default='diffusion', choices=['diffusion','trigflow', 'diffusion_v2', 'ect'], help="choose diffusion model type")

    parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
    parser.add_argument("--trans-cfg-scale", type=float, default=1.0, help="CFG for transformer")
    parser.add_argument("--cfg-interval", type=int, default=-1, help="CFG interval for transformer")
    parser.add_argument("--cfg-schedule", type=str, default="constant", choices=["constant", "linear"], help="CFG scheduler for diffusion head with time axis")
    parser.add_argument("--diff-cfg-scale", type=float, default=1.0, help="CFG for diffusion head")
    parser.add_argument("--num-diffusion-steps", type=int, default=20, help="number of sampling steps for diffusion")
    parser.add_argument("--sigma-min", type=float, default=0.002, help="sigma_min")
    parser.add_argument("--sigma-max", type=float, default=80.0, help="sigma_max")
    parser.add_argument("--rho", type=float, default=7.0, help="rho")
    parser.add_argument("--S-churn", type=int, default=0, help="S_churn")
    parser.add_argument("--S-min", type=int, default=0, help="S_min")
    parser.add_argument("--S-max", type=float, default=float('inf'), help="S_max")
    parser.add_argument("--S-noise", type=int, default=1, help="S_noise")
    parser.add_argument("--source", type=str, default='hdf5', required=True , choices=['hdf5','npy'], help="choose test dataset format")
    
    
    parser.add_argument("--pca-path", type=str, help="path for precomputed pca")
    parser.add_argument("--merge_video_and_audio", action='store_true')
    parser.add_argument("--no_pca", action='store_true')
    parser.add_argument("--no_subtract", action='store_true')
    parser.add_argument("--ect-q", type=float, default=4, help="4 for imagenet, 2 for cifar10, 256 for prototyping in ect paper")
    parser.add_argument("--ect-c", type=float, default=0.06, help="hyperparameters for continous-time scheduling in ect paper.")
    parser.add_argument("--ect-k", type=float, default=8, help="hyperparameters for continous-time scheduling in ect paper. 8 for all the settings.")
    parser.add_argument("--ect-b", type=float, default=1, help="hyperparameters for continous-time scheduling in ect paper. 1 for all the settings.")
    parser.add_argument("--ect-d-rate", type=int, default=4, choices=[4, 8], help="4 for imagenet, 8 for cifar10 in ect paper")
    parser.add_argument("--train-transformer", action='store_true', help="Train transformer part with consistency training/distillation.")
    parser.add_argument("--context-mode", type=str, default='none', choices=['pi', 'ntk', 'sliding', 'none'], help="choose context mode")
    args = parser.parse_args()
    main(args)
    
    