import copy
import os
import json
import random
import argparse
import time
import typing as t

import numpy as np
import torch
import torchaudio as ta
from tqdm import tqdm
from jsonschema import exceptions as schema_exceptions
import fcntl

from audio_generation.p2p_sao import P2PStableAudioOpen
from audio_utils import utils, clap_wrapper, audio_analysis, schemas as s


def batch(seq, batch_size=1):
    seq_len = len(seq)
    for i in range(0, seq_len, batch_size):
        yield seq[i:min(i + batch_size, seq_len)]



def conditioning_from_batched_prompt_list(batched_prompt_list: list[dict]):
    conditioning = []
    negative_conditioning = []

    for prompt in batched_prompt_list:
        conditioning.append({"prompt": prompt["data"]["input"], "seconds_start": 0, "seconds_total": 10})
        negative_conditioning.append({"prompt": prompt["data"]["input_neg"], "seconds_start": 0, "seconds_total": 10})
        conditioning.append({"prompt": prompt["data"]["output"], "seconds_start": 0, "seconds_total": 10})
        negative_conditioning.append({"prompt": prompt["data"]["output_neg"], "seconds_start": 0, "seconds_total": 10})

    return conditioning, negative_conditioning


def update_candidates_file(output_path: str, prompt_uid: str, best_candidate: dict):
    """
    Atomically update the candidates.json file using a single read-modify-write cycle
    with an exclusive lock.
    """
    json_path = os.path.join(output_path, "candidates.json")
    os.makedirs(output_path, exist_ok=True)

    with open(json_path, "a+") as json_file:
        fcntl.flock(json_file, fcntl.LOCK_EX)
        json_file.seek(0)
        try:
            json_data = json.load(json_file)
        except (json.JSONDecodeError, IOError):
            json_data = {}

        candidate_copy = copy.deepcopy(best_candidate)
        candidate_copy["candidate"].pop("audios", None)
        json_data[prompt_uid] = candidate_copy

        json_file.seek(0)
        json_file.truncate()
        json.dump(json_data, json_file, indent=4)
        json_file.flush()
        os.fsync(json_file.fileno())
        fcntl.flock(json_file, fcntl.LOCK_UN)


def generate_candidates_from_prompts(
        model: s.ModelWrapper,
        audio_analyzer: t.Optional[audio_analysis.AudioAnalysisWrapper],
        clap: clap_wrapper.CLAPWrapper,
        prompt_list: list,
        batch_size: int,
        output_path: str,
        candidate_gens: int = 10,
        guidance_scale_min: float = 6.0,
        guidance_scale_max: float = 9.0,
        device: str = "cuda",
        save_audio: bool = False,
        steps: int = 100
):
    for batched_prompt_list in batch(prompt_list, batch_size=batch_size // 2):
        time_prompts = time.time()
        candidates = {}
        conditioning, negative_conditioning = conditioning_from_batched_prompt_list(batched_prompt_list)
        prompts = [x["prompt"] for x in conditioning]

        print(f"Working on prompts: {prompts}")
        for i in range(candidate_gens):
            time_candidate = time.time()
            print(f"- Candidate {i+1}/{candidate_gens}: Generating...")
            guidance_scale = round(random.uniform(guidance_scale_min, guidance_scale_max), 2)
            seed = random.randint(a=0, b=1_000_000)
            audios = model.generate_audio(
                conditioning=conditioning,
                negative_conditioning=negative_conditioning,
                cfg_scale=guidance_scale,
                seed=seed,
                steps=steps
            )

            time_audio_analyzer = time.time()
            audio_results = audio_analyzer.does_audio_match_prompt(
                prompts=prompts,
                audio=audios,
                sample_rate=model.sample_rate,
                print_status=True
            )
            print(f"- Candidate {i+1}/{candidate_gens}: Audio Analyzer took {time.time() - time_audio_analyzer:.2f} seconds")

            sim = clap.audio_prompt_sim(audios=audios, sr=model.sample_rate, prompts=prompts)
            print(f"- Candidate {i+1}/{candidate_gens}: {sim}, {audio_results}, {audios.shape}")

            for pair_index, prompt in enumerate(batched_prompt_list):
                if prompt["metadata"]["uid"] not in candidates:
                    candidates[prompt["metadata"]["uid"]] = {
                        "prompt": {
                            **prompt
                        },
                        "candidates": []
                    }

                start_index = pair_index * 2
                end_index = start_index + 2
                candidates[prompt["metadata"]["uid"]]["candidates"].append({
                    "seed": seed,
                    "guidance": guidance_scale,
                    "audio_analyzer": np.all(audio_results[start_index:end_index]).item(),
                    "sim": np.mean(sim[start_index:end_index]).item(),
                    "audios": audios[start_index:end_index, ...].cpu(),
                })

            # Clear GPU memory after each generation
            del audios
            torch.cuda.empty_cache()
            print(f"- Candidate {i+1}/{candidate_gens}: Took {time.time() - time_candidate:.2f} seconds for candidate")

        assert len(candidates.keys()) == len(batched_prompt_list)

        for prompt_uid, prompt_candidates in candidates.items():
            prompt_data = prompt_candidates['prompt']['data']
            # we need to find the best candidate
            best_candidate = None
            max_sim = -np.inf
            for candidate in prompt_candidates["candidates"]:
                if candidate["audio_analyzer"] and candidate["sim"] > max_sim:
                    print(f"Found new best candidate for prompt pair {prompt_uid} with ({prompt_data['input']} / {prompt_data['output']}): {candidate['audio_analyzer']}, {candidate['sim']}")
                    best_candidate = {
                        **prompt_candidates["prompt"],
                        "candidate": candidate
                    }
                    break

            if best_candidate is None:
                print(f"No candidate found for prompt pair {prompt_uid} with ({prompt_data['input']} / {prompt_data['output']})")
            else:
                update_candidates_file(output_path, prompt_uid, best_candidate)

                # optionally saving audio
                if save_audio:
                    audio_path = os.path.join(output_path, prompt_uid)
                    os.makedirs(audio_path, exist_ok=True)
                    ta.save(uri=os.path.join(audio_path, "input.wav"), src=best_candidate["candidate"]["audios"][0], sample_rate=model.sample_rate)
                    ta.save(uri=os.path.join(audio_path, "output.wav"), src=best_candidate["candidate"]["audios"][1], sample_rate=model.sample_rate)
                    print(f"Saved audio to {audio_path}")

        print(f"Took {time.time() - time_prompts:.2f} seconds for prompts")


def main():
    parser = argparse.ArgumentParser(description="Generate audio candidates based on prompt file")

    # Model
    parser.add_argument("--model-path", help="Path to model checkpoint", required=False)
    parser.add_argument("--config-path", help="Path to config file", required=False)
    parser.add_argument("--clap-path", help="Path to CLAP checkpoint", required=True)

    # Prompts
    parser.add_argument("--prompt-file", help="Path of prompt file", required=True)
    parser.add_argument("--n-chunks", help="Number of chunks", required=False, default=1, type=int)
    parser.add_argument("--curr-chunk", help="Current chunk", required=False, default=0, type=int)
    
    # Filtering
    parser.add_argument('--skip-speech', help="Whether to skip prompts containing speech", default=False, action=argparse.BooleanOptionalAction)
    parser.add_argument('--skip-n-elements', help="Whether to skip prompts containing n elements or more", default=3, type=int)
    parser.add_argument("--gemini", help="Whether to use gemini", default=False, action=argparse.BooleanOptionalAction)

    # Other
    parser.add_argument("--output-path", help="Output path", required=True, type=str)
    parser.add_argument("--save-audio", help="Whether the generated audio should be saved", default=False, action=argparse.BooleanOptionalAction)
    parser.add_argument("--candidate-gens", help="How many candidates to generate", required=True, type=int)
    parser.add_argument("--guidance-scale-min", help="Guidance scale min", required=False, default=6, type=float)
    parser.add_argument("--guidance-scale-max", help="Guidance scale max", required=False, default=9, type=float)
    parser.add_argument("--batch-size", help="Batch size of generation. For batch size 32, 16 different prompts (input / output) will be generated at once.", required=True, type=int)
    parser.add_argument("--steps", help="How many inference steps to use.", required=False, default=100, type=int)

    # Parse the arguments
    args = parser.parse_args()

    # If we continue from previous configuration, check if same configs were used
    valid_config = utils.check_config_match(args)
    if not valid_config:
        print("Config did not match, exiting...")
        return

    with open(args.prompt_file, "r") as jsonl_file:
        text_prompts = jsonl_file.readlines()

    # Validating prompt file
    prompts = []
    print("Validating prompt file...")
    for i, text_prompt in tqdm(enumerate(text_prompts)):
        try:
            prompt = json.loads(text_prompt)
            # validate(prompt, json_schemas.file_schema)
            prompts.append(prompt)
        except json.JSONDecodeError:
            print(f"Prompt at line {i} is not valid JSON: {text_prompt}")
            return
        except schema_exceptions.ValidationError:
            print(f"Prompt at line {i} does not conform to schema: {text_prompt}")
            return
    print(f"Total number of prompts: {len(prompts)}")

    # filtering already processed prompts
    candidate_path = os.path.join(args.output_path, "candidates.json")
    try:
        with open(candidate_path, "r") as json_file:
            candidate_data = json.load(json_file)
            processed_uids = sorted(candidate_data.keys())
            prompts = [x for x in prompts if x["metadata"]["uid"] not in processed_uids]
    except IOError:
        print("No candidate found, prompts are not filtered...")
        pass

    print(f"Number of prompts after filtering: {len(prompts)}")
    if args.skip_speech:
        prompts = [x for x in prompts if not x["data"]["speech"]]
    print(f"Number of prompts after speech filtering: {len(prompts)}")
    if args.skip_n_elements:
        prompts = [x for x in prompts if x["data"]["n_elements"] < args.skip_n_elements]
    print(f"Number of prompts after complex filtering: {len(prompts)}")

    prompts = [x for x in np.array_split(prompts, args.n_chunks)[args.curr_chunk].tolist()]
    random.shuffle(prompts)
    print(f"Number of prompts after chunking: {len(prompts)}")

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Loading Models
    model: s.ModelWrapper = P2PStableAudioOpen(device=device, model_path=args.model_path, config_path=args.config_path, half_precision=True)
    clap = clap_wrapper.CLAPWrapper(clap_path=args.clap_path, device=device)

    audio_analyzer: t.Optional[audio_analysis.AudioAnalysisWrapper] = None
    if args.gemini:
        audio_analyzer: audio_analysis.AudioAnalysisWrapper = audio_analysis.GeminiAudioWrapper()

    with torch.no_grad():
        generate_candidates_from_prompts(
            model=model,
            audio_analyzer=audio_analyzer,
            clap=clap,
            prompt_list=prompts,
            batch_size=args.batch_size,
            candidate_gens=args.candidate_gens,
            guidance_scale_min=args.guidance_scale_min,
            guidance_scale_max=args.guidance_scale_max,
            device=device,
            output_path=args.output_path,
            save_audio=args.save_audio,
            steps=args.steps
        )


if __name__ == "__main__":
    main()