import json
import random
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
import librosa
import pandas as pd
from audio import (
    loudness_normalize,
    compute_speaker_activity_masks,
)
from config import *
from distortions import apply_pm_distortions, apply_ps_distortions
from metrics import (
    compute_pm,
    compute_ps,
    diffusion_map_torch,
    pm_ci_components_full,
    ps_ci_components_full,
)
from models import embed_batch, load_model
from utils import *


def compute_mapss_measures(
        models,
        mixtures,
        *,
        systems=None,
        algos=None,
        experiment_id=None,
        layer=DEFAULT_LAYER,
        add_ci=DEFAULT_ADD_CI,
        alpha=DEFAULT_ALPHA,
        seed=42,
        on_missing="skip",
        verbose=False,
        max_gpus=None,
):
    """
    Compute MAPSS measures (PM, PS, and their errors). Data is saved to csv files.

    :param models: backbone self-supervised models.
    :param mixtures: data to process from _read_manifest
    :param systems: specific systems (algos and data)
    :param algos: specific algorithms to use
    :param experiment_id: user-specified name for experiment
    :param layer: transformer layer of model to consider
    :param add_ci: True will compute error radius and tail bounds. False will not.
    :param alpha: normalization factor of the diffusion maps. Lives in [0, 1].
    :param seed: random seed number.
    :param on_missing: "skip" when missing values or throw an "error".
    :param verbose: True will print process info to console during runtime. False will minimize it.
    :param max_gpus: maximal amount of GPUs the program tries to utilize in parallel.

    """
    gpu_distributor = GPUWorkDistributor(max_gpus)
    ngpu = get_gpu_count(max_gpus)

    if on_missing not in {"skip", "error"}:
        raise ValueError("on_missing must be 'skip' or 'error'.")

    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    canon_mix = canonicalize_mixtures(mixtures, systems=systems)

    mixture_entries = []
    for m in canon_mix:
        entries = []
        for i, refp in enumerate(m.refs):
            sid = m.speaker_ids[i]
            entries.append(
                {"id": sid, "ref": Path(refp), "mixture": m.mixture_id, "outs": {}}
            )
        mixture_entries.append(entries)

    for m, mix_entries in zip(canon_mix, mixture_entries):
        for algo, out_list in (m.systems or {}).items():
            if len(out_list) != len(mix_entries):
                msg = f"[{algo}] Number of outputs ({len(out_list)}) does not match number of references ({len(mix_entries)}) for mixture {m.mixture_id}"
                if on_missing == "error":
                    raise ValueError(msg)
                else:
                    if verbose:
                        warnings.warn(msg + " Skipping this algorithm.")
                    continue

            for idx, e in enumerate(mix_entries):
                e["outs"][algo] = out_list[idx]

    if algos is None:
        algos_to_run = sorted(
            {algo for algo in canon_mix[0].systems.keys()} if canon_mix and canon_mix[0].systems else []
        )
    else:
        algos_to_run = list(algos)

    exp_id = experiment_id or datetime.now().strftime("%Y%m%d_%H%M%S")
    exp_root = os.path.join(RESULTS_ROOT, f"experiment_{exp_id}")
    os.makedirs(exp_root, exist_ok=True)

    params = {
        "models": models,
        "layer": layer,
        "add_ci": add_ci,
        "alpha": alpha,
        "seed": seed,
        "batch_size": BATCH_SIZE,
        "ngpu": ngpu,
        "max_gpus": max_gpus,
    }

    with open(os.path.join(exp_root, "params.json"), "w") as f:
        json.dump(params, f, indent=2)

    canon_struct = [
        {
            "mixture_id": m.mixture_id,
            "references": [str(p) for p in m.refs],
            "systems": {
                a: [str(p) for p in outs] for a, outs in (m.systems or {}).items()
            },
            "speaker_ids": m.speaker_ids,
        }
        for m in canon_mix
    ]

    with open(os.path.join(exp_root, "manifest_canonical.json"), "w") as f:
        json.dump(canon_struct, f, indent=2)

    print(f"Starting experiment {exp_id} with {ngpu} GPUs")
    print(f"Results will be saved to: {exp_root}")
    print("NOTE: Output files must be provided in the same order as reference files.")

    clear_gpu_memory()
    get_gpu_memory_info(verbose)

    flat_entries = [e for mix in mixture_entries for e in mix]
    all_refs = {}

    if verbose:
        print("Loading reference signals...")
    for e in flat_entries:
        wav, _ = librosa.load(str(e["ref"]), sr=SR)
        all_refs[e["id"]] = torch.from_numpy(loudness_normalize(wav))

    if verbose:
        print("Computing speaker activity masks...")

    win = int(ENERGY_WIN_MS * SR / 1000)
    hop = int(ENERGY_HOP_MS * SR / 1000)
    multi_speaker_masks_mix = []
    individual_speaker_masks_mix = []
    total_frames_per_mix = []

    for i, mix in enumerate(mixture_entries):
        if verbose:
            print(f"  Computing masks for mixture {i + 1}/{len(mixture_entries)}")

        if ngpu > 0:
            with torch.cuda.device(0):
                refs_for_mix = [all_refs[e["id"]].cuda() for e in mix]
                multi_mask, individual_masks = compute_speaker_activity_masks(refs_for_mix, win, hop)
                multi_speaker_masks_mix.append(multi_mask.cpu())
                individual_speaker_masks_mix.append([m.cpu() for m in individual_masks])
                total_frames_per_mix.append(multi_mask.shape[0])
                for ref in refs_for_mix:
                    del ref
                torch.cuda.empty_cache()
        else:
            refs_for_mix = [all_refs[e["id"]].cpu() for e in mix]
            multi_mask, individual_masks = compute_speaker_activity_masks(refs_for_mix, win, hop)
            multi_speaker_masks_mix.append(multi_mask.cpu())
            individual_speaker_masks_mix.append([m.cpu() for m in individual_masks])
            total_frames_per_mix.append(multi_mask.shape[0])

    ordered_speakers = [e["id"] for e in flat_entries]
    all_mixture_results = {}
    for mix_idx, (mix_canon, mix_entries) in enumerate(zip(canon_mix, mixture_entries)):
        mixture_id = mix_canon.mixture_id
        all_mixture_results[mixture_id] = {}
        total_frames = total_frames_per_mix[mix_idx]
        mixture_speakers = [e["id"] for e in mix_entries]

        for algo_idx, algo in enumerate(algos_to_run):
            if verbose:
                print(f"\nProcessing Mixture {mixture_id}, Algorithm {algo_idx + 1}/{len(algos_to_run)}: {algo}")
            all_outs = {}
            missing = []
            for e in mix_entries:
                assigned_path = e.get("outs", {}).get(algo)
                if assigned_path is None:
                    missing.append((e["mixture"], e["id"]))
                    continue
                wav, _ = librosa.load(str(assigned_path), sr=SR)
                all_outs[e["id"]] = torch.from_numpy(loudness_normalize(wav))

            if missing:
                msg = f"[{algo}] missing outputs for {len(missing)} speaker(s) in mixture {mixture_id}"
                if on_missing == "error":
                    raise FileNotFoundError(msg)
                else:
                    if verbose:
                        warnings.warn(msg + " Skipping those speakers.")

            if not all_outs:
                if verbose:
                    warnings.warn(f"[{algo}] No outputs for mixture {mixture_id}. Skipping.")
                continue

            if algo not in all_mixture_results[mixture_id]:
                all_mixture_results[mixture_id][algo] = {}

            ps_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
            pm_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
            ps_bias_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
            ps_prob_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
            pm_bias_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}
            pm_prob_frames = {m: {s: [np.nan] * total_frames for s in mixture_speakers} for m in models}

            for model_idx, mname in enumerate(models):
                if verbose:
                    print(f"  Processing Model {model_idx + 1}/{len(models)}: {mname}")

                for metric_type in ["PS", "PM"]:
                    clear_gpu_memory()
                    gc.collect()

                    model_wrapper, layer_eff = load_model(mname, layer, max_gpus)
                    get_gpu_memory_info(verbose)

                    speakers_this_mix = [e for e in mix_entries if e["id"] in all_outs]
                    if not speakers_this_mix:
                        continue

                    if verbose:
                        print(f"    Processing {metric_type} for mixture {mixture_id}")

                    multi_speaker_mask = multi_speaker_masks_mix[mix_idx]
                    individual_masks = individual_speaker_masks_mix[mix_idx]
                    valid_frame_indices = torch.where(multi_speaker_mask)[0].tolist()

                    speaker_signals = {}
                    speaker_labels = {}

                    for speaker_idx, e in enumerate(speakers_this_mix):
                        s = e["id"]

                        if metric_type == "PS":
                            dists = [
                                loudness_normalize(d)
                                for d in apply_ps_distortions(all_refs[s].numpy(), "all")
                            ]
                        else:
                            dists = [
                                loudness_normalize(d)
                                for d in apply_pm_distortions(
                                    all_refs[s].numpy(), "all"
                                )
                            ]

                        sigs = [all_refs[s].numpy(), all_outs[s].numpy()] + dists
                        lbls = ["ref", "out"] + [f"d{i}" for i in range(len(dists))]

                        speaker_signals[s] = sigs
                        speaker_labels[s] = [f"{s}-{l}" for l in lbls]

                    all_embeddings = {}
                    for s in speaker_signals:
                        sigs = speaker_signals[s]
                        masks = [multi_speaker_mask] * len(sigs)

                        batch_size = min(2, BATCH_SIZE)
                        embeddings_list = []

                        for i in range(0, len(sigs), batch_size):
                            batch_sigs = sigs[i:i + batch_size]
                            batch_masks = masks[i:i + batch_size]

                            batch_embs = embed_batch(
                                batch_sigs,
                                batch_masks,
                                model_wrapper,
                                layer_eff,
                                use_mlm=False,
                            )

                            if batch_embs.numel() > 0:
                                embeddings_list.append(batch_embs.cpu())

                            torch.cuda.empty_cache()

                        if embeddings_list:
                            all_embeddings[s] = torch.cat(embeddings_list, dim=0)
                        else:
                            all_embeddings[s] = torch.empty(0, 0, 0)

                    if not all_embeddings or all(e.numel() == 0 for e in all_embeddings.values()):
                        if verbose:
                            print(f"WARNING: mixture {mixture_id} produced 0 frames after masking; skipping.")
                        continue

                    L = next(iter(all_embeddings.values())).shape[1] if all_embeddings else 0

                    if L == 0:
                        if verbose:
                            print(f"WARNING: mixture {mixture_id} produced 0 frames after masking; skipping.")
                        continue

                    if verbose:
                        print(f"Computing {metric_type} scores for {mname}...")

                    with ThreadPoolExecutor(
                            max_workers=min(2, ngpu if ngpu > 0 else 1)
                    ) as executor:

                        def process_frame(f, frame_idx, all_embeddings_dict, speaker_labels_dict, individual_masks_list,
                                          speaker_indices):
                            try:
                                active_speakers = []
                                for spk_idx, spk_id in enumerate(speaker_indices):
                                    if individual_masks_list[spk_idx][frame_idx]:
                                        active_speakers.append(spk_id)

                                if len(active_speakers) < 2:
                                    return frame_idx, metric_type, {}, None, None

                                frame_embeddings = []
                                frame_labels = []
                                for spk_id in active_speakers:
                                    spk_embs = all_embeddings_dict[spk_id][:, f, :]
                                    frame_embeddings.append(spk_embs)
                                    frame_labels.extend(speaker_labels_dict[spk_id])

                                frame_emb = torch.cat(frame_embeddings, dim=0).detach().cpu().numpy()

                                if add_ci:
                                    coords_d, coords_c, eigvals, k_sub_gauss = (
                                        gpu_distributor.execute_on_gpu(
                                            diffusion_map_torch,
                                            frame_emb,
                                            frame_labels,
                                            alpha=alpha,
                                            eig_solver="full",
                                            return_eigs=True,
                                            return_complement=True,
                                            return_cval=add_ci,
                                        )
                                    )
                                else:
                                    coords_d = gpu_distributor.execute_on_gpu(
                                        diffusion_map_torch,
                                        frame_emb,
                                        frame_labels,
                                        alpha=alpha,
                                        eig_solver="full",
                                        return_eigs=False,
                                        return_complement=False,
                                        return_cval=False,
                                    )
                                    coords_c = None
                                    eigvals = None
                                    k_sub_gauss = 1

                                if metric_type == "PS":
                                    score = compute_ps(
                                        coords_d, frame_labels, max_gpus
                                    )
                                    bias = prob = None
                                    if add_ci:
                                        bias, prob = ps_ci_components_full(
                                            coords_d,
                                            coords_c,
                                            eigvals,
                                            frame_labels,
                                            delta=DEFAULT_DELTA_CI,
                                        )
                                    return frame_idx, "PS", score, bias, prob
                                else:
                                    score = compute_pm(
                                        coords_d, frame_labels, "gamma", max_gpus
                                    )
                                    bias = prob = None
                                    if add_ci:
                                        bias, prob = pm_ci_components_full(
                                            coords_d,
                                            coords_c,
                                            eigvals,
                                            frame_labels,
                                            delta=DEFAULT_DELTA_CI,
                                            K=k_sub_gauss,
                                        )
                                    return frame_idx, "PM", score, bias, prob

                            except Exception as ex:
                                if verbose:
                                    print(f"ERROR frame {frame_idx}: {ex}")
                                return None

                        speaker_ids = [e["id"] for e in speakers_this_mix]

                        futures = [
                            executor.submit(
                                process_frame,
                                f,
                                valid_frame_indices[f],
                                all_embeddings,
                                speaker_labels,
                                individual_masks,
                                speaker_ids
                            )
                            for f in range(L)
                        ]

                        for fut in futures:
                            result = fut.result()
                            if result is None:
                                continue

                            frame_idx, metric, score, bias, prob = result

                            if metric == "PS":
                                for sp in mixture_speakers:
                                    if sp in score:
                                        ps_frames[mname][sp][frame_idx] = score[sp]
                                        if add_ci and bias is not None and sp in bias:
                                            ps_bias_frames[mname][sp][frame_idx] = bias[sp]
                                            ps_prob_frames[mname][sp][frame_idx] = prob[sp]
                            else:
                                for sp in mixture_speakers:
                                    if sp in score:
                                        pm_frames[mname][sp][frame_idx] = score[sp]
                                        if add_ci and bias is not None and sp in bias:
                                            pm_bias_frames[mname][sp][frame_idx] = bias[sp]
                                            pm_prob_frames[mname][sp][frame_idx] = prob[sp]

                    clear_gpu_memory()
                    gc.collect()

                    del model_wrapper
                    clear_gpu_memory()
                    gc.collect()

            all_mixture_results[mixture_id][algo][mname] = {
                'ps_frames': ps_frames[mname],
                'pm_frames': pm_frames[mname],
                'ps_bias_frames': ps_bias_frames[mname] if add_ci else None,
                'ps_prob_frames': ps_prob_frames[mname] if add_ci else None,
                'pm_bias_frames': pm_bias_frames[mname] if add_ci else None,
                'pm_prob_frames': pm_prob_frames[mname] if add_ci else None,
                'total_frames': total_frames
            }

        if verbose:
            print(f"Saving results for mixture {mixture_id}...")

        timestamps_ms = [i * hop * 1000 / SR for i in range(total_frames)]

        for model in models:
            ps_data = {'timestamp_ms': timestamps_ms}
            pm_data = {'timestamp_ms': timestamps_ms}
            ci_data = {'timestamp_ms': timestamps_ms} if add_ci else None

            for algo in algos_to_run:
                if algo not in all_mixture_results[mixture_id]:
                    continue
                if model not in all_mixture_results[mixture_id][algo]:
                    continue

                model_data = all_mixture_results[mixture_id][algo][model]

                for speaker in mixture_speakers:
                    col_name = f"{algo}_{speaker}"
                    ps_data[col_name] = model_data['ps_frames'][speaker]
                    pm_data[col_name] = model_data['pm_frames'][speaker]

                    if add_ci and ci_data is not None:
                        ci_data[f"{algo}_{speaker}_ps_bias"] = model_data['ps_bias_frames'][speaker]
                        ci_data[f"{algo}_{speaker}_ps_prob"] = model_data['ps_prob_frames'][speaker]
                        ci_data[f"{algo}_{speaker}_pm_bias"] = model_data['pm_bias_frames'][speaker]
                        ci_data[f"{algo}_{speaker}_pm_prob"] = model_data['pm_prob_frames'][speaker]

            mixture_dir = os.path.join(exp_root, mixture_id)
            os.makedirs(mixture_dir, exist_ok=True)

            pd.DataFrame(ps_data).to_csv(
                os.path.join(mixture_dir, f"ps_scores_{model}.csv"),
                index=False
            )

            pd.DataFrame(pm_data).to_csv(
                os.path.join(mixture_dir, f"pm_scores_{model}.csv"),
                index=False
            )

            if add_ci and ci_data is not None:
                pd.DataFrame(ci_data).to_csv(
                    os.path.join(mixture_dir, f"ci_{model}.csv"),
                    index=False
                )

        del all_outs
        clear_gpu_memory()
        gc.collect()

    print(f"\nEXPERIMENT COMPLETED")
    print(f"Results saved to: {exp_root}")

    del all_refs, multi_speaker_masks_mix, individual_speaker_masks_mix

    from models import cleanup_all_models
    cleanup_all_models()

    clear_gpu_memory()
    get_gpu_memory_info(verbose)
    gc.collect()

    return exp_root