import os
import json
import torch
import librosa
import hydra
import numpy as np
import pretty_midi
import decimal
from glob import glob
from tqdm import tqdm
from scipy.stats import ttest_rel, wilcoxon
from inference_error import InferenceHandler
from evaluate_errors import evaluate_main
from test_errors_old import get_scores as get_scores_B
from test_laddersym import get_scores as get_scores_A
from evaluate_errors_old import evaluate_main as evaluate_main_B


# ---------- Utility Functions ----------
def _load_MAESTRO_split_info(json_path):
    with open(json_path, "r") as f:
        data = json.load(f)
    midi_filename_to_number = {
        os.path.basename(path).replace(".midi", ""): str(number)
        for number, path in data["midi_filename"].items()
    }
    split_to_numbers = {split: set() for split in set(data["split"].values())}
    for number, split in data["split"].items():
        split_to_numbers[split].add(str(number))
    return midi_filename_to_number, split_to_numbers


def normalize(x):
    return np.zeros_like(x) if np.sum(x, axis=0) == 0 else x / np.sum(x, axis=0)


def manual_chroma(piano_roll, start_note=24):
    note_indices = (np.arange(piano_roll.shape[0]) + start_note) % 12
    chroma = np.zeros((12, piano_roll.shape[1]))
    for pitch in range(12):
        chroma[pitch] = np.sum(piano_roll[note_indices == pitch, :], axis=0)
    return chroma


def apply_normalization(arr):
    x = arr.T
    for i in range(len(x)):
        x[i] = normalize(x[i])
    return x.T


def process_audio_and_midi(audio_path, midi_path):
    output_path = midi_path.replace(".mid", "_aligned.mid")
    if os.path.exists(output_path):
        print(f"File {output_path} already exists. Skipping processing.")
        return
    y, sr = librosa.load(audio_path)
    FS = 10
    hop_length = int(sr / FS)
    audio_chroma = apply_normalization(
        librosa.feature.chroma_cqt(y=y, sr=sr, hop_length=hop_length)
    )
    midi_data = pretty_midi.PrettyMIDI(midi_path)
    piano_roll = midi_data.get_piano_roll(fs=FS)[24 : 24 + 84]
    midi_chroma = apply_normalization(manual_chroma(piano_roll))
    from librosa.sequence import dtw

    _, wp = dtw(audio_chroma, midi_chroma, subseq=True, backtrack=True)
    adjusted_midi = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(program=midi_data.instruments[0].program)
    for note in midi_data.instruments[0].notes:
        midi_frame = int(note.start * FS)
        audio_frame = wp[0][0]
        for i in range(len(wp)):
            if wp[i][1] <= midi_frame:
                audio_frame = wp[min(len(wp) - 1, i + 1)][0]
                break
        adjusted_start = audio_frame / FS
        adjusted_end = adjusted_start + (note.end - note.start)
        instrument.notes.append(
            pretty_midi.Note(
                start=adjusted_start,
                end=adjusted_end,
                velocity=note.velocity,
                pitch=note.pitch,
            )
        )
    adjusted_midi.instruments.append(instrument)
    adjusted_midi.write(output_path)
    print(f"Processed and saved as {output_path}")


def _build_dataset(root_dir, json_path, split):
    midi_filename_to_number, split_to_numbers = _load_MAESTRO_split_info(json_path)
    desired_file_numbers = split_to_numbers[split]
    df, mistakes_audio_dir, scores_audio_dir = [], [], []
    extra_notes_files = {
        os.path.normpath(f).split(os.sep)[-3]: f
        for f in glob(
            os.path.join(root_dir, "label", "extra_notes", "**", "*.mid"),
            recursive=True,
        )
    }
    removed_notes_files = {
        os.path.normpath(f).split(os.sep)[-3]: f
        for f in glob(
            os.path.join(root_dir, "label", "removed_notes", "**", "*.mid"),
            recursive=True,
        )
    }
    correct_notes_files = {
        os.path.normpath(f).split(os.sep)[-3]: f
        for f in glob(
            os.path.join(root_dir, "label", "correct_notes", "**", "*.mid"),
            recursive=True,
        )
    }
    mistake_files = {
        os.path.normpath(f).split(os.sep)[-2]: f
        for f in glob(os.path.join(root_dir, "mistake", "**", "mix.*"), recursive=True)
    }
    score_files = {
        os.path.normpath(f).split(os.sep)[-2]: f
        for f in glob(os.path.join(root_dir, "score", "**", "mix.*"), recursive=True)
    }
    for track_id in extra_notes_files:
        file_number = midi_filename_to_number.get(track_id)
        if (
            file_number in desired_file_numbers
            and track_id in removed_notes_files
            and track_id in correct_notes_files
            and track_id in mistake_files
            and track_id in score_files
            and os.path.exists(mistake_files[track_id].replace(".mid", ".wav"))
        ):
            df.append(
                {
                    "track_id": track_id,
                    "extra_notes_midi": extra_notes_files[track_id],
                    "removed_notes_midi": removed_notes_files[track_id],
                    "correct_notes_midi": correct_notes_files[track_id],
                    "mistake_audio": mistake_files[track_id].replace(".mid", ".wav"),
                    "score_audio": score_files[track_id].replace(".mid", ".wav"),
                    "score_midi": score_files[track_id].replace(".wav", ".mid"),
                    "aligned_midi": score_files[track_id].replace(
                        ".wav", "_aligned.mid"
                    ),
                    "prompt": score_files[track_id].replace(".wav", ".mid"),
                }
            )
            mistakes_audio_dir.append(mistake_files[track_id].replace(".mid", ".wav"))
            scores_audio_dir.append(score_files[track_id].replace(".mid", ".wav"))
    assert len(df) > 0, "No matching files found. Check the dataset directory."
    return df, mistakes_audio_dir, scores_audio_dir


# ---------- Scoring and Statistical Testing ----------
def extract_scores(a, b):
    scores_a, scores_b = [], []
    for key in a:
        if key in b:
            a_vals, b_vals = a[key], b[key]
            if isinstance(a_vals, list) and isinstance(b_vals, list):
                min_len = min(len(a_vals), len(b_vals))
                scores_a.extend(a_vals[:min_len])
                scores_b.extend(b_vals[:min_len])
            else:
                scores_a.append(a_vals)
                scores_b.append(b_vals)
    return np.array(scores_a), np.array(scores_b)


def run_tests(score_x, score_y, label):
    x, y = extract_scores(score_x, score_y)
    if len(x) == 0 or len(y) == 0:
        print(f"Skipping {label}: no matching data.")
        return
    t, tp = ttest_rel(x, y)
    w, wp = wilcoxon(x, y)
    # right after running the tests
    print(f"\n{label}")
    print(f"  T-test   : t={t:.6f}, p={tp:.3e}")
    print(f"  Wilcoxon : W={w:.6f}, p={wp:.3e}")


@hydra.main(config_path="config", config_name="config", version_base="1.1")
def main(cfg):
    dataset, mistakes, scores = _build_dataset(
        cfg.dataset.test.root_dir, cfg.dataset.test.split_json_path, "test"
    )
    print("cfg.dataset.test.root_dir:", cfg.dataset.test.root_dir)
    prompt_dir = [d["prompt"] for d in dataset]
    if cfg.eval.eval_first_n_examples:
        n = min(cfg.eval.eval_first_n_examples, len(mistakes))
        offset = 0
        if len(mistakes) > n:
            offset = np.random.randint(0, len(mistakes) - n)
        mistakes = mistakes[offset : offset + n]
        scores = scores[offset : offset + n]
        prompt_dir = prompt_dir[offset : offset + n]
        dataset = dataset[offset : offset + n]
    for d in dataset:
        process_audio_and_midi(d["mistake_audio"], d["score_midi"])

    # Model A
    cfg.model.config.use_prompt = False
    model_cls = hydra.utils.get_class(cfg.model._target_)
    print("torch.cuda.device_count():", torch.cuda.device_count())

    pl = model_cls.load_from_checkpoint(
        cfg.path,
        config=cfg.model.config,
        optim_cfg=cfg.optim,
    )
    model_A = pl.model
    model_A.eval()

    score_A, _ = get_scores_A(
        model_A,
        mistakes_audio_dir=mistakes,
        scores_audio_dir=scores,
        prompt_dir=prompt_dir,
        mel_norm=True,
        eval_dataset=cfg.eval.eval_dataset,
        exp_tag_name=cfg.eval.exp_tag_name,
        ground_truth=dataset,
        contiguous_inference=cfg.eval.contiguous_inference,
        batch_size=cfg.eval.batch_size,
    )

    # Model B
    # cfg.model_old.config.use_prompt = False
    model_cls = hydra.utils.get_class(cfg.model_old._target_)
    print("torch.cuda.device_count():", torch.cuda.device_count())

    pl = model_cls.load_from_checkpoint(
        cfg.path_old,
        config=cfg.model_old.config,
        optim_cfg=cfg.optim,
    )
    model_B = pl.model
    model_B.eval()
    score_B, _ = get_scores_B(
        model_B,
        eval_audio_dir=mistakes,
        mel_norm=True,
        eval_dataset=cfg.eval.eval_dataset_old,
        exp_tag_name=cfg.eval.exp_tag_name,
        ground_truth=dataset,
        contiguous_inference=cfg.eval.contiguous_inference,
        batch_size=cfg.eval.batch_size,
    )

    # Model C (New, uses prompt)
    cfg.model_new.config.use_prompt = True
    model_cls = hydra.utils.get_class(cfg.model_new._target_)
    print("torch.cuda.device_count():", torch.cuda.device_count())

    pl = model_cls.load_from_checkpoint(
        cfg.path_new,
        config=cfg.model_new.config,
        optim_cfg=cfg.optim,
    )
    model_C = pl.model
    model_C.eval()
    handler = InferenceHandler(
        model=model_C,
        device=torch.device("cuda"),
        mel_norm=True,
        contiguous_inference=cfg.eval.contiguous_inference,
    )
    for m, s, p in tqdm(zip(mistakes, scores, prompt_dir)):
        m_audio, _ = librosa.load(m, sr=16000)
        s_audio, _ = librosa.load(s, sr=16000)
        outpath = os.path.join(
            cfg.eval.exp_tag_name, os.path.basename(os.path.dirname(m)), "mix.mid"
        )
        handler.inference(
            mistake_audio=m_audio,
            score_audio=s_audio,
            audio_path=os.path.basename(os.path.dirname(m)),
            prompt_path=p,
            outpath=outpath,
            batch_size=cfg.eval.batch_size,
            max_length=1024,
            verbose=True,
        )
    test_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
    score_C, _ = evaluate_main(
        cfg.eval.eval_dataset, os.path.join(test_dir, cfg.eval.exp_tag_name), dataset
    )

    # Comparisons
    run_tests(score_A, score_B, "Model A vs B")
    run_tests(score_A, score_C, "Model A vs C")
    run_tests(score_B, score_C, "Model B vs C")


if __name__ == "__main__":
    main()
