from enum import Enum
from pathlib import Path
from typing import Optional, List, Union
from typing_extensions import Literal
from jinja2 import Template, ext
from pydantic import BaseModel, Field
import copy
import yaml

import typer
from ...utils.factory import PipelineFactory, NodeModelFactory, PipelineBase, DataFactory
from ...utils.base_model import extract_name, EarlyStopConfig, DeviceEnum

from ...utils.yaml_dump import deep_convert_dict, merge_comment
import ruamel.yaml
from ruamel.yaml.comments import CommentedMap


class SamplerConfig(BaseModel):
    name: Literal["neighbor"]
    fan_out: List[int] = [5, 10]
    batch_size: int = Field(64, description="Batch size")
    num_workers: int = 4
    eval_batch_size: int = 1024
    eval_num_workers: int = 4

    class Config:
        extra = 'forbid'



pipeline_comments = {
    "num_epochs": "Number of training epochs",
    "eval_period": "Interval epochs between evaluations",
    "early_stop": {
        "patience": "Steps before early stop",
        "checkpoint_path": "Early stop checkpoint model file path"
    },
    "sampler": {
        "fan_out": "List of neighbors to sample per edge type for each GNN layer, with the i-th element being the fanout for the i-th GNN layer. Length should be the same as num_layers in model setting",
        "batch_size": "Batch size of seed nodes in training stage",
        "num_workers": "Number of workers to accelerate the graph data processing step",
        "eval_batch_size": "Batch size of seed nodes in training stage in evaluation stage",
        "eval_num_workers": "Number of workers to accelerate the graph data processing step in evaluation stage"
    },
    "save_path": "Directory to save the experiment results",
    "num_runs": "Number of experiments to run",
}

class NodepredNSPipelineCfg(BaseModel):
    sampler: SamplerConfig = Field("neighbor")
    early_stop: Optional[EarlyStopConfig] = EarlyStopConfig()
    num_epochs: int = 200
    eval_period: int = 5
    optimizer: dict = {"name": "Adam", "lr": 0.005, "weight_decay": 0.0}
    loss: str = "CrossEntropyLoss"
    num_runs: int = 1
    save_path: str = "results"

@PipelineFactory.register("nodepred-ns")
class NodepredNsPipeline(PipelineBase):
    def __init__(self):
        self.pipeline = {
            "name": "nodepred-ns",
            "mode": "train"
        }
        self.default_cfg = None

    @classmethod
    def setup_user_cfg_cls(cls):
        from ...utils.enter_config import UserConfig
        class NodePredUserConfig(UserConfig):
            eval_device: DeviceEnum = Field("cpu")
            data: DataFactory.filter("nodepred-ns").get_pydantic_config() = Field(..., discriminator="name")
            model : NodeModelFactory.filter(lambda cls: hasattr(cls, "forward_block")).get_pydantic_model_config() = Field(..., discriminator="name")
            general_pipeline: NodepredNSPipelineCfg

        cls.user_cfg_cls = NodePredUserConfig

    @property
    def user_cfg_cls(self):
        return self.__class__.user_cfg_cls

    def get_cfg_func(self):
        def config(
            data: DataFactory.filter("nodepred-ns").get_dataset_enum() = typer.Option(..., help="input data name"),
            cfg: Optional[str] = typer.Option(
                None, help="output configuration path"),
            model: NodeModelFactory.filter(lambda cls: hasattr(cls, "forward_block")).get_model_enum() = typer.Option(..., help="Model name"),
        ):
            self.__class__.setup_user_cfg_cls()
            generated_cfg = {
                "pipeline_name": self.pipeline["name"],
                "pipeline_mode": self.pipeline["mode"],
                "device": "cpu",
                "data": {"name": data.name},
                "model": {"name": model.value},
                "general_pipeline": {"sampler":{"name": "neighbor"}}
            }
            output_cfg = self.user_cfg_cls(**generated_cfg).dict()
            output_cfg = deep_convert_dict(output_cfg)
            comment_dict = {
                "device": "Torch device name, e.g., cpu or cuda or cuda:0",
                "data": {
                    "split_ratio": 'Ratio to generate split masks, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset'
                },
                "general_pipeline": pipeline_comments,
                "model": NodeModelFactory.get_constructor_doc_dict(model.value),
            }
            comment_dict = merge_comment(output_cfg, comment_dict)

            # truncate length fan_out to be the same as num_layers in model
            if "num_layers" in comment_dict["model"]:
                comment_dict['general_pipeline']["sampler"]["fan_out"] = [5,10,15,15,15][:int(comment_dict['model']["num_layers"])]

            if cfg is None:
                cfg = "_".join(["nodepred-ns", data.value, model.value]) + ".yaml"
            yaml = ruamel.yaml.YAML()
            yaml.dump(comment_dict, Path(cfg).open("w"))
            print("Configuration file is generated at {}".format(
                Path(cfg).absolute()))

        return config

    @staticmethod
    def gen_script(user_cfg_dict):
        file_current_dir = Path(__file__).resolve().parent
        template_filename = file_current_dir / "nodepred-ns.jinja-py"
        with open(template_filename, "r") as f:
            template = Template(f.read())
        pipeline_cfg = NodepredNSPipelineCfg(
            **user_cfg_dict["general_pipeline"])

        if "num_layers" in user_cfg_dict["model"]:
            assert user_cfg_dict["model"]["num_layers"] == len(user_cfg_dict["general_pipeline"]["sampler"]["fan_out"]), \
                "The num_layers in model config should be the same as the length of fan_out in sampler. For example, if num_layers is 1, the fan_out cannot be [5, 10]"

        render_cfg = copy.deepcopy(user_cfg_dict)
        model_code = NodeModelFactory.get_source_code(
            user_cfg_dict["model"]["name"])
        render_cfg["model_code"] = model_code
        render_cfg["model_class_name"] = NodeModelFactory.get_model_class_name(
            user_cfg_dict["model"]["name"])
        render_cfg.update(DataFactory.get_generated_code_dict(
            user_cfg_dict["data"]["name"], '**cfg["data"]'))
        generated_user_cfg = copy.deepcopy(user_cfg_dict)

        if "split_ratio" in generated_user_cfg["data"]:
            generated_user_cfg["data"].pop("split_ratio")
        generated_user_cfg["data_name"] = generated_user_cfg["data"].pop("name")
        generated_user_cfg.pop("pipeline_name")
        generated_user_cfg.pop("pipeline_mode")
        generated_user_cfg["model_name"] = generated_user_cfg["model"].pop("name")
        generated_user_cfg['general_pipeline']["optimizer"].pop("name")


        if user_cfg_dict["data"].get("split_ratio", None) is not None:
            render_cfg["data_initialize_code"] = "{}, split_ratio={}".format(render_cfg["data_initialize_code"], user_cfg_dict["data"]["split_ratio"])

        render_cfg["user_cfg_str"] = f"cfg = {str(generated_user_cfg)}"
        render_cfg["user_cfg"] = user_cfg_dict
        with open("output.py", "w") as f:
            return template.render(**render_cfg)

    @staticmethod
    def get_description() -> str:
        return "Node classification neighbor sampling pipeline for training"
