import os
import copy
import json
import time
import torch
import numpy as np
import random
import argparse
import soundfile as sf
import wandb
import matplotlib.pyplot as plt
import csv
import auraloss
from tqdm import tqdm
from accelerate.utils import set_seed
import torch.nn as nn
# from diffusers import DDPMScheduler
from audioldm_eval_dev import EvaluationHelper
from tango_edm.models_edm import build_pretrained_models
# from transformers import AutoProcessor, ClapModel
import torchaudio
from loss_guided_diffusion.operator import Declipping, Intensity, BWE, StyleGram
import tools.torch_tools as torch_tools
from cm.script_util_v4_invp import (
    create_model_and_diffusion,
)
from cm.karras_diffusion_v3_invp import karras_sample_guided



class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
    
def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def rand_fix(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    

def normalize_wav(waveform):
    waveform = waveform - torch.mean(waveform)
    waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
    return waveform * 0.5

def pad_wav(waveform, segment_length):
    waveform_length = len(waveform)
    
    if segment_length is None or waveform_length == segment_length:
        return waveform
    elif waveform_length > segment_length:
        return waveform[:segment_length]
    else:
        pad_wav = torch.zeros(segment_length - waveform_length).to(waveform.device)
        waveform = torch.cat([waveform, pad_wav])
        return waveform

def read_wav_file(filename, segment_length, new_freq=16000):
    waveform, sr = torchaudio.load(filename)  # Faster!!!
    waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=new_freq, rolloff=0.95, resampling_method = 'sinc_interp_kaiser' )[0]
    # print("waveform_before_normalized.shape", waveform.shape)
    try:
        waveform = normalize_wav(waveform)
    except:
        print ("Exception normalizing:", filename)
        waveform = torch.ones(160000)
    waveform = pad_wav(waveform, segment_length).unsqueeze(0)
    waveform = waveform / torch.max(torch.abs(waveform))
    waveform = 0.5 * waveform
    # print("waveform_after_normalized.shape", waveform.shape)
    return waveform
import torch

def curve(T: int, start=-15, end=-45):
    t = torch.linspace(-6, 6, T)
    sigmoid = 1 / (1 + torch.exp(-t))
    
    scale = end - start
    signal = start + scale * sigmoid
    
    return signal

def parse_args():
    parser = argparse.ArgumentParser(description="Inference for text to audio generation task.")
    parser.add_argument(
        "--training_args", type=str, default=None,
        help="Path for summary jsonl file saved during training."
    )
    parser.add_argument(
        "--output_dir", type=str, default=None,
        help="Where to store the output."
    )
    parser.add_argument(
        "--seed", type=int, default=5031, #43
        help="A seed for reproducible training."
    )
    parser.add_argument(
        "--text_encoder_name", type=str, default="google/flan-t5-large",
        help="Text encoder identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--ctm_unet_model_config", type=str, default=None,
        help="UNet model config json path.",
    )
    parser.add_argument(
        "--model", type=str, default=None,
        help="Path for saved model bin file."
    )
    parser.add_argument(
        "--ema_model", type=str, default=None,
        help="Path for saved EMA model bin file."
    )
    parser.add_argument(
        "--sampler", type=str, default='exact_cfg_guided',
        help="Inference sampling methods. [exact_cfg_guided, cm_multistep_cfg_guided, gamma_multistep_cfg_guided]"
    )
    parser.add_argument(
        "--sampling_gamma", type=float, default=0.9,
        help="\gamma for CTM's gamma-sampling"
    )
    parser.add_argument(
        "--test_file", type=str, default="data/test_audiocaps_subset.json",
        help="json file containing the test prompts for generation."
    )
    # parser.add_argument(
    #     "--text_key", type=str, default="caption",
    #     help="Key containing the text in the json file."
    # )
    parser.add_argument(
        "--test_references", type=str, default="data/audiocaps_test_references/subset",
        help="Folder containing the test reference wav files."
    )
    parser.add_argument(
        "--num_steps", type=int, default=200,
        help="How many denoising steps for generation.",
    )
    parser.add_argument(
        "--guidance", type=float, default=3,
        help="Guidance scale for classifier free guidance."
    )
    parser.add_argument(
        "--batch_size", type=int, default=8,
        help="Batch size for generation.",
    )
    parser.add_argument(
        "--num_samples", type=int, default=1,
        help="How many samples per prompt.",
    )
    parser.add_argument(
        "--sigma_data", type=float, default=0.25,
        help="Sigma data",
    )
    parser.add_argument(
        "--stocastic", type=bool, default=False,
        help="Enable stocastic sampling of EDM Heun sampler",
    )
    parser.add_argument(
        "--prefix", type=str, default=None,
        help="Add prefix in text prompts.",
    )
    parser.add_argument(
        "--stage1_ckpt", type=str, default='ckpt/audioldm-s-full.ckpt',
        help="Path for stage1 model's checkpoint",
    )
    parser.add_argument(
        "--sdr", type=float, default=0.15,
    )
    parser.add_argument(
        "--save_dir", type=str, default="/path/to/"
    )
    parser.add_argument(
        "--sdedit", type=bool, default=False,
    )
    parser.add_argument(
        "--loss_guided_sampling", type=bool, default=False,
    )
    parser.add_argument(
        "--initial_index", type=int, default=40,
    )
    parser.add_argument(
        "--fs_start", type=float, default=0.20,
    )
    parser.add_argument(
        "--fs_end", type=float, default=0.70,
    )
    
    parser.add_argument(
        "--optimization_steps", type=int, default=200,
    )
    parser.add_argument(
        "--task", type=str, default='declip',
    )
    parser.add_argument(
        "--curve_shape", type=str, default='linear-down',
    )
    parser.add_argument(
        "--generated_length", type=int, default=160000,
    )
    
    args = parser.parse_args()

    return args

def main():
    args = parse_args()
    if args.seed is not None:
        set_seed(args.seed)
        rand_fix(args.seed)
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print("GPU is available. Using GPU...")
    else:
        device = torch.device("cpu")
        print("GPU is not available. Using CPU...")
    
    train_args = dotdict(json.loads(open(args.training_args).readlines()[0]))
    # sisdr = auraloss.time.SISDRLoss()
    # if "hf_model" not in train_args:
    #     train_args["hf_model"] = None
    
    # Load Models #
    # if train_args.hf_model:
    #     tango = Tango(train_args.hf_model, "cpu")
    #     vae, stft, model = tango.vae.cuda(), tango.stft.cuda(), tango.model.cuda()
    # else:
    name = "audioldm-s-full"
    vae, stft = build_pretrained_models(name, args.stage1_ckpt)
    vae, stft = vae.to(device), stft.to(device)
    vae.requires_grad_(False)
    stft.requires_grad_(False)
    
    vae.eval()
    stft.eval()
    
    model, diffusion = create_model_and_diffusion(train_args, teacher=False)
    model.requires_grad_(False)
    model.to(device)
    model.eval()
    
    # target_model, _ = create_model_and_diffusion(train_args)
    # target_model.cuda()
    # target_model.eval()

    # state_dict = torch.load(args.ema_model)
    # target_model.load_state_dict(state_dict, strict=False)
    
    # Load Trained Weight #
    model.load_state_dict(torch.load(args.model))
    ema_ckpt = torch.load(args.ema_model)
    state_dict = model.ctm_unet.state_dict()
    for i, (name, _value) in enumerate(model.ctm_unet.named_parameters()):
        assert name in state_dict
        state_dict[name] = ema_ckpt[name]
    model.ctm_unet.load_state_dict(state_dict, strict=False)
    del ema_ckpt
    del state_dict
    torch.cuda.empty_cache()
    # evaluator = EvaluationHelper(16000, "cuda:0")
    
    # if args.num_samples > 1:
    #     clap = ClapModel.from_pretrained("laion/clap-htsat-unfused").to(device)
    #     clap.eval()
    #     clap_processor = AutoProcessor.from_pretrained("laion/clap-htsat-unfused")
    
    # def audio_text_matching(waveforms, text, sample_freq=16000, max_len_in_seconds=10):
    #     new_freq = 48000
    #     resampled = []
        
    #     for wav in waveforms:
    #         x = torchaudio.functional.resample(torch.tensor(wav, dtype=torch.float).reshape(1, -1), orig_freq=sample_freq, new_freq=new_freq)[0].numpy()
    #         resampled.append(x[:new_freq*max_len_in_seconds])

    #     inputs = clap_processor(text=text, audios=resampled, return_tensors="pt", padding=True, sampling_rate=48000)
    #     inputs = {k: v.to(device) for k, v in inputs.items()}

    #     with torch.no_grad():
    #         outputs = clap(**inputs)

    #     logits_per_audio = outputs.logits_per_audio
    #     ranks = torch.argsort(logits_per_audio.flatten(), descending=True).cpu().numpy()
    #     return ranks
    
    # Load Data #
    if args.prefix:
        prefix = args.prefix
    else:
        prefix = ""
        
    # text_prompts = [json.loads(line)[args.text_key] for line in open(args.test_file).readlines()]
    with open(args.test_file, mode='r', encoding='utf-8') as file:
        reader = csv.DictReader(file)
        text_prompts = [row['caption'] for row in reader]
    text_prompts = [prefix + inp for inp in text_prompts]
    with open(args.test_file, mode='r', encoding='utf-8') as file:
        reader = csv.DictReader(file)
        file_names = [row['file_name'] for row in reader]
    # Generate #
    num_steps, guidance, batch_size, num_samples = args.num_steps, args.guidance, args.batch_size, args.num_samples
    all_outputs = []
    loss_fn = nn.MSELoss()
    total_loss = 0.
    os.makedirs(args.save_dir, exist_ok=True)
    os.makedirs(f"{args.save_dir}/{args.curve_shape}/", exist_ok=True)
    os.makedirs(f"{args.save_dir}/{args.curve_shape}/plot/", exist_ok=True)
    for k in tqdm(range(0, len(text_prompts), batch_size)):
        text = text_prompts[k: k+batch_size]
        measurements_mel, _, original_wav = torch_tools.wav_to_fbank_inference(file_names[k], 1024, stft)
        if args.task == 'intensity_control':
            degradation_function = Intensity(ctx_window=28, device=device)
            ref_curve = degradation_function(original_wav[:, :args.generated_length].to(device))
            if args.curve_shape == 'linear-down':
                measurements_wav = torch.linspace(-15, -40, ref_curve.shape[-1]).unsqueeze(0)
            elif args.curve_shape == 'linear-up':
                measurements_wav = torch.linspace(-40, -15, ref_curve.shape[-1]).unsqueeze(0)
            elif args.curve_shape == 'linear-down-up':
                first_half = torch.linspace(-15, -40, ref_curve.shape[-1] // 2)
                second_half = torch.linspace(-40, -15, (ref_curve.shape[-1]+1)//2)
                measurements_wav = torch.cat([first_half, second_half])
            elif args.curve_shape == 'linear-up-down':
                first_half = torch.linspace(-40, -15, ref_curve.shape[-1] // 2)
                second_half = torch.linspace(-15, -40, (ref_curve.shape[-1]+1)//2)
                measurements_wav = torch.cat([first_half, second_half])
            elif args.curve_shape == 'curve-up':
                measurements_wav = curve(ref_curve.shape[-1], -40, -15)
            elif args.curve_shape == 'curve-down':
                measurements_wav = curve(ref_curve.shape[-1], -15, -40)      
            
        print(os.path.basename(file_names[k]))
        # with torch.no_grad():
        latents = karras_sample_guided(
            diffusion=diffusion,
            model=model,
            stage1 = vae,
            stft = stft,
            degradation_function = degradation_function,
            measurements_wav=measurements_wav.to(device), 
            measurements_mel=measurements_mel.to(device), 
            task = args.task,
            sdedit = args.sdedit,
            target_length = args.generated_length,
            initial_index = args.initial_index,
            fs_start=args.fs_start,
            fs_end=args.fs_end,
            optimization_steps = args.optimization_steps,
            shape=(batch_size, train_args.latent_channels, train_args.latent_t_size, train_args.latent_f_size),
            steps=num_steps,
            cond=text,
            guidance_scale=guidance,
            loss_guided_sampling=args.loss_guided_sampling,
            model_kwargs={},
            device=device,
            clip_denoised=False,
            sampler=args.sampler,
            gamma=args.sampling_gamma,
            # generator=self.generator,
            teacher=False,
            ctm=True,
            x_T=None,
            clip_output=False,
            sigma_min=train_args.sigma_min,
            sigma_max=train_args.sigma_max,
            train=False,
        )
        mel = vae.decode_first_stage(latents)
        restored_wav = vae.decode_to_waveform(mel).detach()
        wave = (restored_wav.cpu().numpy() * 32768).astype("int16")
        wave = wave.squeeze()[:args.generated_length]
        if args.task == 'intensity_control':
            MSE_error = loss_fn(degradation_function(restored_wav[:, :args.generated_length]), measurements_wav.to(device))
            total_loss += MSE_error
            print("MSE_error", MSE_error)
            plot = degradation_function(restored_wav[:, :args.generated_length]).squeeze().to('cpu').numpy()
            # plt.figure(figsize=(10, 5)) 
            plt.plot(plot[14:-14])
            # plt.title('Time Series Data')
            # plt.xlabel('Time')
            # plt.ylabel('Value')
            plt.savefig(f'{args.save_dir}/{args.curve_shape}/plot/results_plot_{os.path.basename(file_names[k])}.pdf')
            plt.close() 
            
        sf.write(f'{args.save_dir}/{args.curve_shape}/restored_{os.path.basename(file_names[k])}', wave, samplerate=16000)
        all_outputs += [item for item in wave]
    print("total_loss", total_loss/k)
    # breakpoint()
    # Save #
    # exp_id = str(int(time.time()))
    # if not os.path.exists(args.output_dir):
    #     os.makedirs(args.output_dir)
    
    # if num_samples == 1:
    #     output_dir = "outputs/{}_steps_{}_guidance_{}_seed_{}".format(exp_id, num_steps, guidance, args.seed)
    #     output_dir = os.path.join(args.output_dir, output_dir)
    #     os.makedirs(output_dir, exist_ok=True)
    #     for j, wav in enumerate(all_outputs):
    #         filename = os.path.splitext(os.path.basename(file_names[j]))[0]
    #         sf.write("{}/{}.wav".format(output_dir, filename), wav, samplerate=16000)
    #     result = evaluator.main(output_dir, args.test_references)
    #     result["Steps"] = num_steps
    #     result["Guidance Scale"] = guidance
    #     result["Test Instances"] = len(text_prompts)
    #     wandb.log(result)
        
    #     # result["scheduler_config"] = dict(scheduler.config)
    #     result["args"] = dict(vars(args))
    #     result["output_dir"] = output_dir

    #     with open(os.path.join(output_dir,'summary.jsonl'), "a") as f:
    #         f.write(json.dumps(result) + "\n\n")
            
    # else:
    #     for i in range(num_samples):
    #         output_dir = "outputs/{}_steps_{}_guidance_{}_seed_{}/rank_{}".format(exp_id, num_steps, guidance, args.seed, i+1)
    #         output_dir = os.path.join(args.output_dir, output_dir)
    #         os.makedirs(output_dir, exist_ok=True)
        
    #     groups = list(chunks(all_outputs, num_samples))
    #     for k in tqdm(range(len(groups))):
    #         wavs_for_text = groups[k]
    #         rank = audio_text_matching(wavs_for_text, text_prompts[k])
    #         ranked_wavs_for_text = [wavs_for_text[r] for r in rank]
            
    #         for i, wav in enumerate(ranked_wavs_for_text):
    #             output_dir = "outputs/{}_{}_steps_{}_guidance_{}/rank_{}".format(exp_id, "_".join(args.model.split("/")[1:-1]), num_steps, guidance, i+1)
    #             sf.write("{}/output_{}.wav".format(output_dir, k), wav, samplerate=16000)
            
    #     # Compute results for each rank #
    #     for i in range(num_samples):
    #         output_dir = "outputs/{}_{}_steps_{}_guidance_{}/rank_{}".format(exp_id, "_".join(args.model.split("/")[1:-1]), num_steps, guidance, i+1)
    #         result = evaluator.main(output_dir, args.test_references)
    #         result["Steps"] = num_steps
    #         result["Guidance Scale"] = guidance
    #         result["Instances"] = len(text_prompts)
    #         result["clap_rank"] = i+1
            
    #         wb_result = copy.deepcopy(result)
    #         wb_result = {"{}_rank{}".format(k, i+1): v for k, v in wb_result.items()}
    #         wandb.log(wb_result)
            
    #         # result["scheduler_config"] = dict(scheduler.config)
    #         result["args"] = dict(vars(args))
    #         result["output_dir"] = output_dir

    #         with open("outputs/summary.jsonl", "a") as f:
    #             f.write(json.dumps(result) + "\n\n")
        
if __name__ == "__main__":
    main()