import os

import hydra
import torch
import torch.nn as nn
import torch.optim as optim
from loguru import logger
from omegaconf import DictConfig

from dataset import get_dataloader, get_split_dataloader
from models.mlp import ReLUMLP, SigmoidMLP
from models.non_oblivious import (DecisionListEnsemble,
                                  DecisionListEnsembleNoBinary,
                                  SoftTreeEnsemble)
from models.oblivious import ObliviousTreeEnsemble
from utils import test, train, seed_everything

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

torch.set_default_dtype(torch.float64)


@hydra.main(config_path="./config/", config_name="config", version_base=None)
def main(cfg: DictConfig) -> None:
    logger.info(cfg)
    seeds = (cfg.seed_a, cfg.seed_b)
    output_filenames = (cfg.model_a, cfg.model_b)

    for model_type in (
        "non_oblivious",
        "oblivious",
        "decision_list",
        "decision_list_no_binary",
        "relu_mlp",
        "sigmoid_mlp",
    ):
        if (cfg.depth > 3) and model_type=="non_oblivious":
            logger.info(f"depth={cfg.depth}. {model_type} is skipped.")
            continue

        logger.info(f"~~~~~Model: {model_type}~~~~~")
        if model_type == "non_oblivious":
            model_class = SoftTreeEnsemble
        elif model_type == "oblivious":
            model_class = ObliviousTreeEnsemble
        elif model_type == "decision_list":
            model_class = DecisionListEnsemble
        elif model_type == "decision_list_no_binary":
            model_class = DecisionListEnsembleNoBinary
        elif model_type == "relu_mlp":
            model_class = ReLUMLP
        elif model_type == "sigmoid_mlp":
            model_class = SigmoidMLP
        else:
            raise NotImplementedError

        for seed, output_filename in zip(seeds, output_filenames):
            output_path = f"./output/model/{model_type}_{output_filename}"
            if os.path.exists(output_path) and not cfg.disable_cache:
                logger.info(f"Model already exists. Skip training ({output_path})")
            else:
                logger.info(
                    f"======Training Start: Seed={seed}, Filename={output_filename}======"
                )
                seed_everything(seed)

                if cfg.split:
                    train_a_loader, train_b_loader, test_loader = get_split_dataloader(
                        cfg.dataset.name, cfg.batch_size
                    )
                    if seed == cfg.seed_a:
                        train_loader = train_a_loader
                    elif seed == cfg.seed_b:
                        train_loader = train_b_loader
                    else:
                        raise ValueError
                else:
                    train_loader, test_loader = get_dataloader(
                        cfg.dataset.name, cfg.batch_size
                    )

                model = model_class(
                    input_dim=cfg.dataset.input_dim,
                    output_dim=cfg.dataset.output_dim,
                    depth=cfg.depth,
                    alpha=cfg.alpha,
                    n_tree=cfg.n_tree,
                ).to(device)
                optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate)

                for epoch in range(1, cfg.epochs + 1):
                    logger.info(f"------Epoch={epoch}------")
                    train(model, train_loader, optimizer, epoch)
                    test(model, test_loader)

                torch.save(model.state_dict(), output_path)
                logger.info(f"Model Saved: {output_path}")


if __name__ == "__main__":
    main()
