#!/usr/bin/env python

import numpy as np
import sys
from pathlib import Path
import subprocess
import os
import json
from typing import Literal
import logging

logging.basicConfig(level=logging.INFO)


GameType = Literal["disc", "recon"]

# fmt: off
disc_hparams_header = (
  ["n_attrs", "n_vals", "n_distractors", "n_examples", "vocab_size", "msg_len"]
 )
disc_variant_arr = [
  [        4,        4,               3,        10000,           10,        10],
  [        6,        6,               3,        10000,           10,        10],
  [        6,        6,               9,        10000,           10,        10],
  [        8,        8,               3,        10000,           10,        10],
]
# fmt: on
disc_variants = [
    dict(list(zip(disc_hparams_header, row)) + [("seed", s)])
    for s in range(0,3)
    for row in disc_variant_arr
]

# fmt: off
recon_hparams_header = (
  ["n_attrs", "n_vals", "n_examples", "vocab_size", "msg_len"]
 )
recon_variant_arr = [
  [        4,        4,        10000,           10,        10],
  [        6,        6,        10000,           10,        10],
  [        8,        8,        10000,           10,        10],
]
# fmt: on
recon_variants = [dict(zip(recon_hparams_header, row)) for row in recon_variant_arr]


def generate_data(
    *,
    game_type: GameType,
    n_attrs: int,
    n_vals: int,
    n_distractors: int | None = None,
    n_examples: int,
    **kwargs,
) -> Path:
    rng = np.random.default_rng(0)

    n_items = n_vals**n_attrs
    out_dir = Path("input-data")
    out_dir.mkdir(exist_ok=True, parents=True)
    match game_type:
        case "disc":
            assert n_distractors is not None
            fn = f"{n_attrs}-attr_{n_vals}-val_{n_distractors}-dist.txt"
        case "recon":
            fn = f"{n_attrs}-attr_{n_vals}-val.txt"
    out_path = out_dir / fn
    if out_path.exists():
        logging.info(f"{out_path} already exists; skipping data generation.")
    else:
        # TODO Optimize and refactor
        with out_path.open("w") as fo:
            match game_type:
                case "disc":
                    assert n_distractors is not None

                    def write_example() -> None:
                        raw = rng.choice(n_items, n_distractors + 1, replace=False)
                        label = rng.integers(n_distractors + 1)
                        sample = np.empty((n_distractors + 1, n_attrs), dtype=np.int64)
                        for j in range(n_attrs):
                            sample[:, j] = raw // (n_vals**j) % n_vals
                        fo.write(
                            " . ".join(" ".join(str(x) for x in row) for row in sample)
                        )
                        fo.write(" . ")
                        fo.write(str(label))
                        fo.write("\n")

                case "recon":

                    def write_example() -> None:
                        raw = rng.choice(n_items)
                        sample = raw // (n_vals ** np.arange(n_attrs)) % n_vals
                        fo.write(" ".join(str(x) for x in sample))
                        fo.write("\n")

            for i in range(n_examples):
                write_example()
    return out_path


def run_environment(game_type: GameType, cfg: dict, data_path: Path) -> Path:
    out_dir = Path("output-data")
    out_dir.mkdir(parents=True, exist_ok=True)
    match game_type:
        case "disc":
            fn_parts = [
                cfg["n_attrs"] + "-attr",
                cfg["n_vals"] + "-val",
                cfg["n_distractors"] + "-dist",
                cfg["seed"] + "-seed",
                # cfg["vocab_size"] + "-vocab",
                # cfg["msg_len"] + "-len",
            ]
        case "recon":
            fn_parts = [
                cfg["n_attrs"] + "-attr",
                cfg["n_vals"] + "-val",
                cfg["vocab_size"] + "-vocab",
                cfg["msg_len"] + "-len",
            ]
    out_file = out_dir / ("_".join(fn_parts) + ".out")
    if out_file.exists():
        logging.info(f"{out_file} already exists; skipping running environment.")
    else:
        # fmt: off
        cmd = [
            "python",
            "-m", "egg.zoo.basic_games.play",
            "--mode", "gs",
            "--game_type", "discri" if game_type == "disc" else "recon",
            "--train_data", str(data_path),
            "--validation_data", str(data_path),
            "--n_attributes", cfg["n_attrs"],
            "--n_values", cfg["n_vals"],
            "--n_epochs", "100",
            "--batch_size", "1024",
            "--validation_batch_size", "1024",
            "--max_len", cfg["msg_len"],
            "--vocab_size", cfg["vocab_size"],
            "--random_seed", cfg.get("seed", 0),
            "--sender_hidden", "256",
            "--receiver_hidden", "512",
            "--sender_embedding", "32",
            "--receiver_embedding", "32",
            "--receiver_cell", "gru",
            "--sender_cell", "gru",
            "--lr", "0.001",
            "--print_validation_events",
        ]
        # fmt: on
        env = os.environ.copy()
        python_path = env.get("PYTHONPATH", "") + ":repo"
        env["PYTHONPATH"] = python_path.lstrip(":")
        with out_file.open("w") as fo:
            subprocess.run(cmd, stdout=fo, env=env)
    return out_file


def process_output(raw_out_path: Path) -> None:
    out_path = Path("../data") / raw_out_path.stem / "corpus.jsonl"
    out_path.parent.mkdir(parents=True, exist_ok=True)
    if out_path.exists():
        logging.info(f"{out_path} already exists; skipping output processing.")
    else:
        with raw_out_path.open() as fo:
            while "MESSAGES\n" != (line := fo.readline()):
                if line == "":
                    raise ValueError("Could not find MESSAGES block in output file.")
            messages = fo.readline()
        with out_path.open("w") as fo:
            messages = json.loads(messages)
            messages = np.argmax(messages, -1).tolist()
            for m in messages:
                fo.write(json.dumps(m, separators=(",", ":")))
                fo.write("\n")


def main() -> None:
    cwd_parent = os.getcwd().split("/")[-2]
    logging.info(
        f'Using working directory parent "{cwd_parent}" to detect what config to run...'
    )
    match cwd_parent:
        case "egg-discrimination":
            logging.info("Running discrimination...")
            for cfg in disc_variants:
                data_path = generate_data(game_type="disc", **cfg)
                cfg_strs = {k: str(v) for k, v in cfg.items()}
                out_path = run_environment("disc", cfg_strs, data_path)
                process_output(out_path)
        case "egg-reconstruction":
            logging.info("Running reconstruction...")
            for cfg in recon_variants:
                data_path = generate_data(game_type="recon", **cfg)
                cfg_strs = {k: str(v) for k, v in cfg.items()}
                out_path = run_environment("recon", cfg_strs, data_path)
                process_output(out_path)
        case _:
            logging.error("No match found.")


if __name__ == "__main__":
    main()
