import argparse
import os
import warnings
from logging import FileHandler, Formatter, StreamHandler, getLogger
from typing import Tuple, Dict, List

import numpy as np
from tqdm import tqdm

import models
import ntk
import wandb

warnings.filterwarnings("ignore")

log_fmt = Formatter(
    "%(asctime)s %(name)s L%(lineno)d [%(levelname)s][%(funcName)s] %(message)s "
)
logger = getLogger(__name__)
handler = StreamHandler()
handler.setLevel("INFO")
handler.setFormatter(log_fmt)
logger.setLevel("INFO")
logger.addHandler(handler)
handler = FileHandler("./result.log", "w")
handler.setLevel("INFO")
handler.setFormatter(log_fmt)
logger.setLevel("INFO")
logger.addHandler(handler)

DATA_DIR = os.path.join(os.path.dirname(__file__), "./data/")


def get_datasize(dic: Dict) -> Tuple[int, int, int]:
    c = int(dic["n_clases="])
    d = int(dic["n_entradas="])
    n_train_val = int(dic["n_patrons1="])
    if "n_patrons2=" in dic:
        n_test = int(dic["n_patrons2="])
    else:
        n_test = 0
    n_tot = n_train_val + n_test
    logger.info(f"Datasize: \tN: {n_tot}\td: {d}\tc: {c}")
    return n_tot, n_train_val, n_test


def load_data(dic: Dict) -> Tuple[np.array, np.array]:
    f = open(os.path.join(DATA_DIR, dic["dataset"], dic["fich1="]), "r").readlines()[1:]
    X = np.asarray(list(map(lambda x: list(map(float, x.split()[1:-1])), f)))
    y = np.asarray(list(map(lambda x: int(x.split()[-1]), f)))
    return X, y


def train(
    args,
    K: np.array,
    X: np.array,
    y: np.array,
    train_fold: List[int],
    eval_fold: List[int],
    depth: int,
    dic: Dict,
) -> float:
    if args.model in ("rbf", "laplacian"):
        acc = models.kernel_ridge_regression(
            X[train_fold],
            X[eval_fold],
            y[train_fold],
            y[eval_fold],
            alpha=args.reg_coef,  # ridge-less regression
            gamma=args.alpha,
            classes=int(dic["n_clases="]),
            kernel=args.model,
        )
    else:  # use NTK
        acc = models.precomputed_kernel_ridge_regression(
            K[int(np.log2(depth))][train_fold][:, train_fold],
            K[int(np.log2(depth))][eval_fold][:, train_fold],
            y[train_fold],
            y[eval_fold],
            alpha=args.reg_coef,  # ridge-less regression
            classes=int(dic["n_clases="]),
        )

    return acc


def send_log(
    args: argparse.Namespace, avg_acc: float, n_tot: int, depth: int, dic: dict
) -> None:
    logger.info(f"depth: {depth+1}, acc: {avg_acc * 100}%")
    if args.model in (
        "tree_ntk",
        "asymtree_ntk",
    ):
        wandb.log(
            {
                f'{depth+1:02d}_size_{dic["dataset"]}': n_tot,
                f'{depth+1:02d}_acc_{dic["dataset"]}': avg_acc * 100,
            }
        )
    else:
        wandb.log(
            {
                f'size_{dic["dataset"]}': n_tot,
                f'acc_{dic["dataset"]}': avg_acc * 100,
            }
        )


def cv(
    args: argparse.Namespace,
    K: np.array,
    X: np.array,
    y: np.array,
    dic: dict,
    n_tot: int,
) -> None:
    fold = list(
        map(
            lambda x: list(map(int, x.split())),
            open(
                os.path.join(DATA_DIR, dic["dataset"], "conxuntos_kfold.dat"), "r"
            ).readlines(),
        )
    )

    for depth in [1, 2, 4, 8, 16, 32, 64, 128]:
        logger.info("Start training...")
        avg_acc = 0
        for repeat in tqdm(range(4), leave=False, desc="CV-loop..."):
            train_fold, test_fold = fold[repeat * 2], fold[repeat * 2 + 1]
            acc = train(args, K, X, y, train_fold, test_fold, depth, dic)
            avg_acc += 0.25 * acc
        send_log(args, avg_acc, n_tot, depth - 1, dic)
        if args.model not in ("tree_ntk", "asymtree_ntk"):
            break


def main(args: argparse.Namespace):
    logger.info(f"Dataset: {sorted(os.listdir(DATA_DIR))[1:]}")
    n_dataset = 0
    for idx, dataset in enumerate(sorted(os.listdir(DATA_DIR))[1:]):  # remove .gitkeep
        if not os.path.isfile(os.path.join(DATA_DIR, dataset, f"{dataset}.txt")):
            logger.info(f"{dataset} is skipped because of the absence of txt file")
            continue

        logger.info(f"-----{idx}, {dataset}-----")

        # load configuration
        dic = dict()
        dic["dataset"] = dataset
        for k, v in map(
            lambda x: x.split(),
            open(os.path.join(DATA_DIR, dataset, f"{dataset}.txt"), "r").readlines(),
        ):
            dic[k] = v

        # Check skip or not
        n_tot, n_train_val, n_test = get_datasize(dic)
        if (n_tot > args.max_tot) or (n_test > 0):
            logger.info("skipped because of the dataset setting")
            continue
        else:
            n_dataset += 1

        # load dataset
        X, y = load_data(dic)

        kernel_path = f"kernels/{args.model}_{dic['dataset']}_{args.alpha}.npz"
        if args.mode == "train":
            K = np.load(kernel_path)["arr_0"]
            cv(args, K, X, y, dic, n_tot)  # cv loop
        elif args.mode == "kernel":
            # load NTK
            max_depth = 128
            logger.info("Extracting NTK...")
            if args.model == "tree_ntk":
                K, _, _ = ntk.tree(X, max_depth, args.alpha)
            elif args.model == "asymtree_ntk":
                K, _, _ = ntk.asymtree(X, max_depth, args.alpha)
            elif args.model == "inf_asymtree_ntk":
                K, _, _ = ntk.inf_asymtree(X, max_depth, args.alpha)
            else:  # do not use kernel
                K = np.zeros((2, 2))  # placeholder
            np.savez_compressed(f"{kernel_path}", K)
            logger.info(f"kernel saved. {kernel_path}")
        else:
            raise NotImplementedError

        logger.info("done")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-max_tot", default=5000, type=int, help="Maximum number of data samples"
    )
    parser.add_argument("-alpha", default=1.0, type=float, help="Scaling factor")
    parser.add_argument(
        "-reg_coef",
        default=1e-8,
        type=float,
        help="regularization factor used in kernel regression",
    )
    parser.add_argument(
        "-model",
        default="tree_ntk",
        type=str,
        choices=[
            "tree_ntk",
            "asymtree_ntk",
            "inf_asymtree_ntk",
        ],
    )
    parser.add_argument("-mode", type=str, choices=["train", "kernel"])
    parser.add_argument(
        "-name",
        default="debug",
        type=str,
    )

    args = parser.parse_args()
    logger.info(args)
    if args.mode == "train":
        wandb.init(
            project=f"tree-ntk-{args.max_tot}-{args.name}",
            config=args,
            reinit=True,
        )
        wandb.run.name = f"{args.model}{'-alpha'+str(int(args.alpha*100)).zfill(4) if args.model in ('tree_ntk', 'asymtree_ntk', 'inf_asymtree_ntk') else ''}"
    main(args)
    wandb.join()
