import polars as pl
import kdai
import kdai.lrfind
import logging
import numpy as np
from pathlib import Path
import kdtpp.experiments as exp
import kdai._logging


_logger = logging.getLogger(__name__)


def process(lrs, losses):
    loss_arr = kdai.lrfind.to_loss_arr(losses)
    lrs = np.array(lrs)
    # It's possible that all losses are infinite.
    if loss_arr is None:
        return {
            "mloss": [],
            "dloss": [],
            "smoothed_dloss": [],
            "cum_smoothed_dloss": [],
            "mean_dloss": [],
            "cum_dloss": [],
            "kalman": [],
            "ckalman": [],
            "suggested_lr_idx_kalman": 0,
            "suggested_lr_kalman": lrs[0],
            "suggested_lr_idx_mloss": 0,
            "suggested_lr_mloss": lrs[0],
            "is_flat": int(False),  # Polars complains otherwise.
        }
    default_lr = 5e-4
    # Choose the index closest to the default LR.
    default_lr_idx = np.argmin(np.abs(lrs - default_lr))
    try:
        (
            kalman_dloss,
            suggested_lr_idx_kalman,
            suggested_lr_kalman,
        ) = kdai.lrfind.kalman_lr_choose(lrs, loss_arr)
        (
            mean_dloss,
            smoothed_dloss,
            suggested_lr_idx_mloss,
            suggested_lr_mloss,
        ) = kdai.lrfind.mean_lr_choose(lrs, loss_arr)
    except kdai.lrfind.TroublesomeLrCurve as e:
        _logger.error(f"Troublesome LR curve. Using default. {e=}")
        suggested_lr_idx_kalman = default_lr_idx
        suggested_lr_kalman = lrs[default_lr_idx]
        suggested_lr_idx_mloss = default_lr_idx
        suggested_lr_mloss = lrs[default_lr_idx]
        mean_dloss = np.zeros_like(loss_arr[0])
        smoothed_dloss = np.zeros_like(loss_arr[0])
        kalman_dloss = np.zeros_like(loss_arr[0])

    is_flat = kdai.lrfind.is_flat(loss_arr)
    if is_flat:
        _logger.info(
            f"Flat loss array. Choosing default LR ({default_lr_idx=},"
            f"{lrs[default_lr_idx]:.3e})"
        )
        suggested_lr_idx_kalman = default_lr_idx
        suggested_lr_kalman = lrs[default_lr_idx]
        suggested_lr_idx_mloss = default_lr_idx
        suggested_lr_mloss = lrs[default_lr_idx]

    m_start = np.mean([loss[0] for loss in loss_arr])

    return {
        "mloss": loss_arr.mean(axis=0).tolist(),
        "dloss": mean_dloss.tolist(),
        "smoothed_dloss": smoothed_dloss.tolist(),
        "cum_smoothed_dloss": (np.cumsum(smoothed_dloss) + m_start).tolist(),
        "mean_dloss": mean_dloss.tolist(),
        "cum_dloss": (np.cumsum(mean_dloss) + m_start).tolist(),
        "kalman": kalman_dloss.tolist(),
        "ckalman": (np.cumsum(kalman_dloss) + m_start).tolist(),
        "suggested_lr_idx_kalman": suggested_lr_idx_kalman,
        "suggested_lr_kalman": suggested_lr_kalman,
        "suggested_lr_idx_mloss": suggested_lr_idx_mloss,
        "suggested_lr_mloss": suggested_lr_mloss,
        "is_flat": int(is_flat),  # Polars complains otherwise.
    }


def linear_fit(ys):
    xs = np.arange(len(ys))
    ys = np.array(ys)
    coefs = np.polyfit(xs, ys, 1)
    fitted_ys = np.polyval(coefs, xs)
    return fitted_ys


def as_linear_fit(df, src_col):
    dfg = df.group_by(pl.col("model"), pl.col("ds")).agg(
        pl.col("batch_size"), pl.col(src_col)
    )
    col_name = f"{src_col}_linear_fit"
    dfg = dfg.with_columns(
        pl.col(src_col)
        .map_elements(lambda x: linear_fit(x), return_dtype=pl.List(pl.Float32))
        .alias(col_name)
    )
    dfg = dfg.explode("batch_size", col_name, src_col)
    df = df.join(dfg, on=["model", "ds", "batch_size"], how="left")
    return df


def add_median(df, col_names):
    median_cols = [
        pl.col(col).median().alias(f"{col}_median") for col in col_names
    ]
    group_medians = df.group_by(pl.col("model"), pl.col("ds")).agg(
        *median_cols,
    )
    df = df.join(group_medians, on=["model", "ds"], how="left")
    df = df.sort(pl.col("model"), pl.col("ds"), pl.col("batch_size"))
    return df


def add_stats(df):
    df = as_linear_fit(df, "suggested_lr_mloss")
    df = as_linear_fit(df, "suggested_lr_kalman")
    df = add_median(df, ["suggested_lr_mloss", "suggested_lr_kalman"])
    return df


def to_lrmap(df):
    # col = "suggested_lr_kalman_linear_fit"
    col = "suggested_lr_kalman_median"
    df = df.select(
        [
            pl.col("model"),
            pl.col("ds"),
            pl.col("batch_size"),
            pl.col(col).alias("lr"),
        ]
    )
    return df


def main():
    ver_parts = kdai._logging.version_labels_from_script_dir()
    out_dir = exp.start_logging(ver_parts)
    ver_str = "_".join(ver_parts)

    # Edit to point to the output of lrsweep.py.
    src_ver = "0.0.0"
    src_ver_str = src_ver.replace(".", "_")
    src_dir = Path("./out/exp") / src_ver.replace(".", "/")
    models = [
        "gpt-6-4-32-f",
        "gpt-2-4-16-f",
        "gpt-6-4-32-logmix",
        "gpt-2-4-16-logmix",
        "gpt-2-4-16-const",
        "gpt-2-4-16-exp",
        "gpt-2-4-16-nn",
    ]
    dfs = []
    for model in models:
        df = pl.read_parquet(src_dir / f"{model}_lrsweep_{src_ver_str}.parquet")
        df = df.with_columns(
            pl.struct(pl.col("lrs"), pl.col("loss"))
            .map_elements(
                lambda x: process(x["lrs"], x["loss"]), return_dtype=pl.Struct
            )
            .alias("temp")
        )
        df = df.with_columns(pl.col("temp").struct.field("*")).drop("temp")
        out_filename = f"{model}_calc_lrfind_{ver_str}.parquet"
        df = add_stats(df)
        df.write_parquet(out_dir / out_filename)
        dfs.append(df)
    df = pl.concat(dfs, how="vertical")
    out_filename = f"all_lrcalc_{ver_str}.parquet"
    df.write_parquet(out_dir / out_filename)
    to_lrmap(df).write_csv(out_dir / f"lrmap_{ver_str}.csv")


if __name__ == "__main__":
    logging.basicConfig(level=logging.WARNING)
    main()
