import yaml
import sys
from pathlib import Path
from multiprocessing import Process
from utils import import_tf
from time import perf_counter
import pandas as pd
from itertools import product
from tqdm import tqdm
import glob
import os
import zstandard as zstd
import io
import chess
import numpy as np
import h5py

from utils import STORAGE_DIR, CFG_DIR, SF_PATH

STORAGE_DIR = "/storage1/fs1/XXXX-1/Active/chess"
DATA_DIR = f"{STORAGE_DIR}/data/csv_data_z3/curr"
MODEL_DIR = f"{STORAGE_DIR}/leela/trained-models-curr13"


def parse_rating_stages(rating_arg):
    """
    Parse rating argument into cumulative stages.
    e.g., "800-1600-2400.1600.2400" becomes:
    ["800-1600-2400", "800-1600-2400.1600", "800-1600-2400.1600.2400"]
    """
    stages = rating_arg.split(".")
    cumulative_stages = []

    for i in range(len(stages)):
        cumulative_stage = ".".join(stages[: i + 1])
        cumulative_stages.append(cumulative_stage)

    return cumulative_stages


def get_current_stage_ratings(stage_str):
    stages = stage_str.split(".")
    current_stage = stages[-1]  # Get the last stage
    try:
        return list(map(int, current_stage.split("-")))
    except ValueError:
        return current_stage.split("-")


def get_all_stage_ratings(stage_str):
    stages = stage_str.split(".")
    stages = [list(map(int, stage.split("-"))) for stage in stages]
    return list(set(rating for stage in stages for rating in stage))


def get_disk_cache_path(stage_str, budget, seed):
    ratings = get_current_stage_ratings(stage_str)
    return f"/scratch1/fs1/XXXX-1/Active/tf_cache13/{'-'.join([str(r) for r in ratings])}_{budget}_{seed}"


def get_puzzle_datasets(stage_str, budget, seed):
    from data import puzzle_to_dataset

    ratings = get_current_stage_ratings(stage_str)

    train_budget = int(2**budget * 0.75)
    test_budget = int(2**budget * 0.25)

    # if we have multiple ratings that all need some data, we need a smaller batch size
    bs = min(1024, test_budget // 4)
    train_take = train_budget // bs
    train_take = [train_take // 4 for _ in range(len(ratings))]
    test_take = test_budget // bs
    test_take = [test_take // 4 for _ in range(len(ratings))]

    shuffle_buffer_size = bs * 50

    print(
        "puzzle",
        stage_str,
        budget,
        seed,
        train_budget,
        train_take,
        test_budget,
        test_take,
    )

    if isinstance(ratings[0], str) and ratings[0].startswith("sf"):
        if budget > 12:
            cache_path = get_disk_cache_path(stage_str, budget, seed)
            cache_train, cache_test = f"{cache_path}.train", f"{cache_path}.test"
        else:
            cache_train, cache_test = "", ""

        sf_data_path = (
            f"/storage1/fs1/XXXX-1/Active/chess/data/sf_matches/{ratings[0]}"
        )
        train = puzzle_to_dataset(
            puzzle_path=f"{sf_data_path}_train.csv.zst",
            batch_size=bs,
            shuffle_buffer_size=shuffle_buffer_size,
            seed=seed,
            take=train_take[0] * bs,
            combine_outputs=True,
            cache_binary=cache_train,
            both_sides=True,
        ).prefetch(10)

        test = puzzle_to_dataset(
            puzzle_path=f"{sf_data_path}_test.csv.zst",
            batch_size=bs,
            shuffle_buffer_size=shuffle_buffer_size,
            seed=seed,
            take=test_take[0] * bs,
            combine_outputs=True,
            cache_binary=cache_test,
            both_sides=True,
        ).prefetch(10)

    else:
        if ratings == [0]:
            puzzle_path = "/storage1/fs1/XXXX-1/Active/chess/data/puzzle/lichess_db_puzzle.csv.zst"
        else:
            puzzle_path = f"/storage1/fs1/XXXX-1/Active/chess/data/puzzle/merged_{ratings[0]}.csv.zst"

        train = puzzle_to_dataset(
            puzzle_path=puzzle_path,
            batch_size=bs,
            shuffle_buffer_size=shuffle_buffer_size,
            seed=seed,
            take=train_take[0],  # TODO temp change [0] to merging datasets
            combine_outputs=True,
            cache_binary="",
        ).prefetch(10)

        test = puzzle_to_dataset(
            puzzle_path=puzzle_path,
            batch_size=bs,
            shuffle_buffer_size=shuffle_buffer_size,
            seed=seed,
            skip=train_take[0],  # TODO temp change [0] to merging datasets
            take=test_take[0],  # TODO temp change [0] to merging datasets
            combine_outputs=True,
            cache_binary="",
        ).prefetch(10)

    print("about to iterate train")
    for collected_train, _ in tqdm(train, total=sum(train_take), ncols=80):
        pass
    print("about to reduce train")
    collected_train = train.reduce(0, lambda x, _: x + 1).numpy()
    if collected_train < sum(train_take):
        print(f"Only collected {collected_train}/{sum(train_take)} for training")
        return None, None, None
    print(
        f"Collected {collected_train} positions ({train_take} batches of size {bs}) for training",
        flush=True,
    )

    print("about to to iterate test")
    for collected_test, _ in tqdm(test, total=sum(test_take), ncols=80):
        pass
    print("about to reduce test")
    collected_test = test.reduce(0, lambda x, _: x + 1).numpy()
    if collected_test < sum(test_take):
        print(f"Only collected {collected_test}/{sum(test_take)} for testing")
        return None, None, None
    print(
        f"Collected {collected_test} positions ({test_take} batches of size {bs}) for testing",
        flush=True,
    )

    return train, test, bs


def get_datasets(stage_str, budget, seed, endgame=False):
    from data import csv_to_dataset

    # Only use ratings from the current (last) stage for training data
    ratings = get_current_stage_ratings(stage_str)

    train_budget = int(2**budget * 0.75)
    test_budget = int(2**budget * 0.25)

    # if we have multiple ratings that all need some data, we need a smaller batch size
    bs = min(1024, test_budget // 4)
    train_take = train_budget // bs
    train_take = [train_take // 4 for _ in range(len(ratings))]
    test_take = test_budget // bs
    test_take = [test_take // 4 for _ in range(len(ratings))]

    shuffle_buffer_size = bs * 100

    print(
        f"Loading datasets for {ratings} {budget} {seed} {bs} {train_take} {test_take} {shuffle_buffer_size}",
    )

    # Create data paths for all ratings
    train_paths = [f"{DATA_DIR}/train/*/{rating}/*.csv.gz" for rating in ratings]
    train_ranges = [(rating, rating + 100) for rating in ratings]
    test_paths = [f"{DATA_DIR}/test/*/{rating}/*.csv.gz" for rating in ratings]
    test_ranges = [(rating, rating + 100) for rating in ratings]

    if budget > 21:
        cache_path = get_disk_cache_path(stage_str, budget, seed)

        print("disk cache")
        cache_train, cache_test = f"{cache_path}.train", f"{cache_path}.test"

        # print('memory cache')
        # cache_train, cache_test = '', ''

        # print('no cache')
        # cache_train, cache_test = None, None

        print(f"{type(cache_train)}: '{cache_train}'")
        print(f"{type(cache_test)}: '{cache_test}'")

    else:
        cache_train = ""
        cache_test = ""

    train = csv_to_dataset(
        data_paths=train_paths,
        rating_ranges=train_ranges,
        batch_size=bs,
        seed=seed,
        take=train_take,
        combine_outputs=True,
        shuffle_buffer_size=shuffle_buffer_size,
        cache_binary=cache_train,
        endgame=endgame,
    ).prefetch(10)

    test = csv_to_dataset(
        data_paths=test_paths,
        rating_ranges=test_ranges,
        batch_size=bs,
        seed=seed,
        take=test_take,
        combine_outputs=True,
        shuffle_buffer_size=shuffle_buffer_size,
        cache_binary=cache_test,
        endgame=endgame,
    ).prefetch(10)

    print("about to iterate train")
    for collected_train, _ in tqdm(train, total=sum(train_take), ncols=80):
        pass
    print("about to reduce train")
    collected_train = train.reduce(0, lambda x, _: x + 1).numpy()
    if collected_train < sum(train_take):
        print(f"Only collected {collected_train}/{sum(train_take)} for training")
        return None, None, None
    print(
        f"Collected {collected_train} positions ({train_take} batches of size {bs}) for training",
        flush=True,
    )

    print("about to to iterate test")
    for collected_test, _ in tqdm(test, total=sum(test_take), ncols=80):
        pass
    print("about to reduce test")
    collected_test = test.reduce(0, lambda x, _: x + 1).numpy()
    if collected_test < sum(test_take):
        print(f"Only collected {collected_test}/{sum(test_take)} for testing")
        return None, None, None
    print(
        f"Collected {collected_test} positions ({test_take} batches of size {bs}) for testing",
        flush=True,
    )

    return train, test, bs


def get_cfg(**kwargs):
    arch, filters, blocks = kwargs["model"].split("-")

    if arch == "t74":
        cfg_path = f"{CFG_DIR}/t74-config.yaml"
    elif arch == "t82":
        cfg_path = f"{CFG_DIR}/t82-config.yaml"

    with open(cfg_path, "r") as f:
        cfg = yaml.safe_load(f)

    if arch == "t74":
        cfg["model"]["filters"] = int(filters)
        cfg["model"]["blocks"] = int(blocks)
    elif arch == "t82":
        cfg["gpu"] = -1
    return cfg


def load_tfp(**kwargs):
    from nets.tfprocess import TFProcess
    import tensorflow as tf

    cfg = get_cfg(**kwargs)

    tfp = TFProcess(cfg)
    tfp.init_net()

    stages = kwargs["rating_str"].split(".")

    if len(stages) > 1:
        prev_kwargs = kwargs.copy()
        prev_kwargs["rating_str"] = ".".join(stages[:-1])  # all but the last stage
        previous_model_str = get_model_str(**prev_kwargs)
        patience_values = get_patience_values(kwargs["budget"])
        previous_model_path = (
            f"{MODEL_DIR}/{previous_model_str}_best-{max(patience_values)}"
        )

        if Path(previous_model_path).exists():
            print(f"Loading weights from previous stage: {previous_model_path}")
            tfp.replace_weights(previous_model_path)
        else:
            raise ValueError(f"Previous stage model not found: {previous_model_path}")

    if kwargs["opt"] == "SGD":
        optimizer = tf.keras.optimizers.SGD(
            learning_rate=kwargs["lr"], momentum=0.9, nesterov=True
        )
    elif kwargs["opt"] == "adam":
        optimizer = tf.keras.optimizers.Adam(learning_rate=kwargs["lr"])
    elif kwargs["opt"] == "nadam":
        optimizer = tf.keras.optimizers.Nadam(learning_rate=kwargs["lr"])
    else:
        raise ValueError(f"Invalid optimizer: {kwargs['opt']}")

    tfp.model.compile(
        optimizer=optimizer,
        loss=[tfp.policy_loss_fn, tfp.value_loss_fn, tfp.moves_left_loss_fn],
        metrics=[["accuracy"], ["accuracy"], ["mse"]],
    )

    return tfp, cfg


def get_model_str(**kwargs):
    stage_str = kwargs["rating_str"]
    return f"{kwargs['model']}_{kwargs['opt']}_{kwargs['lr']}_{stage_str}_{kwargs['budget']}_{kwargs['seed']}"


# def get_patience_values(budget):
#     if budget <= 16:
#         patience_values = [20]  # 8 10 12
#     elif budget <= 20:
#         patience_values = [20]  #  14 15 16 17 18 19 20
#     elif budget <= 24:
#         patience_values = [20]  #  21 22 23 24
#     elif budget <= 28:
#         patience_values = [10]  #  25 26 27 28
#     return patience_values


def get_patience_values(budget):
    if budget <= 16:
        patience_values = [5]  # 8 10 12
    elif budget <= 20:
        patience_values = [5]  #  14 15 16 17 18 19 20
    elif budget <= 24:
        patience_values = [5]  #  21 22 23 24
    elif budget <= 28:
        patience_values = [5]  #  25 26 27 28
    return patience_values


def main(**kwargs):
    stage_str = kwargs["rating_str"]  # Single stage string
    epochs = kwargs["epochs"]
    seed = kwargs["seed"]
    budget = kwargs["budget"]
    endgame = kwargs["endgame"]
    puzzle = kwargs["puzzle"]

    patience_values = get_patience_values(budget)

    model_str = get_model_str(**kwargs)
    best_path = f"{MODEL_DIR}/{model_str}_best-{max(patience_values)}"

    if Path(best_path).exists():
        print(f"Model {best_path} already exists", flush=True)
        return
    else:
        print(f"Model {best_path} starting training or resuming...", flush=True)

    import_tf()
    import tensorflow as tf

    if puzzle:
        train, test, _ = get_puzzle_datasets(stage_str, budget, seed)
    else:
        train, test, _ = get_datasets(stage_str, budget, seed, endgame=endgame)

    if "data_only" in kwargs and kwargs["data_only"]:
        return

    if train is None or test is None:
        return

    tfp, _ = load_tfp(**kwargs)

    heads = ["/".join(i.name.split(":")[0].split("/")[:-1]) for i in tfp.model.outputs]

    class CustomEarlyStopping(tf.keras.callbacks.Callback):
        def __init__(self, patiences, model_str, start_from_epoch=0):
            super().__init__()
            self.patiences = sorted(patiences)  # Ensure patiences are sorted
            self.best_weights_dict = {
                p: None for p in patiences
            }  # Stores actual weights
            self.current_best_overall_weights = None  # Stores the weights of the best model seen so far, regardless of patience
            self.wait = 0
            self.best_loss = None
            self.model_str = model_str
            self.start_from_epoch = start_from_epoch
            self.metrics_to_print = {
                "loss": "train_loss",
                "val_loss": "val_loss",
                f"val_{heads[0]}_accuracy": "policy_acc",
            }

        def on_epoch_begin(self, epoch, logs=None):
            self.epoch_start_time = perf_counter()

        def on_epoch_end(self, epoch, logs=None):
            current_loss = logs.get("val_loss")
            training_time = perf_counter() - self.epoch_start_time

            if epoch == 0:
                self.first_loss = logs.get("val_loss")

            if self.best_loss is None or current_loss < self.best_loss:
                self.best_loss = current_loss
                self.wait = 0
                self.current_best_overall_weights = self.model.get_weights()
            else:
                self.wait += 1

            for p in self.patiences:
                if self.wait >= p and self.best_weights_dict[p] is None:
                    self.best_weights_dict[p] = self.current_best_overall_weights

                    print(
                        f"({self.model_str}) Epoch {epoch + 1}: Stored Keras weights for patience {p}",
                        flush=True,
                    )

            metrics_str = []
            for metric, name in self.metrics_to_print.items():
                if metric in logs:
                    metrics_str.append(f"{name}: {logs[metric]:.4f}")

            # Check if all patiences have had their weights saved
            if all(self.best_weights_dict.values()):
                if not self.model.stop_training:
                    self.model.stop_training = True
                    print(
                        f"({self.model_str}) Epoch {epoch + 1} completed in {training_time:.2f}s - {', '.join(metrics_str)}, wait: {self.wait}/{max(patience_values)}: All patience levels triggered. Stopping training.",
                        flush=True,
                    )
            else:
                print(
                    f"({self.model_str}) Epoch {epoch + 1} completed in {training_time:.2f}s - {', '.join(metrics_str)}, wait: {self.wait}/{max(patience_values)}",
                    flush=True,
                )

    early_stopping = CustomEarlyStopping(patiences=patience_values, model_str=model_str)

    tfp.model.fit(
        train,
        epochs=epochs,
        validation_data=test,
        verbose=0,
        callbacks=[
            early_stopping,
        ],
    )

    for p in patience_values:
        best_weights = early_stopping.best_weights_dict[p]
        tfp.model.set_weights(best_weights)
        tfp.save_leelaz_weights(f"{MODEL_DIR}/{model_str}_best-{p}")

    for path in glob.glob(f"{get_disk_cache_path(stage_str, budget, seed)}*"):
        os.remove(path)


def _z_to_wp(z: np.ndarray) -> float:
    return (z[0] + z[1] / 2) / 1000


if __name__ == "__main__":
    mode = sys.argv[1]
    model = sys.argv[2]
    budgets = sys.argv[3]
    rating_arg = sys.argv[4]
    seeds = sys.argv[5]

    Path(MODEL_DIR).mkdir(parents=True, exist_ok=True)

    if "-" in seeds:
        seeds = list(range(*map(int, seeds.split("-"))))
    else:
        seeds = [int(seeds)]

    if "-" in budgets:
        budgets = list(map(int, budgets.split("-")))
    else:
        budgets = [int(budgets)]

    model, opt, lr = model.split("_")
    lr = float(lr)

    # Parse stages once
    stages = parse_rating_stages(rating_arg)

    for budget, seed in product(budgets, seeds):
        # for seed, budget in product(seeds, budgets):
        print(f"Processing seed {seed}, budget {budget}")

        if mode == "train" or mode == "endgame" or mode == "train_puzzle":
            if mode == "endgame":
                MODEL_DIR = MODEL_DIR + "_endgame"
            elif mode == "train_puzzle":
                MODEL_DIR = MODEL_DIR.split("_")[0] + "_puzzle"

            # Train each stage in sequence
            for stage_idx, stage_str in enumerate(stages):
                print(f"\n=== Stage {stage_idx + 1}/{len(stages)}: {stage_str} ===")
                current_params = {
                    "mode": mode,
                    "rating_str": stage_str,  # Pass single stage string
                    "opt": opt,
                    "lr": lr,
                    "model": model,
                    "epochs": 1000,
                    "seed": seed,
                    "budget": budget,
                    "endgame": "endgame" in mode,
                    "puzzle": "puzzle" in mode,
                    "data_only": "data_only" in mode,
                }
                main(**current_params)

        # Match testing only runs on the final stage
        if mode == "match":
            import run_match

            if rating_arg in ["0", "100", "1000"]:
                MODEL_DIR = MODEL_DIR.split("_")[0] + "_puzzle"

            if rating_arg.startswith("sf"):
                MODEL_DIR = MODEL_DIR.split("_")[0] + "_puzzle"

            output_path = f"{MODEL_DIR}/results.csv"

            print(output_path)

            n = 1024

            if len(sys.argv) <= 6:
                continue
            match_arg = sys.argv[6]
            if "-" in match_arg:
                m0, m1 = map(int, match_arg.split("-"))
            else:
                m0, m1 = int(match_arg), int(match_arg) + 1

            cfg_match = get_cfg(model=model)
            patiences = get_patience_values(budget)
            print("patiences", patiences, seed, budget, rating_arg)

            if not Path(output_path).exists():
                with open(output_path, "w") as f:
                    print(
                        "model,opt,lr,rating,budget,seed,patience,matchseed,n,res",
                        file=f,
                    )

            # For matching, we only test the final stage
            final_stage = stages[-1]

            # -2 because we don't really need to check the low values...
            for p in patiences[-1:]:
                current_params = {
                    "mode": mode,
                    "rating_str": final_stage,
                    "opt": opt,
                    "lr": lr,
                    "model": model,
                    "epochs": 1000,
                    "seed": seed,
                    "budget": budget,
                }
                model_str = get_model_str(**current_params)
                tfp1 = f"{MODEL_DIR}/{model_str}_best-{p}"

                if not Path(tfp1).exists():
                    print(f"Model {tfp1} does not exist")
                    continue

                m_idxs = []
                for m_idx in range(m0, m1):
                    df = pd.read_csv(output_path)
                    df["rating"] = df.rating.astype(str)
                    if df[
                        (df["model"] == model)
                        & (df["opt"] == opt)
                        & (df["lr"] == lr)
                        & (df["rating"] == final_stage)
                        & (df["budget"] == budget)
                        & (df["seed"] == seed)
                        & (df["patience"] == p)
                        & (df["matchseed"] == m_idx)
                        & (df["n"] == n)
                    ].empty:
                        print(
                            f"Running match {m_idx} for {opt}_{lr}_{final_stage}_{budget}_{seed}_best-{p}",
                            flush=True,
                        )
                        m_idxs.append(m_idx)
                    else:
                        print(
                            f"Match result {m_idx} for {opt}_{lr}_{final_stage}_{budget}_{seed}_best-{p} already exists",
                        )
                if not m_idxs:
                    continue

                #                 res = run_match.run_multi_match(
                #                     [
                #                         {
                #                             "tfp1": tfp1,
                #                             "tfp2": "random",
                #                             "n": n,
                #                             "verbose": True,
                #                             "cfg": cfg_match,
                #                             "seed": m_idx,
                #                         }
                #                         for m_idx in m_idxs
                #                     ]
                #                 )

                #                 with open(f"{MODEL_DIR}/results.csv", "a") as f:
                #                     for m_idx, r in zip(m_idxs, res):
                #                         print(
                #                             f"{model},{opt},{lr},{final_stage},{budget},{seed},{p},{m_idx},{n},{r}",
                #                             file=f,
                #                         )

                with open(output_path, "a") as f:
                    for m_idx in m_idxs:
                        r, boards1, boards2 = run_match.run_multi_match(
                            [
                                {
                                    "tfp1": tfp1,
                                    "tfp2": "random",
                                    "n": n,
                                    "verbose": True,
                                    "cfg": cfg_match,
                                    "seed": m_idx,
                                }
                            ]
                        )[0]
                        print(
                            f"{model},{opt},{lr},{final_stage},{budget},{seed},{p},{m_idx},{n},{r}",
                            file=f,
                            flush=True,
                        )

        # Match testing only runs on the final stage
        if mode == "match_puzzle":
            MODEL_DIR = f"{MODEL_DIR}_puzzle"
            output_path = f"{MODEL_DIR}/results_puzzle.csv"
            import run_match
            from data import stream_lines

            n = 1024
            i = seed * n + 5_000_000
            j = i + n

            STORAGE_DIR = "/storage1/fs1/XXXX-1/Active/chess"
            lines = list(stream_lines(f"{STORAGE_DIR}/data/lichess_db_puzzle.csv.zst"))[
                1:
            ]

            boards = []
            for line in lines[i:j]:
                fen, moves = line.split(",")[1:3]
                board = chess.Board(fen)
                moves = moves.split()
                board.push_uci(moves[0])
                boards.append(board)

            if len(sys.argv) <= 6:
                continue
            match_arg = sys.argv[6]
            if "-" in match_arg:
                m0, m1 = map(int, match_arg.split("-"))
            else:
                m0, m1 = int(match_arg), int(match_arg) + 1

            cfg_match = get_cfg(model=model)
            patiences = get_patience_values(budget)
            print("patiences", patiences, seed, budget, rating_arg)
            if not Path(output_path).exists():
                with open(output_path, "w") as f:
                    print(
                        "model,opt,lr,rating,budget,seed,patience,matchseed,n,res",
                        file=f,
                    )
            final_stage = stages[-1]
            for p in patiences[-1:]:
                current_params = {
                    "mode": mode,
                    "rating_str": final_stage,
                    "opt": opt,
                    "lr": lr,
                    "model": model,
                    "epochs": 1000,
                    "seed": seed,
                    "budget": budget,
                }
                model_str = get_model_str(**current_params)
                tfp1 = f"{MODEL_DIR}/{model_str}_best-{p}"

                if not Path(tfp1).exists():
                    print(f"Model {tfp1} does not exist")
                    continue

                m_idxs = []
                for m_idx in range(m0, m1):
                    df = pd.read_csv(output_path)
                    df["rating"] = df.rating.astype(str)
                    if df[
                        (df["model"] == model)
                        & (df["opt"] == opt)
                        & (df["lr"] == lr)
                        & (df["rating"] == final_stage)
                        & (df["budget"] == budget)
                        & (df["seed"] == seed)
                        & (df["patience"] == p)
                        & (df["matchseed"] == m_idx)
                        & (df["n"] == n)
                    ].empty:
                        print(
                            f"Running match {m_idx} for {opt}_{lr}_{final_stage}_{budget}_{seed}_best-{p}",
                            flush=True,
                        )
                        m_idxs.append(m_idx)
                    else:
                        print(
                            f"Match result {m_idx} for {opt}_{lr}_{final_stage}_{budget}_{seed}_best-{p} already exists",
                        )
                if not m_idxs:
                    continue
                with open(output_path, "a") as f:
                    for m_idx in m_idxs:
                        r, boards1, boards2 = run_match.run_multi_match(
                            [
                                {
                                    "tfp1": tfp1,
                                    "tfp2": "random",
                                    "boards": boards,
                                    "verbose": True,
                                    "cfg": cfg_match,
                                    "seed": m_idx,
                                }
                            ]
                        )[0]
                        print(
                            f"{model},{opt},{lr},{final_stage},{budget},{seed},{p},{m_idx},{n},{r}",
                            file=f,
                        )

        # Match testing only runs on the final stage
        if mode == "match_analysis":
            output_path = f"{MODEL_DIR}/results_puzzle_analysis.csv"
            import run_match
            from data import stream_lines

            n = 256
            i = seed * n
            j = i + n

            STORAGE_DIR = "/storage1/fs1/XXXX-1/Active/chess"
            lines = list(stream_lines(f"{STORAGE_DIR}/data/lichess_db_puzzle.csv.zst"))[
                1:
            ]

            boards = []
            for line in lines[i:j]:
                fen, moves = line.split(",")[1:3]
                board = chess.Board(fen)
                moves = moves.split()
                board.push_uci(moves[0])
                boards.append(board)

            if len(sys.argv) <= 6:
                continue
            match_arg = sys.argv[6]
            if "-" in match_arg:
                m0, m1 = map(int, match_arg.split("-"))
            else:
                m0, m1 = int(match_arg), int(match_arg) + 1

            cfg_match = get_cfg(model=model)
            patiences = get_patience_values(budget)
            print("patiences", patiences, seed, budget, rating_arg)
            if not Path(output_path).exists():
                with open(output_path, "w") as f:
                    print(
                        "model,opt,lr,rating,budget,seed,patience,matchseed,i,res,avg",
                        file=f,
                    )
            final_stage = stages[-1]
            for p in patiences[-1:]:
                current_params = {
                    "mode": mode,
                    "rating_str": final_stage,
                    "opt": opt,
                    "lr": lr,
                    "model": model,
                    "epochs": 1000,
                    "seed": seed,
                    "budget": budget,
                }
                model_str = get_model_str(**current_params)
                tfp1 = f"{MODEL_DIR}/{model_str}_best-{p}"

                if not Path(tfp1).exists():
                    print(f"Model {tfp1} does not exist")
                    continue

                m_idxs = []
                for m_idx in range(m0, m1):
                    df = pd.read_csv(output_path)
                    df["rating"] = df.rating.astype(str)
                    if df[
                        (df["model"] == model)
                        & (df["opt"] == opt)
                        & (df["lr"] == lr)
                        & (df["rating"] == final_stage)
                        & (df["budget"] == budget)
                        & (df["seed"] == seed)
                        & (df["patience"] == p)
                        & (df["matchseed"] == m_idx)
                    ].empty:
                        print(
                            f"Running match {m_idx} for {opt}_{lr}_{final_stage}_{budget}_{seed}_best-{p}",
                            flush=True,
                        )
                        m_idxs.append(m_idx)
                    else:
                        print(
                            f"Match result {m_idx} for {opt}_{lr}_{final_stage}_{budget}_{seed}_best-{p} already exists",
                        )
                if not m_idxs:
                    continue
                with open(output_path, "a") as f:
                    for m_idx in m_idxs:
                        res, boards1, boards2 = run_match.run_multi_match(
                            [
                                {
                                    "tfp1": tfp1,
                                    "tfp2": "random",
                                    "boards": boards,
                                    "verbose": True,
                                    "cfg": cfg_match,
                                    "seed": m_idx,
                                }
                            ]
                        )[0]
                        with chess.engine.SimpleEngine.popen_uci(SF_PATH) as engine:
                            for i, board1 in enumerate(tqdm(boards1)):
                                temp_board = chess.Board(board1.history[0]["fen"])
                                avgs = []
                                for move_data in board1.history:
                                    legal_moves = list(temp_board.legal_moves)
                                    if not legal_moves:
                                        break
                                    result = engine.analyse(
                                        temp_board,
                                        chess.engine.Limit(nodes=10_000),
                                        multipv=len(legal_moves),
                                    )
                                    avg = np.mean(
                                        [
                                            _z_to_wp(info["score"].wdl())
                                            for info in result
                                        ]
                                    )
                                    avgs.append(avg)
                                    if "choice" not in move_data:
                                        break
                                    temp_board.push_uci(move_data["choice"])
                                winner = (
                                    board1.winner if board1.winner is not None else 0.5
                                )
                                print(
                                    f"{model},{opt},{lr},{final_stage},{budget},{seed},{p},{m_idx},{i},{winner},{np.mean(avgs) if avgs else -1}",
                                    file=f,
                                    flush=True,
                                )
                            for i, board2 in enumerate(tqdm(boards2)):
                                temp_board = chess.Board(board2.history[0]["fen"])
                                avgs = []
                                for move_data in board2.history:
                                    legal_moves = list(temp_board.legal_moves)
                                    if not legal_moves:
                                        break
                                    result = engine.analyse(
                                        temp_board,
                                        chess.engine.Limit(nodes=10_000),
                                        multipv=len(legal_moves),
                                    )
                                    avg = np.mean(
                                        [
                                            _z_to_wp(info["score"].wdl())
                                            for info in result
                                        ]
                                    )
                                    avgs.append(avg)
                                    if "choice" not in move_data:
                                        break
                                    temp_board.push_uci(move_data["choice"])
                                winner = (
                                    1 - board2.winner
                                    if board2.winner is not None
                                    else 0.5
                                )
                                print(
                                    f"{model},{opt},{lr},{final_stage},{budget},{seed},{p},{m_idx},{i + len(boards1)},{winner},{np.mean(avgs) if avgs else -1}",
                                    file=f,
                                    flush=True,
                                )

            # Match testing only runs on the final stage
        if mode == "match_endgame":
            import run_match

            MODEL_DIR = MODEL_DIR + "_endgame"
            output_path = f"{MODEL_DIR}/results_endgame.csv"

            n = 1024
            i = seed * n
            j = i + n

            boards = []
            rating = 2400
            for path in tqdm(
                glob.glob(f"{STORAGE_DIR}/leela/results_new/{rating}/*.hdf5")
            ):
                with h5py.File(path, "r") as f:
                    for v in f["positions"].values():
                        board = chess.Board(v.attrs["fen"])
                        if len(board.piece_map()) < 8:
                            boards.append(board)
            boards = boards[(seed % 9) :: 9][:1024]

            if len(sys.argv) <= 6:
                continue
            match_arg = sys.argv[6]
            if "-" in match_arg:
                m0, m1 = map(int, match_arg.split("-"))
            else:
                m0, m1 = int(match_arg), int(match_arg) + 1

            cfg_match = get_cfg(model=model)
            patiences = get_patience_values(budget)
            print("patiences", patiences, seed, budget, rating_arg)
            if not Path(output_path).exists():
                with open(output_path, "w") as f:
                    print(
                        "model,opt,lr,rating,budget,seed,patience,matchseed,n,res",
                        file=f,
                    )
            final_stage = stages[-1]
            for p in patiences[-1:]:
                current_params = {
                    "mode": mode,
                    "rating_str": final_stage,
                    "opt": opt,
                    "lr": lr,
                    "model": model,
                    "epochs": 1000,
                    "seed": seed,
                    "budget": budget,
                }
                model_str = get_model_str(**current_params)
                tfp1 = f"{MODEL_DIR}/{model_str}_best-{p}"

                if not Path(tfp1).exists():
                    print(f"Model {tfp1} does not exist")
                    continue

                m_idxs = []
                for m_idx in range(m0, m1):
                    df = pd.read_csv(output_path)
                    df["rating"] = df.rating.astype(str)
                    if df[
                        (df["model"] == model)
                        & (df["opt"] == opt)
                        & (df["lr"] == lr)
                        & (df["rating"] == final_stage)
                        & (df["budget"] == budget)
                        & (df["seed"] == seed)
                        & (df["patience"] == p)
                        & (df["matchseed"] == m_idx)
                        & (df["n"] == n)
                    ].empty:
                        print(
                            f"Running match {m_idx} for {opt}_{lr}_{final_stage}_{budget}_{seed}_best-{p}",
                            flush=True,
                        )
                        m_idxs.append(m_idx)
                    else:
                        print(
                            f"Match result {m_idx} for {opt}_{lr}_{final_stage}_{budget}_{seed}_best-{p} already exists",
                        )
                if not m_idxs:
                    continue
                with open(output_path, "a") as f:
                    for m_idx in m_idxs:
                        r, boards1, boards2 = run_match.run_multi_match(
                            [
                                {
                                    "tfp1": tfp1,
                                    "tfp2": "random",
                                    "boards": boards,
                                    "verbose": True,
                                    "cfg": cfg_match,
                                    "seed": m_idx,
                                }
                            ]
                        )[0]
                        print(
                            f"{model},{opt},{lr},{final_stage},{budget},{seed},{p},{m_idx},{n},{r}",
                            file=f,
                        )
