import json
from pathlib import Path
import logging
from typing import Any
import sys

import numpy as np
import joblib  # type: ignore
from tqdm import tqdm  # type: ignore
import pandas as pd

from .. import util
from .metric import metric_registry

logging.basicConfig(level=logging.INFO)


def load_data(path: Path) -> list[np.ndarray]:
    data = []
    with path.open() as fo:
        for line in fo:
            row = json.loads(line)
            if len(row) == 0:
                continue
            data.append(np.array(row))
    return data


def analyze_data(input_path: Path) -> dict[str, Any] | None:
    # TODO Account for EOS
    name = path_to_name(input_path)
    logging.debug(f"Analyzing: {name}")
    metrics = None
    try:
        data = load_data(input_path)
        metrics = dict(kv for f in metric_registry for kv in f(data).items())
        metadata_path = input_path.parent / "metadata.json"
        with util.update_json(metadata_path) as md:
            md["metrics"] = md.get("metrics", {})
            md["metrics"]["analysis"] = metrics
    except Exception as e:
        logging.warning(f"{name} failed due to {e}")
        raise e  # TODO remove
    finally:
        logging.debug(f"Finished: {name}")
        return metrics


def path_to_name(path: Path) -> str:
    comps = list(path.parents[-3:-2]) + list(path.parents[-5::-1])
    return "/".join(x.name for x in comps)


def main() -> None:
    paths = list(Path("./systems").glob("*/data/**/corpus.jsonl"))
    funcs = [joblib.delayed(analyze_data)(path) for path in paths]
    parallel = joblib.Parallel(n_jobs=-1, return_as="generator")(funcs)
    results = list(tqdm(parallel, total=len(funcs)))
    results = [x for x in results if x is not None]

    for p, r in zip(paths, results):
        r["name"] = path_to_name(p)

    df = pd.DataFrame(results)
    df.set_index("name", inplace=True)
    df.to_csv("table.csv")


def generate_plots(df: pd.DataFrame) -> None:
    summary = df.describe()
    summary.drop(["count", "mean", "std"], inplace=True)
    summary = summary.T

    to_int_rows = ["Unique Tokens", "Unique Lines", "Token Count", "Line Count"]
    col_rename = {
        "25%": "$25\\%$",
        "50%": "$50\\%$",
        "75%": "$75\\%$",
    }
    _summary = pd.DataFrame(columns=summary.columns)
    for k in summary.index:
        if k in to_int_rows:
            fmtr = lambda x: f"${int(x)}$"
        else:
            fmtr = lambda x: f"${x:.2f}$"
        _summary.loc[k] = summary.loc[k].apply(fmtr)
    summary = _summary
    summary.rename(columns=col_rename, inplace=True)
    # print(summary)
    summary.to_latex("summary.tex", column_format="lrrrrr")

    # print(df.corr())
    df = df.sort_values(by="name")

    df.rename(index=lambda x: x.replace("_", r"\_"), inplace=True)

    col_groups = [
        ["Token Count", "Line Count", "Tokens per Line", "Tokens per Line SD"],
        ["Unique Tokens", "Unique Lines", "EoS Token Present", "EoS Padding"],
        ["1-gram Entropy", "1-gram Normalized Entropy", "Entropy per Line"],
        ["2-gram Entropy", "2-gram Conditional Entropy"],
    ]
    for i, cols in enumerate(col_groups):
        df.to_latex(f"big-table-{i+1}.tex", columns=cols)
