from pathlib import Path
from typing import Callable

from konductor.metadata.database.sqlite import DEFAULT_FILENAME
import pandas as pd
from sqlalchemy import create_engine, Engine
from sqlalchemy.orm import (
    DeclarativeBase,
    Mapped,
    MappedAsDataclass,
    mapped_column,
    sessionmaker,
)
import typer


class Base(MappedAsDataclass, DeclarativeBase):
    """subclasses will be converted to dataclasses"""


class GymGoalSummary(Base):
    __tablename__ = "gym_goal_summary"

    hash: Mapped[str] = mapped_column(primary_key=True)
    iteration: Mapped[int]
    top1: Mapped[float]


class GymUncovSummary(Base):
    __tablename__ = "gym_uncover_summary"

    hash: Mapped[str] = mapped_column(primary_key=True)
    iteration: Mapped[int]
    targets_auc: Mapped[float] = mapped_column(default=-1)
    targets_iou: Mapped[float] = mapped_column(default=-1)
    agents_auc: Mapped[float] = mapped_column(default=-1)
    agents_iou: Mapped[float] = mapped_column(default=-1)


class SC2Summary(Base):
    __tablename__ = "sc2_summary"

    hash: Mapped[str] = mapped_column(primary_key=True)
    iteration: Mapped[int]
    top1: Mapped[float]
    top5: Mapped[float]
    top1_null: Mapped[float]
    top5_null: Mapped[float]
    l2: Mapped[float]
    pos_f1: Mapped[float]
    pos_recall: Mapped[float]
    pos_precision: Mapped[float]
    pos_acc: Mapped[float]


app = typer.Typer()


def summarize_goal(hash: str, data: pd.DataFrame, engine: Engine):
    """Summarize goal statistics"""
    iteration: int = data.iteration.max()
    mean_data: pd.Series = data[data.iteration == iteration].mean()
    del mean_data["iteration"]
    del mean_data["timestamp"]
    mean_top1: float = mean_data.mean()

    with sessionmaker(bind=engine)() as session:
        entry = session.query(GymGoalSummary).filter_by(hash=hash).first()
        if entry and entry.iteration < iteration:
            entry.top1 = mean_top1
        elif not entry:
            session.add(GymGoalSummary(hash=hash, iteration=iteration, top1=mean_top1))
        session.commit()


def summarize_occupancy(hash: str, data: pd.DataFrame, engine: Engine):
    """Summarize occupancy statistics"""
    iteration: int = data.iteration.max()
    mean_data: pd.Series = data[data.iteration == iteration].mean()

    def gather_data(match: str):
        keys = [i for i in mean_data.index if i.startswith(match)]
        if len(keys) == 0:
            return -1
        val = mean_data[keys].mean()
        if 0 <= val <= 1:
            return val
        return -1

    targets_auc = gather_data("targets_AUC")
    targets_iou = gather_data("targets_IoU")
    agents_auc = gather_data("agents_AUC")
    agents_iou = gather_data("agents_IoU")

    with sessionmaker(bind=engine)() as session:
        entry = session.query(GymUncovSummary).filter_by(hash=hash).first()
        if entry and entry.iteration < iteration:
            entry.targets_auc = targets_auc
            entry.targets_iou = targets_iou
            entry.agents_auc = agents_auc
            entry.agents_iou = agents_iou
        elif not entry:
            session.add(
                GymUncovSummary(
                    hash, iteration, targets_auc, targets_iou, agents_auc, agents_iou
                )
            )
        session.commit()


def summarize_sc2(hash: str, data: pd.DataFrame, engine: Engine):
    """Summarize sc2 statistics"""
    iteration: int = data.iteration.max()
    mean_data: pd.Series = data[data.iteration == iteration].mean()

    def gather_data(match: str, null: bool = False):
        keys: list[str] = [
            i
            for i in mean_data.index
            if i.startswith(match) and null == i.endswith("null")
        ]
        if len(keys) == 0:
            return -1
        val = mean_data[keys].mean()
        if 0 <= val <= 1:
            return val
        return -1

    top1 = gather_data("top1")
    top5 = gather_data("top5")
    l2 = gather_data("l2")
    pos_acc = gather_data("pos-score-acc")
    pos_precision = gather_data("pos-score-precision")
    pos_recall = gather_data("pos-score-recall")
    pos_f1 = gather_data("pos-score-f1")
    top1_null = gather_data("top1", null=True)
    top5_null = gather_data("top5", null=True)

    with sessionmaker(bind=engine)() as session:
        entry = session.query(SC2Summary).filter_by(hash=hash).first()
        if entry and entry.iteration < iteration:
            entry.top1 = top1
            entry.top1_null = top1_null
            entry.top5 = top5
            entry.top5_null = top1_null
            entry.l2 = l2
            entry.pos_acc = pos_acc
            entry.pos_f1 = pos_f1
            entry.pos_precision = pos_precision
            entry.pos_recall = pos_recall
        elif not entry:
            session.add(
                SC2Summary(
                    hash,
                    iteration,
                    top1,
                    top5,
                    top1_null,
                    top5_null,
                    l2,
                    pos_f1,
                    pos_recall,
                    pos_precision,
                    pos_acc,
                )
            )
        session.commit()


LOG_MAPPING: dict[str, Callable[[str, pd.DataFrame, Engine], None]] = {
    "goal": summarize_goal,
    "occupancy": summarize_occupancy,
    "sc2-accuracy": summarize_sc2,
}


@app.command()
def run(run_path: Path):
    """Add to summary table the mean accuracy over the duration of the sequence"""
    engine = create_engine(f"sqlite:///{run_path.parent / DEFAULT_FILENAME}")
    Base.metadata.create_all(engine)

    val_files = run_path.glob("val_*.parquet")

    for log in val_files:
        _, logtype = log.stem.split("_", maxsplit=1)
        data = pd.read_parquet(log)
        try:
            LOG_MAPPING[logtype](run_path.name, data, engine)
        except KeyError:
            print(f"Failed to determine format for {logtype=}")


@app.command()
def workspace(workspace: Path):
    """Iterate over runs in a workspace and apply summarization"""
    for item in workspace.iterdir():
        if not item.is_dir():
            continue
        try:
            # Test if val parquet log exists
            next(item.glob("val_*.parquet"))
        except StopIteration:
            continue
        else:
            run(item)


if __name__ == "__main__":
    app()
