#!/usr/bin/env python3
"""Evaluate intention prediction model being used as a unit controller"""
import enum
import inspect
import multiprocessing as mp
from itertools import product
from pathlib import Path
from typing import Annotated, Any

import typer
from absl import logging
from konductor.metadata.database.metadata import Metadata
from konductor.metadata.database.sqlite import DEFAULT_FILENAME
from s2clientprotocol import sc2api_pb2 as sc_pb
from sqlalchemy import create_engine
from sqlalchemy.orm import (
    DeclarativeBase,
    Mapped,
    MappedAsDataclass,
    mapped_column,
    sessionmaker,
)

import src.sc2_env.custom_env as custom
import src.sc2_env.smacv2_env as smacv2
from src.sc2_env.simulator import HotStart, SC2GameCfg
from src.utils.eval_common import EnvResult, ExperimentInitConfig, load_model

logging.set_verbosity(logging.WARNING)

app = typer.Typer()


class Base(MappedAsDataclass, DeclarativeBase):
    """subclasses will be converted to dataclasses"""


class EnvEntry(Base):
    __tablename__ = "custom_env"

    uid: Mapped[int] = mapped_column(init=False, primary_key=True, autoincrement=True)
    hash: Mapped[str]
    epoch: Mapped[int]
    iteration: Mapped[int]
    n_units: Mapped[int]
    n_enemies: Mapped[int]
    pos_dist: Mapped[str]
    race: Mapped[str]
    hotstart: Mapped[str]
    wins: Mapped[int]
    losses: Mapped[int]
    draws: Mapped[int]


class Race(str, enum.Enum):
    """Starcraft Race"""

    terran = "terran"
    zerg = "zerg"
    protoss = "protoss"


class PosDist(str, enum.Enum):
    """Different position distributions"""

    multi = "multi"
    surrounded_and_reflect = "surrounded_and_reflect"
    grouped = "grouped"
    surrounded = "surrounded"
    reflect_position = "reflect_position"


def get_unit_distributions(race: Race, n_units: int, n_enemies: int, pos_dist: str):
    """Get the unit type and position distributions used in smacv2 scenario randomization logic"""
    unit_types = {
        Race.terran: ["marine", "marauder", "medivac"],
        Race.zerg: ["zergling", "baneling", "hydralisk"],
        Race.protoss: ["stalker", "zealot", "colossus"],
    }[race]

    exception_unit_types = {
        Race.terran: ["medivac"],
        Race.zerg: ["baneling"],
        Race.protoss: [],
    }[race]

    pos_args: dict[str, Any] = {"dist_type": pos_dist}
    if pos_dist == "multi":
        pos_args["dists"] = ["surrounded", "reflect_position", "grouped"]
        pos_args["p"] = [0.33, 0.33, 0.34]
    elif pos_dist == "surrounded_and_reflect":
        pos_args["p"] = 0.5

    return {
        "n_units": n_units,
        "n_enemies": n_enemies,
        "team_gen": {
            "dist_type": "weighted_teams",
            "unit_types": unit_types,
            "exception_unit_types": exception_unit_types,
            "weights": [0.45, 0.45, 0.1],
            "observe": True,
        },
        "start_positions": pos_args,
    }


@app.command()
def merge_databases(
    dst: Annotated[Path, typer.Option()],
    src: Annotated[Path, typer.Option()],
    overwrite: Annotated[
        bool, typer.Option(help="Overwrite duplicates with new data")
    ] = False,
):
    """Merge results from one database into another"""
    src_engine = create_engine(f"sqlite:///{src}")
    dst_engine = create_engine(f"sqlite:///{dst}")

    src_session = sessionmaker(bind=src_engine)()
    dst_session = sessionmaker(bind=dst_engine)()

    copy_fields = inspect.signature(EnvEntry).parameters

    for entry in src_session.query(EnvEntry):
        tgt = (
            dst_session.query(EnvEntry)
            .filter_by(
                hash=entry.hash,
                iteration=entry.iteration,
                n_units=entry.n_units,
                n_enemies=entry.n_enemies,
                race=entry.race,
                pos_dist=entry.pos_dist,
                hotstart=entry.hotstart,
            )
            .first()
        )
        if tgt and overwrite:
            tgt.wins = entry.wins
            tgt.draws = entry.draws
            tgt.losses = entry.losses
        elif not tgt:
            dst_session.add(
                EnvEntry(
                    **{k: v for k, v in entry.__dict__.items() if k in copy_fields}
                )
            )

    dst_session.commit()

    src_session.close()
    dst_session.close()


def update_results_database(
    run_path: Path,
    race: Race,
    n_units: int,
    n_enemies: int,
    pos_dist: str,
    hotstart: HotStart,
    results: EnvResult,
    baseline=False,
):
    """Add win, loss, draw, n_samples to results database.
    If baseline is true, save result hash as 'baseline'"""
    run_meta = Metadata.from_yaml(run_path / "metadata.yaml")

    engine = create_engine(f"sqlite:///{run_path.parent / DEFAULT_FILENAME}")

    Base.metadata.create_all(engine)

    run_hash = "baseline" if baseline else run_path.stem
    epoch = 0 if baseline else run_meta.epoch
    iteration = 0 if baseline else run_meta.iteration

    with sessionmaker(bind=engine)() as session:
        entry = (
            session.query(EnvEntry)
            .filter_by(
                hash=run_hash,
                iteration=iteration,
                n_units=n_units,
                n_enemies=n_enemies,
                race=race,
                pos_dist=pos_dist,
                hotstart=hotstart,
            )
            .first()
        )
        if entry:
            entry.wins = results.wins
            entry.draws = results.draws
            entry.losses = results.losses
        else:
            session.add(
                EnvEntry(
                    hash=run_hash,
                    epoch=epoch,
                    iteration=iteration,
                    n_units=n_units,
                    n_enemies=n_enemies,
                    pos_dist=pos_dist,
                    race=race,
                    hotstart=hotstart.name,
                    wins=results.wins,
                    losses=results.losses,
                    draws=results.draws,
                )
            )
        session.commit()


def run_trial(
    run_path: Path,
    game_cfg: SC2GameCfg,
    n_samples: int,
    baseline: bool,
    smac: bool,
    visualize: bool,
):
    """Run trial"""
    cfg = ExperimentInitConfig.from_run(run_path)
    model = None if baseline else load_model(cfg)

    if smac:
        results = smacv2.run_evaluation(n_samples, game_cfg, model, visualize)
    else:
        results = custom.run_evaluation(n_samples, game_cfg, model, cfg, visualize)

    return results


def run_parallel(n_parallel: int, kwargs: dict):
    """Run trials in parallel"""
    kwargs["visualize"] = False
    mpool = mp.get_context("forkserver").Pool(processes=n_parallel)
    futures = [mpool.apply_async(run_trial, kwds=kwargs) for _ in range(n_parallel)]
    results = [f.get() for f in futures]
    mpool.close()
    mpool.join()
    result = EnvResult()
    for attr in result.__slots__:
        setattr(result, attr, sum(getattr(r, attr) for r in results))
    return result


@app.command()
def run_sample(
    run_path: Annotated[Path, typer.Option()],
    n_samples: Annotated[int, typer.Option()] = 1024,
    smac: Annotated[bool, typer.Option()] = False,
    baseline: Annotated[bool, typer.Option(help="no-action baseline")] = False,
    race: Annotated[Race, typer.Option()] = Race.terran,
    n_units: Annotated[int, typer.Option()] = 5,
    n_enemies: Annotated[int, typer.Option()] = 5,
    n_parallel: Annotated[int, typer.Option()] = 1,
    pos_dist: Annotated[PosDist, typer.Option()] = PosDist.multi,
    hotstart: Annotated[HotStart, typer.Option()] = HotStart.none,
    visualize: Annotated[bool, typer.Option()] = False,
):
    """Play StarCraft 2 with SC2IntentPredictor model"""
    if visualize:
        assert n_parallel == 1, "Can't run more then 1 instance if visualizing"

    game_cfg = SC2GameCfg(
        pos_dist=get_unit_distributions(race, n_units, n_enemies, pos_dist),
        hotstart=hotstart,
        difficulty=sc_pb.VeryHard,
    )

    trial_kwargs = {
        "run_path": run_path,
        "game_cfg": game_cfg,
        "n_samples": n_samples,
        "baseline": baseline,
        "smac": smac,
    }

    if n_parallel > 1:
        trial_kwargs["n_samples"] //= n_parallel
        results = run_parallel(n_parallel, trial_kwargs)
    else:
        results = run_trial(**trial_kwargs, visualize=visualize)

    print(
        f"Win Rate: {100*results.wins / n_samples:.1f}%, "
        f"Loss Rate: {100*results.losses / n_samples:.1f}%, "
        f"Draw Rate: {100*results.draws / n_samples:.1f}%"
    )

    if not smac:
        update_results_database(
            run_path, race, n_units, n_enemies, pos_dist, hotstart, results, baseline
        )


@app.command()
def run_all(
    run_path: Annotated[Path, typer.Option()],
    n_samples: Annotated[int, typer.Option()] = 1024,
    baseline: Annotated[bool, typer.Option(help="no-action baseline")] = False,
    n_parallel: Annotated[int, typer.Option()] = 1,
):
    """Run all the common permutations of race, n_enemy, n_units and pos_dist"""
    for race, n_units, n_enemies, pos_dist, hotstart in product(
        [Race.protoss, Race.zerg, Race.terran],
        [5, 6],
        [5],
        ["surrounded", "reflect_position", "grouped", "multi"],
        [HotStart.none, HotStart.closest],
    ):
        run_sample(
            run_path,
            n_samples,
            False,
            baseline,
            race,
            n_units,
            n_enemies,
            n_parallel,
            pos_dist,
            hotstart,
            False,
        )


if __name__ == "__main__":
    app()
