#!/usr/bin/env python
# test_camae_compare.py
#
# Compare three frozen encoders
#   A: new code‑path, audio‑only
#   B: old code‑path, audio‑only
#   C: new code‑path, *with symbolic‑score prompt*
#
# Outputs per‑model JSON‐scores + paired t‑test / Wilcoxon results.
# ------------------------------------------------------------------
import os, json, decimal, hydra, torch, librosa, numpy as np, pretty_midi
from glob import glob
from tqdm import tqdm
from scipy.stats import ttest_rel, wilcoxon

# ---------- your utility modules ----------
from test_laddersym_coco import get_scores as get_scores_A  # audio‑only
from test_polytune_coco_old import get_scores as get_scores_B  # legacy
from evaluate_errors_coco_old import evaluate_main as eval_Audio
from inference_error import InferenceHandler
from evaluate_errors_coco import evaluate_main as eval_Prompt

# ------------------------------------------


# ────────────────────────────────────────────────────────────────────
# 1. helpers
# ------------------------------------------------------------------
def _load_MAESTRO_split_info(json_path):
    with open(json_path, "r") as f:
        data = json.load(f)
    id2num = {
        os.path.basename(p).replace(".midi", ""): str(n)
        for n, p in data["midi_filename"].items()
    }
    split2num = {s: set() for s in set(data["split"].values())}
    for n, s in data["split"].items():
        split2num[s].add(str(n))
    return id2num, split2num


def capitalize_instrument_name(file_path):
    # Extract the base path and file name
    base_path, file_name = os.path.split(file_path)
    # Split the file name into parts before and after the first underscore
    parts = file_name.split("_", 1)  # Only split on the first underscore
    if len(parts) > 1:
        # Split the instrument name part into individual words
        instrument_name_parts = parts[1].split("_")
        # Capitalize each part of the instrument name
        capitalized_instrument_name_parts = [
            part.capitalize() for part in instrument_name_parts
        ]
        # Join the capitalized parts back with spaces
        capitalized_instrument_name = " ".join(capitalized_instrument_name_parts)
        # Replace the original instrument part with the capitalized version
        parts[1] = capitalized_instrument_name
        # Reconstruct the file name
        file_name = "_".join(
            parts
        )  # .replace('_', ' ', 1)  # Replace only the first underscore with a space
        # Ensure only one .wav extension
        file_name = file_name.replace(".wav.wav", ".wav")
    # Reconstruct the full path
    return os.path.join(base_path, file_name)


def normalize(v):
    s = np.sum(v, axis=0, keepdims=True) + 1e-6
    return v / s


def manual_chroma(piano_roll, start_note=24):
    idx = (np.arange(piano_roll.shape[0]) + start_note) % 12
    chroma = np.zeros((12, piano_roll.shape[1]))
    for p in range(12):
        chroma[p] = np.sum(piano_roll[idx == p], axis=0)
    return chroma


def apply_norm(arr):
    arr = arr.T
    for i in range(len(arr)):
        arr[i] = normalize(arr[i])
    return arr.T


def process_audio_and_midi(audio_path, midi_path, fs=10):
    out_mid = midi_path.replace(".mid", "_aligned.mid")
    if os.path.exists(out_mid):
        return
    y, sr = librosa.load(audio_path, sr=None)
    hop = int(sr / fs)
    aud_chroma = apply_norm(librosa.feature.chroma_cqt(y=y, sr=sr, hop_length=hop))
    midi = pretty_midi.PrettyMIDI(midi_path)
    roll = midi.get_piano_roll(fs=fs)[24 : 24 + 84]
    midi_chroma = apply_norm(manual_chroma(roll))
    from librosa.sequence import dtw

    _, wp = dtw(aud_chroma, midi_chroma, subseq=True, backtrack=True)

    adj = pretty_midi.PrettyMIDI()
    inst = pretty_midi.Instrument(program=midi.instruments[0].program)
    for note in midi.instruments[0].notes:
        mf = int(note.start * fs)
        af = wp[min(len(wp) - 1, max(i for i, (a, m) in enumerate(wp) if m <= mf))][0]
        s_new = af / fs
        e_new = s_new + (note.end - note.start)
        inst.notes.append(
            pretty_midi.Note(
                velocity=note.velocity, pitch=note.pitch, start=s_new, end=e_new
            )
        )
    adj.instruments.append(inst)
    adj.write(out_mid)


def _build_dataset(root_dir, json_path, split, output_json_file):
    # Load the mapping and splits
    midi_filename_to_number, split_to_numbers = _load_MAESTRO_split_info(json_path)
    print(f"Finished loading {json_path}", flush=True)
    desired_file_numbers = split_to_numbers[split]

    df = []
    mistakes_audio_dir = []
    scores_audio_dir = []

    # Patterns for file discovery
    extra_notes_dir = os.path.join(root_dir, "label", "extra_notes")
    removed_notes_dir = os.path.join(root_dir, "label", "removed_notes")
    correct_notes_dir = os.path.join(root_dir, "label", "correct_notes")
    mistake_dir = os.path.join(root_dir, "mistake")
    score_dir = os.path.join(root_dir, "score")

    print(f"Finished loading {root_dir}", flush=True)

    # Define directory mapping
    directories = {
        "extra_notes": extra_notes_dir,
        "removed_notes": removed_notes_dir,
        "correct_notes": correct_notes_dir,
        "mistake": mistake_dir,
        "score": score_dir,
    }

    def scan_and_save_paths(directories, output_json_file, batch_size=1000):
        files_dict = {key: {} for key in directories}
        batch = []
        batch_count = 0
        total_files = 0
        for key, dir_path in directories.items():
            for root, _, filenames in os.walk(dir_path):
                for filename in filenames:
                    if filename.endswith(".mid") or filename.endswith(".wav"):
                        file_path = os.path.join(root, filename)
                        # Determine the appropriate directory key based on file type
                        dir_key = (
                            os.path.normpath(file_path).split(os.sep)[-3]
                            if key in ["extra_notes", "removed_notes", "correct_notes"]
                            else os.path.normpath(file_path).split(os.sep)[-2]
                        )
                        if dir_key not in files_dict[key]:
                            files_dict[key][dir_key] = []
                        files_dict[key][dir_key].append(file_path)
                        batch.append(file_path)

                    if len(batch) >= batch_size:
                        batch_count += 1
                        total_files += len(batch)
                        print(
                            f"Processed batch {batch_count} with {len(batch)} files (Total: {total_files})",
                            flush=True,
                        )
                        batch.clear()

        if batch:  # Process remaining files in the last batch
            batch_count += 1
            total_files += len(batch)
            print(
                f"Processed batch {batch_count} with {len(batch)} files (Total: {total_files})",
                flush=True,
            )

        with open(output_json_file, "w") as json_file:
            json.dump(files_dict, json_file)
        print(f"File paths saved to {output_json_file}", flush=True)
        return files_dict

    def load_paths_from_json(output_json_file):
        with open(output_json_file, "r") as json_file:
            files_dict = json.load(json_file)
        print(f"File paths loaded from {output_json_file}", flush=True)
        return files_dict

    # Load or scan file paths
    if os.path.exists(output_json_file):
        print(f"Loading file paths from {output_json_file}", flush=True)
        files_dict = load_paths_from_json(output_json_file)
    else:
        print(f"Scanning and saving file paths to {output_json_file}", flush=True)
        files_dict = scan_and_save_paths(directories, output_json_file)

    extra_notes_files = files_dict["extra_notes"]
    removed_notes_files = files_dict["removed_notes"]
    correct_notes_files = files_dict["correct_notes"]
    mistake_files = files_dict["mistake"]
    score_files = files_dict["score"]

    # Match files based on the common identifier
    for track_id in extra_notes_files.keys():
        file_number = midi_filename_to_number.get(track_id)
        if file_number in desired_file_numbers:
            num_subtracks = len(extra_notes_files[track_id])
            if (
                track_id in removed_notes_files
                and track_id in correct_notes_files
                and track_id in mistake_files
                and track_id in score_files
            ):
                for i in range(num_subtracks):
                    if (
                        i < len(removed_notes_files[track_id])
                        and i < len(correct_notes_files[track_id])
                        and i < len(mistake_files[track_id])
                        and i < len(score_files[track_id])
                    ):
                        mistake_audio = capitalize_instrument_name(
                            mistake_files[track_id][i]
                            .replace("stems_midi", "stems_audio")
                            .replace(".mid", ".wav")
                        )
                        score_audio = capitalize_instrument_name(
                            score_files[track_id][i]
                            .replace("stems_midi", "stems_audio")
                            .replace(".mid", ".wav")
                        )

                        if os.path.exists(mistake_audio) and os.path.exists(
                            score_audio
                        ):
                            df.append(
                                {
                                    "track_id": track_id,
                                    "file_number": file_number,
                                    "extra_notes_midi": extra_notes_files[track_id][i],
                                    "removed_notes_midi": removed_notes_files[track_id][
                                        i
                                    ],
                                    "correct_notes_midi": correct_notes_files[track_id][
                                        i
                                    ],
                                    "mistake_audio": mistake_audio,
                                    "score_audio": score_audio,
                                    "prompt": score_files[track_id][i].replace(
                                        ".wav", ".mid"
                                    ),
                                    "score_midi": score_files[track_id][i]
                                    .replace("stems_audio", "stems_midi")
                                    .replace(".wav", ".mid"),
                                    "aligned_midi": score_files[track_id][i]
                                    .replace("stems_audio", "stems_midi")
                                    .replace(".wav", "_aligned.mid"),
                                }
                            )
                            mistakes_audio_dir.append(mistake_audio)
                            scores_audio_dir.append(score_audio)
                        else:
                            if not os.path.exists(mistake_audio):
                                print(f"File does not exist: {mistake_audio}")
                            if not os.path.exists(score_audio):
                                print(f"File does not exist: {score_audio}")
                    else:
                        print(
                            f"Index out of range for track_id {track_id}, subtrack index {i}"
                        )
            else:
                pass
                # print(f"Missing track data for {track_id}")
        else:
            pass
            # print(f"Track {track_id} not in desired file numbers")

    assert len(df) > 0, "No matching files found. Check the dataset directory."
    print("Total files:", len(df))
    return df, mistakes_audio_dir, scores_audio_dir


def extract_scores_nested(d1, d2, metric="F1"):
    a, b = [], []
    for idx in d1:
        if idx not in d2:
            continue
        for inst in d1[idx]:
            if inst in d2[idx] and metric in d1[idx][inst] and metric in d2[idx][inst]:
                v1 = d1[idx][inst][metric]
                v2 = d2[idx][inst][metric]
                if isinstance(v1, list) and isinstance(v2, list):
                    n = min(len(v1), len(v2))
                    a.extend(v1[:n])
                    b.extend(v2[:n])
                else:
                    a.append(v1)
                    b.append(v2)
    return np.asarray(a), np.asarray(b)


def run_tests(s1, s2, label, metric="F1"):
    print(f"Running tests for: {label} on metric: {metric}")
    x, y = extract_scores_nested(s1, s2, metric=metric)
    if len(x) == 0:
        print(f"No data to compare for {label}")
        return
    t, tp = ttest_rel(x, y)
    w, wp = wilcoxon(x, y)
    # right after running the tests
    print(f"\n{label}")
    print(f"  T-test   : t={t:.6f}, p={tp:.3e}")
    print(f"  Wilcoxon : W={w:.6f}, p={wp:.3e}")


# ────────────────────────────────────────────────────────────────────
# 2. experiment entry point
# ------------------------------------------------------------------
@hydra.main(config_path="config", config_name="config", version_base="1.1")
def main(cfg):
    print(cfg.dataset.test.root_dir, flush=True)
    # output_json_path = os.path.join(os.path.dirname(cfg.eval.eval_dataset.test.split_json_path), "file_paths_8_12.json")
    # print(f"Output json path: {output_json_path}")
    output_json_path = ""
    dataset, mistakes, scores = _build_dataset(
        root_dir=cfg.dataset.test.root_dir,
        json_path=cfg.dataset.test.split_json_path,
        split="test",
        output_json_file=output_json_path,
    )

    prompt_dir = [entry["prompt"] for entry in dataset]

    # optional subset
    if cfg.eval.eval_first_n_examples:
        random_offset = 0
        if len(mistakes) > cfg.eval.eval_first_n_examples:
            random_offset = np.random.randint(
                0, len(mistakes) - cfg.eval.eval_first_n_examples
            )
        else:
            cfg.eval.eval_first_n_examples = len(mistakes)

        mistakes = mistakes[
            random_offset : random_offset + cfg.eval.eval_first_n_examples
        ]
        scores = scores[random_offset : random_offset + cfg.eval.eval_first_n_examples]

        prompt_dir = prompt_dir[
            random_offset : random_offset + cfg.eval.eval_first_n_examples
        ]  # Subset ground truth as well
    eval_audio_dir = mistakes

    # align MIDI once
    for data in dataset:
        if (
            data["mistake_audio"] in eval_audio_dir
            or data["score_audio"] in eval_audio_dir
        ):
            process_audio_and_midi(data["mistake_audio"], data["score_midi"])
    print("Aligned MIDIs")

    # ---------- model A  (audio‑only, new code path) ----------
    cfg.model.config.use_prompt = False
    model_cls = hydra.utils.get_class(cfg.model._target_)
    print("torch.cuda.device_count():", torch.cuda.device_count())

    pl = model_cls.load_from_checkpoint(
        cfg.path,
        config=cfg.model.config,
        optim_cfg=cfg.optim,
    )
    model_A = pl.model
    model_A.eval()
    # TODO: need to replace with =
    _, _, trk_A = get_scores_A(
        model_A,
        mistakes_audio_dir=mistakes,
        scores_audio_dir=scores,
        prompt_dir=prompt_dir,
        mel_norm=True,
        eval_dataset=cfg.eval.eval_dataset,
        exp_tag_name=cfg.eval.exp_tag_name,
        ground_truth=dataset,
        contiguous_inference=cfg.eval.contiguous_inference,
        batch_size=cfg.eval.batch_size,
        output_json_path=output_json_path,
    )
    # ---------- model C  (prompt) ----------
    # Model C is not perfroming like it should
    cfg.model_new.config.use_prompt = True
    model_cls = hydra.utils.get_class(cfg.model_new._target_)
    print("torch.cuda.device_count():", torch.cuda.device_count())

    pl = model_cls.load_from_checkpoint(
        cfg.path_new,
        config=cfg.model_new.config,
        optim_cfg=cfg.optim,
    )
    model_C = pl.model
    model_C.eval()
    # handler = InferenceHandler(model=model_C, device=torch.device("cuda"),
    #                            mel_norm=True,
    #                            contiguous_inference=cfg.eval.contiguous_inference)
    _, _, trk_C = get_scores_A(
        model_C,
        mistakes_audio_dir=mistakes,
        scores_audio_dir=scores,
        prompt_dir=prompt_dir,
        mel_norm=True,
        eval_dataset=cfg.eval.eval_dataset,
        exp_tag_name=cfg.eval.exp_tag_name,
        ground_truth=dataset,
        contiguous_inference=cfg.eval.contiguous_inference,
        batch_size=cfg.eval.batch_size,
        output_json_path=output_json_path,
    )
    # ---------- model B  (audio‑only, old) ----------
    model_cls = hydra.utils.get_class(cfg.model_old._target_)
    print("torch.cuda.device_count():", torch.cuda.device_count())

    pl = model_cls.load_from_checkpoint(
        cfg.path_old,
        config=cfg.model_old.config,
        optim_cfg=cfg.optim,
    )
    model_B = pl.model
    model_B.eval()
    _, _, trk_B = get_scores_B(
        model_B,
        eval_audio_dir=mistakes,
        mel_norm=True,
        eval_dataset=cfg.eval.eval_dataset_old,
        exp_tag_name=cfg.eval.exp_tag_name,
        ground_truth=dataset,
        contiguous_inference=cfg.eval.contiguous_inference,
        batch_size=cfg.eval.batch_size,
        output_json_path=output_json_path,
    )

    # out_root = os.path.join(
    #     hydra.core.hydra_config.HydraConfig.get().runtime.output_dir,
    #     cfg.eval.exp_tag_name
    # )
    # os.makedirs(out_root, exist_ok=True)

    # for m_wav, s_wav, prm, row in tqdm(zip(mistakes, scores, prompt_dir, dataset),
    #                                    total=len(dataset)):
    #     tid = os.path.basename(os.path.dirname(m_wav))
    #     out_mid = os.path.join(out_root, tid, "mix.mid")
    #     print(out_mid)
    #     fname = m_wav.split("/")[-3]
    #     base_name, ext = os.path.splitext(os.path.basename(m_wav))
    #     file_name = base_name + ".mid"
    #     print(f"fname: {fname}, file_name: {file_name}", flush=True)

    #     os.makedirs(os.path.dirname(out_mid), exist_ok=True)
    #     m_audio,_ = librosa.load(m_wav, sr=16000)
    #     s_audio,_ = librosa.load(s_wav, sr=16000)
    #     handler.inference(mistake_audio = m_audio, score_audio = s_audio, audio_path = tid, prompt_path = prm, outpath = out_mid,
    #                       batch_size = cfg.eval.batch_size, max_length = 1024, verbose=False)

    # _, _, trk_C = eval_Prompt(dataset_name=cfg.eval.eval_dataset, test_midi_dir=out_root, ground_truth=dataset, output_json_file=output_json_path,)

    # ---------- paired tests ----------
    run_tests(trk_A, trk_B, "A (audio new)  vs B (audio old)")
    run_tests(trk_A, trk_C, "A (audio new)  vs C (prompt)")
    run_tests(trk_B, trk_C, "B (audio old)  vs C (prompt)")


if __name__ == "__main__":
    main()
