import fcntl
import os
import json
import argparse
import typing as t
import time

import optuna
import torch
import torchaudio as ta
import numpy as np

from tqdm import tqdm

from audio_utils import utils, clap_wrapper, schemas as s, audio_analysis


class SampleGenerator:
    def __init__(
            self,
            model: s.ModelWrapper,
            clap: clap_wrapper.CLAPWrapper,
            audio_analyzer: t.Optional[audio_analysis.AudioAnalysisWrapper] = None,
            mode: str = "p2p"
    ):
        self.model = model
        self.audio_analyzer = audio_analyzer
        self.clap = clap
        self.mode = mode

    def generate_samples(
            self,
            prompts: list[dict],
            attn_inject_min: float,
            attn_inject_max: float,
            attn_inject_delay_min: float,
            attn_inject_delay_max: float,
            attn_reweighting_min: float,
            attn_reweighting_max: float,
            cfg_src_min: float,
            cfg_src_max: float,
            cfg_tar_min: float,
            cfg_tar_max: float,
            tstart_min: int,
            tstart_max: int,
            n_trials: int,
            steps: int,
            steps_final: int,
            output_path: str,
            audio_path: str
    ):
        for prompt in tqdm(prompts):
            time_start = time.time()
            init_aud = None
            if self.mode == "edit":
                init_aud = os.path.join(audio_path, f"{prompt['metadata']['input_caption']['filename']}")
                if not os.path.exists(init_aud):
                    print(f"Audio {init_aud} not found. Skipping...")
                    continue

            print(f"Working on prompts: {prompt['data']['input']} / {prompt['data']['output']}")
            obj, params, best_sample = self.get_best_audio_pair(
                prompt=prompt,
                attn_inject_min=attn_inject_min,
                attn_inject_max=attn_inject_max,
                attn_inject_delay_min=attn_inject_delay_min,
                attn_inject_delay_max=attn_inject_delay_max,
                attn_reweighting_min=attn_reweighting_min,
                attn_reweighting_max=attn_reweighting_max,
                cfg_src_min=cfg_src_min,
                cfg_src_max=cfg_src_max,
                cfg_tar_min=cfg_tar_min,
                cfg_tar_max=cfg_tar_max,
                tstart_min=tstart_min,
                tstart_max=tstart_max,
                n_trials=n_trials,
                steps=steps,
                steps_final=steps_final,
                audio_path=audio_path,
                init_aud=init_aud
            )

            if obj > -np.inf:
                print(f"Found best sample with objective value of {obj} and params {params}")
                sims = self.clap.audio_prompt_pair_sim(
                    audios=best_sample,
                    sample_rate=self.model.sample_rate,
                    prompts=[prompt["data"]["input"], prompt["data"]["output"]]
                )

                # updating samples file
                sample_data = {
                    **prompt,
                    "params": params,
                    "sims": sims
                }
                print("saving...")
                self.update_samples_file(output_path=output_path, sample_data=sample_data)

                # saving audio sample
                audio_output_path = os.path.join(output_path, prompt["metadata"]["uid"])
                os.makedirs(audio_output_path, exist_ok=True)
                ta.save(uri=os.path.join(audio_output_path, "input.wav"), src=best_sample[0],
                        sample_rate=self.model.sample_rate)
                ta.save(uri=os.path.join(audio_output_path, "output.wav"), src=best_sample[1],
                        sample_rate=self.model.sample_rate)
                print(f"Done, saving to {audio_output_path}")

            else:
                print(f"No good sample found, skipping prompt...")

            print(f"Took {time.time() - time_start:.2f} seconds for prompt")

    def update_samples_file(self, output_path: str, sample_data: dict):
        json_path = os.path.join(output_path, "samples.json")
        os.makedirs(output_path, exist_ok=True)

        with open(json_path, "a+") as json_file:
            fcntl.flock(json_file, fcntl.LOCK_EX)
            json_file.seek(0)
            try:
                json_data = json.load(json_file)
            except (json.JSONDecodeError, IOError):
                json_data = {}

            # Update with the new sample
            json_data[sample_data["metadata"]["uid"]] = sample_data

            # Write the updated JSON back to the file
            json_file.seek(0)
            json_file.truncate()
            json.dump(json_data, json_file, indent=4)
            json_file.flush()
            os.fsync(json_file.fileno())
            fcntl.flock(json_file, fcntl.LOCK_UN)

    def get_best_audio_pair(
            self,
            prompt: dict,
            attn_inject_min: float,
            attn_inject_max: float,
            attn_inject_delay_min: float,
            attn_inject_delay_max: float,
            attn_reweighting_min: float,
            attn_reweighting_max: float,
            cfg_src_min: float,
            cfg_src_max: float,
            cfg_tar_min: float,
            cfg_tar_max: float,
            tstart_min: float,
            tstart_max: float,
            n_trials: int,
            steps: int,
            steps_final: int,
            audio_path: str,
            study_name: str = "study",
            obj_fct: t.Callable = None,
            init_aud=None,
    ):
        study = optuna.create_study(
            # storage=JournalStorage(JournalFileBackend("data/optuna.log")),
            direction="maximize",
            # study_name=study_name
        )
        study.set_user_attr(key="input_prompt", value=prompt["data"]["input"])
        study.set_user_attr(key="output_prompt", value=prompt["data"]["output"])
        # init_aud = os.path.join(audio_path, f"{prompt['metadata']['input_caption']['audiocap_id']}.wav")
        suggested_params = {}

        if obj_fct is None:
            obj_fct = lambda sims: np.mean(
                (8 * sims["sim_1"], 14 * sims["sim_dir"], 0.5 * sims["sim_audio_clap"], 1.5 * sims["sim_audio_mel"]))

        def objective_p2p(trial, candidate: s.Candidate):
            time_start = time.time()
            attn_inject_frac = trial.suggest_float("attn_inject_frac", low=attn_inject_min, high=attn_inject_max)
            attn_inject_delay = trial.suggest_float(
                "attn_inject_delay",
                low=attn_inject_delay_min,
                high=min(attn_inject_delay_max, 1 - attn_inject_frac),
            )
            attn_reweighting = trial.suggest_float("attn_reweighting", low=attn_reweighting_min,
                                                   high=attn_reweighting_max)
            # print(f"Range attn_inject_frac: ({attn_inject_min}, {attn_inject_max}) chose {attn_inject_frac}")
            # print(f"Range attn_inject_delay: ({attn_inject_delay_min}, {min(attn_inject_delay_max, 1 - attn_inject_frac)}) chose {attn_inject_delay}")

            unique_key = f"{attn_inject_frac}_{attn_inject_delay}_{attn_reweighting}"
            if unique_key in suggested_params:
                raise optuna.exceptions.TrialPruned()
            else:
                suggested_params[unique_key] = 1

            print(prompt["data"])
            sample = self.get_audio_pair_p2p(
                input_caption=prompt["data"]["input"],
                output_caption=prompt["data"]["output"],
                negative_input_caption=prompt["data"]["input_neg"],
                negative_output_caption=prompt["data"]["output_neg"],
                attn_inject_frac=attn_inject_frac,
                attn_inject_delay=attn_inject_delay,
                attn_reweighting=attn_reweighting,
                guidance_scale=candidate["guidance"],
                seed=candidate["seed"],
                steps=steps
            )

            sims = self.clap.audio_prompt_pair_sim(
                audios=sample,
                sample_rate=self.model.sample_rate,
                prompts=[prompt["data"]["input"], prompt["data"]["output"]]
            )
            print(f"Took {time.time() - time_start:.2f} seconds for trial.")
            return obj_fct(sims=sims)

        def objective_edit(trial):
            time_start = time.time()
            cfg_src = trial.suggest_float("cfg_src", low=cfg_src_min, high=cfg_src_max)
            cfg_tar = trial.suggest_float("cfg_tar", low=cfg_tar_min, high=cfg_tar_max)
            tstart = trial.suggest_int("tstart", low=tstart_min, high=tstart_max)
            seed = trial.suggest_int("seed", low=0, high=5)

            unique_key = f"{cfg_src}_{cfg_tar}_{tstart}"
            if unique_key in suggested_params:
                raise optuna.exceptions.TrialPruned()
            else:
                suggested_params[unique_key] = 1

            print(f"{prompt['data']} using audio from {init_aud}")
            sample = self.model.edit_audio(
                input_caption=prompt["data"]["input"],
                output_caption=prompt["data"]["output"],
                negative_output_caption=prompt["data"]["output_neg"],
                cfg_scale=cfg_src,
                cfg_tar=cfg_tar,
                tstart=tstart,
                seed=seed,
                steps=steps,
                init_aud=init_aud,
            )
            print("sample", sample.shape)

            sims = self.clap.audio_prompt_pair_sim(
                audios=sample,
                sample_rate=self.model.sample_rate,
                prompts=[prompt["data"]["input"], prompt["data"]["output"]]
            )
            print(f"Took {time.time() - time_start:.2f} seconds for trial.")
            return obj_fct(sims=sims)

        if self.mode == "p2p":
            study.optimize(lambda trial: objective_p2p(trial, prompt["candidate"]), n_trials=n_trials)
            best_sample = self.get_audio_pair_p2p(
                input_caption=prompt["data"]["input"],
                output_caption=prompt["data"]["output"],
                negative_input_caption=prompt["data"]["input_neg"],
                negative_output_caption=prompt["data"]["output_neg"],
                attn_inject_frac=study.best_trial.params["attn_inject_frac"],
                attn_inject_delay=study.best_trial.params["attn_inject_delay"],
                attn_reweighting=study.best_trial.params["attn_reweighting"],
                guidance_scale=prompt["candidate"]["guidance"],
                seed=prompt["candidate"]["seed"],
                steps=steps_final,
            )
        elif self.mode == "edit":
            study.optimize(lambda trial: objective_edit(trial), n_trials=n_trials)
            best_sample = self.model.edit_audio(
                input_caption=prompt["data"]["input"],
                output_caption=prompt["data"]["output"],
                negative_output_caption=prompt["data"]["output_neg"],
                cfg_scale=study.best_trial.params["cfg_src"],
                cfg_tar=study.best_trial.params["cfg_tar"],
                tstart=study.best_trial.params["tstart"],
                seed=study.best_trial.params["seed"],
                steps=steps_final,
                init_aud=init_aud,
            )
        else:
            raise NotImplementedError(f"Mode {self.mode} not implemented.")

        if self.audio_analyzer:
            audio_analyzer_tries = 3
            for i in range(audio_analyzer_tries):
                print(f"Checking audio analysis ({i} / {audio_analyzer_tries} tries):")
                audio_analyzer_results = self.audio_analyzer.does_audio_match_prompt(
                    prompts=[prompt["data"]["output"]],
                    audio=best_sample[1].unsqueeze(dim=0),
                    sample_rate=self.model.sample_rate,
                    print_status=True
                )
                print(audio_analyzer_results)
                if np.all(audio_analyzer_results).item():
                    return study.best_trial.value, study.best_trial.params, best_sample

            return -np.inf, study.best_trial.params, best_sample

        return study.best_trial.value, study.best_trial.params, best_sample

    def get_audio_pair_p2p(
            self,
            input_caption: str,
            output_caption: str,
            attn_inject_frac: float,
            attn_inject_delay: float,
            guidance_scale: float,
            steps: int,
            seed: int,
            attn_reweighting: float = 1,
            negative_input_caption: str = None,
            negative_output_caption: str = None,
    ) -> torch.Tensor:
        self.model.set_p2p_prop(attn_inject_frac=attn_inject_frac, attn_inject_delay=attn_inject_delay,
                                attn_reweighting=attn_reweighting)
        audio_pair = self.model.generate_edited_audio(
            input_caption=input_caption,
            output_caption=output_caption,
            seed=seed,
            cfg_scale=guidance_scale,
            steps=steps,
            negative_input_caption=negative_input_caption,
            negative_output_caption=negative_output_caption
        )

        return audio_pair


def main():
    parser = argparse.ArgumentParser(description="Generate audio based on prompt file")

    # Model
    parser.add_argument("--model-path", help="Path to model checkpoint", required=False)
    parser.add_argument("--config-path", help="Path to config file", required=False)
    parser.add_argument("--clap-path", help="Path to CLAP checkpoint", required=True)

    # Candidates
    parser.add_argument("--candidate-file", help="Path of prompt file", required=False)

    # Audio Editing
    parser.add_argument("--prompt-file", help="Path of prompt file", required=False)
    parser.add_argument("--audio-path", help="Base path for audio files", required=False, type=str)

    # Filtering
    parser.add_argument("--n-chunks", help="Number of chunks", required=False, default=1, type=int)
    parser.add_argument("--curr-chunk", help="Current chunk", required=False, default=0, type=int)
    parser.add_argument('--skip-speech', help="Whether to skip prompts containing speech", default=False,
                        action=argparse.BooleanOptionalAction)
    parser.add_argument('--skip-n-elements', help="Whether to skip prompts containing n elements or more", default=3,
                        type=int)

    # Other
    parser.add_argument("--n-trials", help="Number of optuna trials", required=False, default=30, type=int)
    parser.add_argument("--steps", help="Number of inference steps for optimization", required=False, default=50,
                        type=int)
    parser.add_argument("--steps-final", help="Number of inference steps for the saved audio file", required=False,
                        default=100, type=int)
    parser.add_argument("--mode", help="Whether to use P2P or audio editing", default="p2p", type=str,
                        choices=["p2p", "edit"])
    parser.add_argument("--output-path", help="Output path", required=True)
    parser.add_argument("--gemini", help="Whether to use gemini", default=False, action=argparse.BooleanOptionalAction)

    # P2P
    parser.add_argument("--attn-inject-min", help="Minimum of attention injection", required=False, default=0.1,
                        type=float)
    parser.add_argument("--attn-inject-max", help="Maximum of attention injection", required=False, default=0.9,
                        type=float)
    parser.add_argument("--attn-inject-delay-min", help="Minimum of attention injection delay", required=False,
                        default=0.0, type=float)
    parser.add_argument("--attn-inject-delay-max", help="Maximum of attention injection delay", required=False,
                        default=0.6, type=float)
    parser.add_argument("--attn-reweighting-min", help="Minimum of attention reweighting", required=False, default=1,
                        type=float)
    parser.add_argument("--attn-reweighting-max", help="Maximum of attention reweighting", required=False, default=1.8,
                        type=float)

    # AudioEdit
    parser.add_argument("--cfg-src-min", help="Minimum classifier-free guidance strength for forward process",
                        required=False, default=1, type=float)
    parser.add_argument("--cfg-src-max", help="Maximum classifier-free guidance strength for forward process",
                        required=False, default=3, type=float)
    parser.add_argument("--cfg-tar-min", help="Minimum classifier-free guidance strength for reverse process",
                        required=False, default=3, type=float)
    parser.add_argument("--cfg-tar-max", help="Maximum classifier-free guidance strength for reverse process",
                        required=False, default=10, type=float)
    parser.add_argument("--tstart-min",
                        help="Minimum of diffusion timestep to start the reverse process from. Controls editing strength.",
                        required=False, default=18, type=int)
    parser.add_argument("--tstart-max",
                        help="Maximum of diffusion timestep to start the reverse process from. Controls editing strength.",
                        required=False, default=65, type=int)

    args = parser.parse_args()
    if args.mode == "edit":
        assert args.tstart_max <= args.steps

    # check output dir
    valid_config = utils.check_config_match(args)
    if not valid_config:
        print("Config did not match, exiting...")
        return

    if args.mode == "p2p":
        # for prompt-to-prompt, candidates are already generated that contain seed and cfg value
        with open(args.candidate_file, "r") as json_file:
            candidate_data = json.load(json_file)
            prompt_ids = sorted(candidate_data.keys())
        print(f"Number of candidates: {len(prompt_ids)}")
    else:
        # for audio editing, no candidates are generated
        with open(args.prompt_file, "r") as jsonl_file:
            prompts = [json.loads(x) for x in jsonl_file.readlines()]
            prompt_ids = sorted([x["metadata"]["uid"] for x in prompts])
        print(f"Number of prompts: {len(prompt_ids)}")

    # filtering already processed samples
    samples_path = os.path.join(args.output_path, "samples.json")
    try:
        with open(samples_path, "r") as json_file:
            samples_data = json.load(json_file)
            processed_ids = sorted(samples_data.keys())
            prompt_ids = [x for x in prompt_ids if x not in processed_ids]
    except IOError:
        print("No samples found, prompts are not filtered...")
        pass

    print("Number of prompts after filtering: ", len(prompt_ids))


    if args.mode == "p2p":
        prompts = [candidate_data[x] for x in prompt_ids]
    else:
        prompts = [x for x in prompts if x["metadata"]["uid"] in prompt_ids]

    # optionally filter out speech and complex prompts
    if args.skip_speech:
        prompts = [x for x in prompts if not x["data"]["speech"]]
    print(f"Number of prompts after speech filtering: {len(prompts)}")
    if args.skip_n_elements:
        prompts = [x for x in prompts if x["data"]["n_elements"] < args.skip_n_elements]
    print(f"Number of prompts after complex filtering: {len(prompts)}")

    prompts = [x for x in np.array_split(prompts, args.n_chunks)[args.curr_chunk].tolist()]
    # prompts = prompts[args.start_index:args.end_index]
    print(f"Number of prompts after chunking: {len(prompts)}")

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Loading Models
    if args.mode == "p2p":
        from audio_generation.p2p_sao import P2PStableAudioOpen
        model: s.ModelWrapper = P2PStableAudioOpen(device=device, model_path=args.model_path,
                                                   config_path=args.config_path, half_precision=True)
    elif args.mode == "edit":
        from audio_generation.edit_sao import EditStableAudioOpen
        model: s.ModelWrapper = EditStableAudioOpen(device=device, steps=args.steps, half_precision=True)
    else:
        raise NotImplementedError(f"Mode '{args.mode}' is not supported")
    clap = clap_wrapper.CLAPWrapper(clap_path=args.clap_path, device=device)

    audio_analyzer: t.Optional[audio_analysis.AudioAnalysisWrapper] = None
    if args.gemini:
        audio_analyzer: audio_analysis.AudioAnalysisWrapper = audio_analysis.GeminiAudioWrapper()

    sample_generator = SampleGenerator(model=model, audio_analyzer=audio_analyzer, clap=clap, mode=args.mode)

    with torch.no_grad():
        sample_generator.generate_samples(
            prompts=prompts,
            attn_inject_min=args.attn_inject_min,
            attn_inject_max=args.attn_inject_max,
            attn_inject_delay_min=args.attn_inject_delay_min,
            attn_inject_delay_max=args.attn_inject_delay_max,
            attn_reweighting_min=args.attn_reweighting_min,
            attn_reweighting_max=args.attn_reweighting_max,
            cfg_src_min=args.cfg_src_min,
            cfg_src_max=args.cfg_src_max,
            cfg_tar_min=args.cfg_tar_min,
            cfg_tar_max=args.cfg_tar_max,
            tstart_min=args.tstart_min,
            tstart_max=args.tstart_max,
            n_trials=args.n_trials,
            steps=args.steps,
            steps_final=args.steps_final,
            output_path=args.output_path,
            audio_path=args.audio_path
        )


if __name__ == "__main__":
    main()
