import argparse
import os
import warnings
from logging import FileHandler, Formatter, StreamHandler, getLogger
from typing import Tuple, Dict, List

import numpy as np
from tqdm import tqdm

import models
import ntk

warnings.filterwarnings("ignore")

log_fmt = Formatter(
    "%(asctime)s %(name)s L%(lineno)d [%(levelname)s][%(funcName)s] %(message)s "
)
logger = getLogger(__name__)
handler = StreamHandler()
handler.setLevel("INFO")
handler.setFormatter(log_fmt)
logger.setLevel("INFO")
logger.addHandler(handler)
handler = FileHandler("./result.log", "w")
handler.setLevel("INFO")
handler.setFormatter(log_fmt)
logger.setLevel("INFO")
logger.addHandler(handler)

DATA_DIR = os.path.join(os.path.dirname(__file__), "./data/")


def get_datasize(dic: Dict) -> Tuple[int, int, int]:
    c = int(dic["n_clases="])
    d = int(dic["n_entradas="])
    n_train_val = int(dic["n_patrons1="])
    if "n_patrons2=" in dic:
        n_test = int(dic["n_patrons2="])
    else:
        n_test = 0
    n_tot = n_train_val + n_test
    logger.info(f"Datasize: \tN: {n_tot}\td: {d}\tc: {c}")
    return n_tot, n_train_val, n_test


def load_data(dic: Dict) -> Tuple[np.array, np.array]:
    f = open(os.path.join(DATA_DIR, dic["dataset"], dic["fich1="]), "r").readlines()[1:]
    X = np.asarray(list(map(lambda x: list(map(float, x.split()[1:-1])), f)))
    y = np.asarray(list(map(lambda x: int(x.split()[-1]), f)))
    return X, y


def train(
    args,
    K: np.array,
    X: np.array,
    y: np.array,
    train_fold: List[int],
    eval_fold: List[int],
    depth: int,
    dic: Dict,
) -> float:
    if args.model == "rf":
        acc = models.random_forest_regressor(
            X[train_fold],
            X[eval_fold],
            y[train_fold],
            y[eval_fold],
            classes=int(dic["n_clases="]),
        )
    elif args.model in ("rbf", "laplacian"):
        acc = models.kernel_ridge_regression(
            X[train_fold],
            X[eval_fold],
            y[train_fold],
            y[eval_fold],
            alpha=args.reg_coef,  # ridge-less regression
            gamma=args.alpha,
            classes=int(dic["n_clases="]),
            kernel=args.model,
        )
    else:  # use NTK
        acc = models.precomputed_kernel_ridge_regression(
            K[depth][train_fold][:, train_fold],
            K[depth][eval_fold][:, train_fold],
            y[train_fold],
            y[eval_fold],
            alpha=args.reg_coef,  # ridge-less regression
            classes=int(dic["n_clases="]),
        )

    return acc


def search(
    args: argparse.Namespace,
    K: np.array,
    X: np.array,
    y: np.array,
    dic: Dict,
    n_tot: int,
) -> None:
    fold = list(
        map(
            lambda x: list(map(int, x.split())),
            open(
                os.path.join(DATA_DIR, dic["dataset"], "conxuntos.dat"), "r"
            ).readlines(),
        )
    )
    train_fold, valid_fold = fold[0], fold[1]

    for depth in range(max(args.max_depth, 1)):
        logger.info("Start training...")
        avg_acc = train(args, K, X, y, train_fold, valid_fold, depth, dic)
        print(avg_acc)
        if args.model not in ("tree_ntk", "mlp_relu_ntk"):
            break


def cv(
    args: argparse.Namespace,
    K: np.array,
    X: np.array,
    y: np.array,
    dic: dict,
    n_tot: int,
) -> None:
    fold = list(
        map(
            lambda x: list(map(int, x.split())),
            open(
                os.path.join(DATA_DIR, dic["dataset"], "conxuntos_kfold.dat"), "r"
            ).readlines(),
        )
    )

    for depth in range(max(args.max_depth, 1)):
        logger.info("Start training...")
        avg_acc = 0
        for repeat in tqdm(range(4), leave=False, desc="CV-loop..."):
            train_fold, test_fold = fold[repeat * 2], fold[repeat * 2 + 1]
            acc = train(args, K, X, y, train_fold, test_fold, depth, dic)
            avg_acc += 0.25 * acc
        print(avg_acc)
        if args.model not in ("tree_ntk", "mlp_relu_ntk"):
            break


def main(args: argparse.Namespace):
    logger.info(f"Dataset: {sorted(os.listdir(DATA_DIR))[1:]}")
    n_dataset = 0
    for idx, dataset in enumerate(sorted(os.listdir(DATA_DIR))[1:]):  # remove .gitkeep
        if not os.path.isfile(os.path.join(DATA_DIR, dataset, f"{dataset}.txt")):
            logger.info(f"{dataset} is skipped because of the absence of txt file")
            continue

        logger.info(f"-----{idx}, {dataset}-----")

        # load configuration
        dic = dict()
        dic["dataset"] = dataset
        for k, v in map(
            lambda x: x.split(),
            open(os.path.join(DATA_DIR, dataset, f"{dataset}.txt"), "r").readlines(),
        ):
            dic[k] = v

        # Check skip or not
        n_tot, n_train_val, n_test = get_datasize(dic)
        if (n_tot > args.max_tot) or (n_test > 0):
            logger.info("skipped because of the dataset setting")
            continue
        else:
            n_dataset += 1

        # load dataset
        X, y = load_data(dic)

        # load NTK
        logger.info("Extracting NTK...")
        if args.model == "tree_ntk":
            K, _, _ = ntk.tree(X, args.max_depth, args.alpha)
        elif args.model == "mlp_relu_ntk":
            K = ntk.mlp_relu(X, args.max_depth)
        else:  # do not use kernel
            K = np.zeros((2, 2))  # placeholder

        if args.search:
            search(args, K, X, y, dic, n_tot)  # decide best parameter
        else:
            cv(args, K, X, y, dic, n_tot)  # cv loop

        logger.info("done")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-max_tot", default=5000, type=int, help="Maximum number of data samples"
    )
    parser.add_argument("-alpha", default=1.0, type=float, help="Scaling factor")
    parser.add_argument("-max_depth", default=5, type=int, help="Maximum depth")
    parser.add_argument(
        "-reg_coef",
        default=1e-8,
        type=float,
        help="regularization factor used in kernel regression",
    )
    parser.add_argument(
        "-model",
        default="tree_ntk",
        type=str,
        choices=[
            "tree_ntk",
            "mlp_relu_ntk",
            "rbf",
        ],
    )
    parser.add_argument("-search", action="store_true")

    args = parser.parse_args()
    logger.info(args)
    main(args)
