


import argparse
import gc
import timeit
from pathlib import Path

import hydra
import matplotlib.pyplot as plt
import pandas as pd
import torch
from hydra import compose, initialize
from transformers.modeling_utils import load_state_dict

from benchmark_base import (
    build_batch_from_dataset,
    build_seq_len_list,
    free_cuda_memory,
    get_core_model,
    get_mlp_modules,
    parse_csv_list,
    parse_dtype,
    parse_layers_per_impl,
    resolve_run_dir,
    resolve_state_dict_file,
    slice_batch_for_seq_len,
    str_to_bool,
)
from energy_utils import (
    GPUEnergyMonitor,
    MlpEnergyHooks,
    merge_energy_stats,
    plot_energy_bars,
    summarize_energy,
)

from custom_models.sparse_testing_utils import (
    get_bwell_mlp_class,
    sparse_to_bwell_state_dict,
)


class MlpHookTimer:
    def __init__(self, device, module_to_layer=None):
        self.enabled = device.type == "cuda"
        self.event_pairs = []
        self.start_events_by_mod = {}
        self.module_to_layer = module_to_layer or {}

    def clear(self):
        self.event_pairs.clear()
        self.start_events_by_mod.clear()

    def pre_hook(self, module, _inputs):
        if not self.enabled:
            return
        start = torch.cuda.Event(enable_timing=True)
        start.record()
        key = id(module)
        if key not in self.start_events_by_mod:
            self.start_events_by_mod[key] = []
        self.start_events_by_mod[key].append(start)

    def post_hook(self, module, _inputs, _output):
        if not self.enabled:
            return
        end = torch.cuda.Event(enable_timing=True)
        end.record()
        key = id(module)
        stack = self.start_events_by_mod.get(key, [])
        if not stack:
            return
        start = stack.pop()
        layer_idx = self.module_to_layer.get(key, None)
        self.event_pairs.append((layer_idx, start, end))

    def sum_ms(self):
        if not self.enabled:
            return float("nan")
        total_ms = 0.0
        for _layer_idx, start, end in self.event_pairs:
            total_ms += start.elapsed_time(end)
        return total_ms

    def per_layer_ms(self, num_layers):
        if not self.enabled:
            return [float("nan")] * int(num_layers)
        layer_ms = [0.0] * int(num_layers)
        for layer_idx, start, end in self.event_pairs:
            if layer_idx is None:
                continue
            if 0 <= layer_idx < num_layers:
                layer_ms[layer_idx] += start.elapsed_time(end)
        return layer_ms

        
def load_model_for_impl(cfg, state_dict_file, mlp_impl, dtype, device,
                        convert_state, layers_to_replace=None):
    model = hydra.utils.instantiate(cfg.loaded_model)

    if mlp_impl not in {None, "", "none"}:
        if not hasattr(model, "replace_mlp_modules"):
            raise RuntimeError("model has no replace_mlp_modules()")
        mlp_cls = get_bwell_mlp_class(mlp_impl)
        model.replace_mlp_modules(
            mlp_cls,
            layers_to_replace=layers_to_replace,
        )

    state = load_state_dict(str(state_dict_file))
    if mlp_impl not in {None, "", "none"} and convert_state:
        state = sparse_to_bwell_state_dict(
            state,
            layer_numbers=layers_to_replace,
            gated=mlp_impl != "bwell_nongated_v0",
        )

    incompatible = model.load_state_dict(state, strict=False)
    del state
    gc.collect()

    model = model.eval().to(device).to(dtype)
    core = get_core_model(model)
    if hasattr(core, "config") and hasattr(core.config, "use_cache"):
        core.config.use_cache = False

    return model, incompatible

def benchmark_one(model, input_ids, attention_mask, reps, warmup_reps,
                  measure_mlp_layers, measure_energy=False,
                  energy_poll_interval_ms=50.0,
                  measure_mlp_energy=False,
                  measure_mlp_energy_reps=5):
    device = input_ids.device
    core = get_core_model(model)

    mlps = get_mlp_modules(model)
    handles = []
    timer = None
    layer_sums = None
    energy_monitor = None
    energy_stats = None
    mlp_energy_stats = None

    if measure_mlp_layers:
        module_to_layer = {id(m): i for i, m in enumerate(mlps)}
        timer = MlpHookTimer(device, module_to_layer=module_to_layer)
        for m in mlps:
            handles.append(m.register_forward_pre_hook(timer.pre_hook))
            handles.append(m.register_forward_hook(timer.post_hook))
        layer_sums = [0.0] * len(mlps)

    if measure_energy and device.type == "cuda":
        gpu_index = device.index
        if gpu_index is None:
            gpu_index = torch.cuda.current_device()
        energy_monitor = GPUEnergyMonitor(
            gpu_index=gpu_index,
            poll_interval_s=float(energy_poll_interval_ms) / 1000.0,
        )
    mlp_energy_monitor = None
    if measure_mlp_energy and device.type == "cuda" and mlps:
        gpu_index = device.index
        if gpu_index is None:
            gpu_index = torch.cuda.current_device()
        mlp_energy_monitor = GPUEnergyMonitor(
            gpu_index=gpu_index,
            poll_interval_s=float(energy_poll_interval_ms) / 1000.0,
        )

    def run_once():
        with torch.inference_mode():
            _ = core(input_ids=input_ids, attention_mask=attention_mask)

    
    for _ in range(warmup_reps):
        if timer is not None:
            timer.clear()
        run_once()
        if device.type == "cuda":
            torch.cuda.synchronize()

    if (
        mlp_energy_monitor is not None
        and mlp_energy_monitor.enabled
        and measure_mlp_energy_reps > 0
    ):
        mlp_energy_hooks = MlpEnergyHooks(
            mlp_energy_monitor,
            mlp_modules=mlps,
            sync_on_stop=True,
        )
        mlp_stats = []
        for _ in range(int(measure_mlp_energy_reps)):
            if timer is not None:
                timer.clear()
            run_once()
            if device.type == "cuda":
                torch.cuda.synchronize()
            stats = mlp_energy_monitor.results()
            if stats is not None:
                mlp_stats.append(stats)
        mlp_energy_hooks.remove()
        mlp_energy_stats = merge_energy_stats(mlp_stats)

    if energy_monitor is not None:
        energy_monitor.start()

    total_times_ms = []
    mlp_times_ms = []

    for _ in range(reps):
        if timer is not None:
            timer.clear()

        start = timeit.default_timer()
        run_once()
        if device.type == "cuda":
            torch.cuda.synchronize()
        end = timeit.default_timer()

        total_times_ms.append((end - start) * 1000.0)

        if timer is not None:
            mlp_ms = timer.sum_ms()
            mlp_times_ms.append(mlp_ms)

            layer_ms = timer.per_layer_ms(len(mlps))
            for i, t in enumerate(layer_ms):
                layer_sums[i] += t
        else:
            mlp_times_ms.append(float("nan"))

    for h in handles:
        h.remove()

    if energy_monitor is not None:
        energy_monitor.stop()
        energy_stats = energy_monitor.results()

    avg_total_ms = sum(total_times_ms) / len(total_times_ms)
    avg_mlp_ms = sum(mlp_times_ms) / len(mlp_times_ms)

    avg_layer_ms = None
    if layer_sums is not None and reps > 0:
        avg_layer_ms = [t / float(reps) for t in layer_sums]

    row = {
        "avg_total_ms": float(avg_total_ms),
        "avg_mlp_ms": float(avg_mlp_ms),
        "mlp_fraction": float(avg_mlp_ms / avg_total_ms)
        if avg_total_ms > 0.0 else float("nan"),
        "num_mlp_modules": int(len(mlps)),
        "reps": int(reps),
        "warmup_reps": int(warmup_reps),
        "measure_mlp_layers": bool(measure_mlp_layers),
    }
    row.update(summarize_energy(energy_stats, reps=reps))
    row.update(
        summarize_energy(
            mlp_energy_stats,
            reps=measure_mlp_energy_reps,
            prefix="mlp_",
        )
    )
    if measure_mlp_energy:
        row["mlp_energy_reps"] = int(measure_mlp_energy_reps)
    return row, avg_layer_ms


def plot_layer_results(layer_df, out_path):

    plt.figure()
    for impl, sub in layer_df.groupby("implementation"):
        sub = sub.sort_values("layer_idx")
        plt.plot(
            sub["layer_idx"].to_list(),
            sub["avg_mlp_layer_ms"].to_list(),
            marker="o",
            label=str(impl),
        )

    plt.xlabel("layer index")
    plt.ylabel("mlp time (ms)")
    plt.title("mlp time per layer")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(str(out_path))


def plot_timing_bars(df, out_path):
    names = df["implementation"].to_list()
    x = list(range(len(names)))

    total_ms = df["avg_total_ms"].to_list()
    has_mlp = "avg_mlp_ms" in df.columns and df["avg_mlp_ms"].notna().all()

    plt.figure()
    if has_mlp:
        width = 0.38
        total_x = [xi - width / 2 for xi in x]
        mlp_x = [xi + width / 2 for xi in x]
        plt.bar(total_x, total_ms, width=width, label="total ms")
        plt.bar(mlp_x, df["avg_mlp_ms"].to_list(), width=width,
                label="mlp ms")
    else:
        plt.bar(x, total_ms, width=0.6, label="total ms")

    plt.xticks(x, names, rotation=45, ha="right")
    plt.ylabel("time (ms)")
    plt.title("per-forward timing")
    plt.grid(True, axis="y")
    plt.legend()
    plt.tight_layout()
    plt.savefig(str(out_path))
    plt.close()


def plot_speed_vs_seq_len(df, speed_col, out_path, title):
    plt.figure()
    for impl, sub in df.groupby("implementation"):
        sub = sub.sort_values("seq_len")
        sub = sub.dropna(subset=[speed_col])
        if sub.empty:
            continue
        plt.plot(
            sub["seq_len"].to_list(),
            sub[speed_col].to_list(),
            marker="o",
            label=str(impl),
        )

    plt.xlabel("sequence length")
    plt.ylabel("tokens / second")
    plt.title(title)
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(str(out_path))
    plt.close()



def plot_speedup_bars(df, baseline_name, out_path):
    names = df["implementation"].to_list()
    x = list(range(len(names)))

    speed_total = df["speedup_total"].to_list()
    has_mlp = "speedup_mlp" in df.columns and df["speedup_mlp"].notna().all()

    plt.figure()
    if has_mlp:
        width = 0.38
        total_x = [xi - width / 2 for xi in x]
        mlp_x = [xi + width / 2 for xi in x]
        plt.bar(total_x, speed_total, width=width,
                label=f"total speedup vs {baseline_name}")
        plt.bar(mlp_x, df["speedup_mlp"].to_list(), width=width,
                label=f"mlp speedup vs {baseline_name}")
    else:
        plt.bar(x, speed_total, width=0.6,
                label=f"total speedup vs {baseline_name}")

    plt.xticks(x, names, rotation=45, ha="right")
    plt.ylabel("speedup")
    plt.title("relative speedup")
    plt.grid(True, axis="y")
    plt.legend()
    plt.tight_layout()
    plt.savefig(str(out_path))
    plt.close()


def plot_speedup_vs_seq_len(df, speedup_col, out_path, title, baseline_name):
    plt.figure()
    for impl, sub in df.groupby("implementation"):
        sub = sub.sort_values("seq_len")
        sub = sub.dropna(subset=[speedup_col])
        if sub.empty:
            continue
        plt.plot(
            sub["seq_len"].to_list(),
            sub[speedup_col].to_list(),
            marker="o",
            label=str(impl),
        )

    plt.xlabel("sequence length")
    plt.ylabel(f"speedup vs {baseline_name}")
    plt.title(title)
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(str(out_path))
    plt.close()



def build_arg_parser():
    p = argparse.ArgumentParser(
        prog="bench_full_mlp",
        description="benchmark summed mlp time across layers (hooks + cuda)",
    )

    p.add_argument("--model-path", type=str, required=True)

    p.add_argument("--hydra-config-path", type=str, default="cfgs")
    p.add_argument("--hydra-config-name", type=str, default="train")
    p.add_argument("--hydra-run-path", type=str, required=True)

    p.add_argument("--results-dir", type=str, required=True)

    p.add_argument("--mlp-implementations", type=str, default=None)
    p.add_argument("--layers-per-implementation", type=str, default=None)
    p.add_argument("--baseline", type=str, default="none")
    p.add_argument("--convert-state", type=str_to_bool, default=True)

    p.add_argument("--batch-size", type=int, default=64)
    p.add_argument("--seq-len", type=int, default=2048)

    p.add_argument("--seq-len-step", type=int, default=256)
    p.add_argument("--adjust-seq-len-for-batch", type=str_to_bool, default=True)

    p.add_argument("--dtype", type=str, default="bf16")
    p.add_argument("--device", type=str, default="cuda")

    p.add_argument("--reps", type=int, default=50)
    p.add_argument("--warmup-reps", type=int, default=5)

    p.add_argument("--out-csv", type=str)
    p.add_argument("--plot", type=str_to_bool, default=True)
    p.add_argument("--plot-save", type=str)

    p.add_argument("--measure-mlp-layers", type=str_to_bool, default=True)
    p.add_argument("--measure-energy", type=str_to_bool, default=False)
    p.add_argument("--energy-interval-ms", type=float, default=50.0)
    p.add_argument("--measure-mlp-energy", type=str_to_bool, default=False)
    p.add_argument("--measure-mlp-energy-reps", type=int, default=5)

    return p


def main():
    args = build_arg_parser().parse_args()

    default_implementations = [
        'bwell_v0',
        'none',
    ]

    results_dir = Path(args.results_dir)
    results_dir.mkdir(parents=True, exist_ok=True)
    if args.out_csv is None:
        args.out_csv = str(results_dir / "timing_df.csv")
    if args.plot_save is None:
        args.plot_save = str(results_dir / "timing_plot.png")

    run_dir = resolve_run_dir(None, None, args.model_path)
    state_dict_file = resolve_state_dict_file(run_dir)

    dtype = parse_dtype(args.dtype)
    device = torch.device(args.device)
    if device.type == "cuda" and not torch.cuda.is_available():
        raise RuntimeError("cuda device requested but cuda is unavailable")

    if args.mlp_implementations is None or args.mlp_implementations.strip() == "":
        implementation_names = default_implementations
    else:
        implementation_names = parse_csv_list(args.mlp_implementations)
    if not implementation_names:
        raise ValueError("--mlp-implementations is empty")
    layers_per_impl = parse_layers_per_impl(
        args.layers_per_implementation,
        len(implementation_names),
    )

    overrides = [
        f"run_cfg@_global_={args.hydra_run_path}",
        "sparsity_mlp_execution_logic=inference",
    ]
    with initialize(version_base=None,
                    config_path=args.hydra_config_path,
                    job_name="bench_full_mlp"):
        cfg = compose(config_name=args.hydra_config_name,
                      overrides=overrides)

    tokenizer = hydra.utils.instantiate(cfg.make_tokenizer_fn)
    seq_lens = build_seq_len_list(args.seq_len, args.seq_len_step)

    total_tokens = args.batch_size * args.seq_len
    
    if args.adjust_seq_len_for_batch:
        min_seq_len = min(seq_lens)
        batch_elements_needed = (total_tokens + min_seq_len - 1) // min_seq_len
    else:
        batch_elements_needed = args.batch_size
    
    input_ids, attention_mask = build_batch_from_dataset(
        cfg=cfg,
        tokenizer=tokenizer,
        batch_size=batch_elements_needed,
        seq_len=args.seq_len,
        device=device,
    )
    out_csv = Path(args.out_csv)
    out_csv.parent.mkdir(parents=True, exist_ok=True)

    rows = []
    layer_rows = []

    for impl_idx, implementation_name in enumerate(implementation_names):
        free_cuda_memory()

        model, incompatible = load_model_for_impl(
            cfg=cfg,
            state_dict_file=state_dict_file,
            mlp_impl=implementation_name,
            dtype=dtype,
            device=device,
            convert_state=args.convert_state,
            layers_to_replace=layers_per_impl[impl_idx],
        )

        for seq_len in seq_lens:
            if args.adjust_seq_len_for_batch:
                current_batch_size = total_tokens // seq_len
            else:
                current_batch_size = args.batch_size

            input_ids_sl, attention_mask_sl = slice_batch_for_seq_len(
                input_ids=input_ids,
                attention_mask=attention_mask,
                seq_len=seq_len,
                batch_size=current_batch_size,
            )

            bench_row, avg_layer_ms = benchmark_one(
                model=model,
                input_ids=input_ids_sl,
                attention_mask=attention_mask_sl,
                reps=args.reps,
                warmup_reps=args.warmup_reps,
                measure_mlp_layers=args.measure_mlp_layers,
                measure_energy=args.measure_energy,
                energy_poll_interval_ms=args.energy_interval_ms,
                measure_mlp_energy=args.measure_mlp_energy,
                measure_mlp_energy_reps=args.measure_mlp_energy_reps,
            )

            bench_row.update({
                "implementation": str(implementation_name),
                "batch_size": int(current_batch_size),
                "seq_len": int(seq_len),
                "dtype": str(args.dtype),
                "device": str(device),
                "run_dir": str(run_dir),
                "state_dict_file": str(state_dict_file),
                "missing_keys": int(len(incompatible.missing_keys)),
                "unexpected_keys": int(len(incompatible.unexpected_keys)),
            })
            rows.append(bench_row)

            keep_layers = (
                args.seq_len_step is None
                or int(seq_len) == int(args.seq_len)
            )
            if avg_layer_ms is not None and keep_layers:
                for layer_idx, ms in enumerate(avg_layer_ms):
                    layer_rows.append({
                        "implementation": str(implementation_name),
                        "layer_idx": int(layer_idx),
                        "avg_mlp_layer_ms": float(ms),
                        "batch_size": int(current_batch_size),
                        "seq_len": int(seq_len),
                        "dtype": str(args.dtype),
                        "device": str(device),
                        "run_dir": str(run_dir),
                    })

        del model
        gc.collect()
        free_cuda_memory()

    df_all = pd.DataFrame(rows)

    df_all["num_tokens"] = (
        df_all["batch_size"] * df_all["seq_len"]
    )
    df_all["speed_total_tok_s"] = (
        df_all["num_tokens"] * 1000.0 / df_all["avg_total_ms"]
    )
    df_all["speed_mlp_tok_s"] = (
        df_all["num_tokens"] * 1000.0 / df_all["avg_mlp_ms"]
    )
    if "energy_per_forward_j" in df_all.columns:
        df_all["energy_per_token_mj"] = (
            df_all["energy_per_forward_j"] * 1000.0 / df_all["num_tokens"]
        )
    if "mlp_energy_per_forward_j" in df_all.columns:
        df_all["mlp_energy_per_token_mj"] = (
            df_all["mlp_energy_per_forward_j"] * 1000.0
            / df_all["num_tokens"]
        )

    df = df_all[df_all["seq_len"] == int(args.seq_len)].copy()

    baseline_name = args.baseline
    if baseline_name not in set(df["implementation"].to_list()):
        baseline_name = df["implementation"].iloc[0]
    
    base_total_speed_by_sl = (
        df_all[df_all["implementation"] == baseline_name]
        .set_index("seq_len")["speed_total_tok_s"]
    )
    df_all["speedup_total_vs_baseline_sl"] = (
        df_all["speed_total_tok_s"]
        / df_all["seq_len"].map(base_total_speed_by_sl)
    )

    base_mlp_speed_by_sl = (
        df_all[df_all["implementation"] == baseline_name]
        .set_index("seq_len")["speed_mlp_tok_s"]
    )
    df_all["speedup_mlp_vs_baseline_sl"] = (
        df_all["speed_mlp_tok_s"]
        / df_all["seq_len"].map(base_mlp_speed_by_sl)
    )

    base_row = df[df["implementation"] == baseline_name].iloc[0]
    base_total = float(base_row["avg_total_ms"])
    base_mlp = float(base_row["avg_mlp_ms"])

    df["speedup_total"] = base_total / df["avg_total_ms"]
    if not df["avg_mlp_ms"].isna().any() and base_mlp > 0.0:
        df["speedup_mlp"] = base_mlp / df["avg_mlp_ms"]
    else:
        df["speedup_mlp"] = float("nan")

    df.to_csv(out_csv, index=False)
    print(f"wrote {len(df)} rows to {out_csv}")

    if args.seq_len_step is not None:
        sweep_csv = out_csv.with_name(
            out_csv.stem + "_seq_len_sweep.csv"
        )
        df_all.to_csv(sweep_csv, index=False)
        print(f"wrote seq-len sweep csv to {sweep_csv}")
    

    if args.plot:
        plot_path = Path(args.plot_save)
        plot_path.parent.mkdir(parents=True, exist_ok=True)
        plot_timing_bars(df, out_path=plot_path)

        speedup_path = plot_path.with_name(
            plot_path.stem + "_speedup.png"
        )
        plot_speedup_bars(df, baseline_name=baseline_name,
                          out_path=speedup_path)

        print(f"saved plots to {plot_path} (+ _speedup.png)")

        if args.measure_energy:
            energy_plot_path = plot_path.with_name(
                plot_path.stem + "_energy.png"
            )
            plot_energy_bars(df, out_path=energy_plot_path)
            print(f"saved energy plot to {energy_plot_path}")

        if args.measure_mlp_energy:
            mlp_energy_plot_path = plot_path.with_name(
                plot_path.stem + "_mlp_energy.png"
            )
            plot_energy_bars(
                df,
                out_path=mlp_energy_plot_path,
                prefix="mlp_",
                title="MLP-only GPU power and energy per forward",
            )
            print(f"saved MLP energy plot to {mlp_energy_plot_path}")

        if args.seq_len_step is not None:
            total_speed_path = plot_path.with_name(
                plot_path.stem + "_speed_total_vs_seq_len.png"
            )
            plot_speed_vs_seq_len(
                df_all,
                speed_col="speed_total_tok_s",
                out_path=total_speed_path,
                title="total throughput vs sequence length",
            )
            print(f"saved total speed plot to {total_speed_path}")

            if not df_all["speed_mlp_tok_s"].isna().all():
                mlp_speed_path = plot_path.with_name(
                    plot_path.stem + "_speed_mlp_vs_seq_len.png"
                )
                plot_speed_vs_seq_len(
                    df_all,
                    speed_col="speed_mlp_tok_s",
                    out_path=mlp_speed_path,
                    title="mlp throughput vs sequence length",
                )
                print(f"saved mlp speed plot to {mlp_speed_path}")

            total_speedup_path = plot_path.with_name(
                plot_path.stem + "_speedup_total_vs_seq_len.png"
            )
            plot_speedup_vs_seq_len(
                df_all,
                speedup_col="speedup_total_vs_baseline_sl",
                out_path=total_speedup_path,
                title="total speedup vs sequence length",
                baseline_name=baseline_name,
            )
            print(f"saved total speedup plot to {total_speedup_path}")

            if not df_all["speedup_mlp_vs_baseline_sl"].isna().all():
                mlp_speedup_path = plot_path.with_name(
                    plot_path.stem + "_speedup_mlp_vs_seq_len.png"
                )
                plot_speedup_vs_seq_len(
                    df_all,
                    speedup_col="speedup_mlp_vs_baseline_sl",
                    out_path=mlp_speedup_path,
                    title="mlp speedup vs sequence length",
                    baseline_name=baseline_name,
                )
                print(f"saved mlp speedup plot to {mlp_speedup_path}")


        if args.measure_mlp_layers and layer_rows:
            layer_df = pd.DataFrame(layer_rows)
            layer_csv = out_csv.with_name(out_csv.stem + "_per_layer.csv")
            layer_df.to_csv(layer_csv, index=False)

            layer_plot_path = Path(args.plot_save)
            layer_plot_path = layer_plot_path.with_name(
                layer_plot_path.stem + "_per_layer.png"
            )
            plot_layer_results(layer_df, layer_plot_path)

            print(f"wrote per-layer csv to {layer_csv}")
            print(f"saved per-layer plot to {layer_plot_path}")


if __name__ == "__main__":
    main()
