import dgl
import numpy as np
import pydantic
import random
import torch
import yaml

from typing import Optional

class MetaData(pydantic.BaseModel):
    model_name: str

class MetaGNN_X(pydantic.BaseModel):
    hidden_t: int
    hidden_X: int
    hidden_Y: int
    num_gnn_layers: int
    dropout: float

class MetaGNN_E(pydantic.BaseModel):
    hidden_t: int
    hidden_X: int
    hidden_Y: int
    hidden_E: int
    num_gnn_layers: int
    dropout: float

class MetaDiffusion(pydantic.BaseModel):
    T: int

class MetaOptimizer(pydantic.BaseModel):
    lr: float
    weight_decay: Optional[float] = 0.
    amsgrad: Optional[bool] = False

class MetaLRScheduler(pydantic.BaseModel):
    factor: float
    patience: int
    verbose: bool

class MetaTrain(pydantic.BaseModel):
    num_epochs: int
    val_every_epochs: int
    patient_epochs: int
    max_grad_norm: Optional[float] = None
    batch_size: int
    val_batch_size: int

class MetaYamlXE(pydantic.BaseModel):
    meta_data: MetaData
    gnn_X: MetaGNN_X
    gnn_E: MetaGNN_E
    diffusion: MetaDiffusion
    optimizer_X: MetaOptimizer
    optimizer_E: MetaOptimizer
    lr_scheduler: MetaLRScheduler
    train: MetaTrain

class MetaYamlE(pydantic.BaseModel):
    meta_data: MetaData
    gnn_E: MetaGNN_E
    diffusion: MetaDiffusion
    optimizer_E: MetaOptimizer
    lr_scheduler: MetaLRScheduler
    train: MetaTrain

class MetaMLP_X(pydantic.BaseModel):
    hidden_t: int
    hidden_X: int
    hidden_Y: int
    num_mlp_layers: int
    dropout: float

class MetaDiffusionXEAsym(pydantic.BaseModel):
    T_X: int
    T_E: int

class MetaYamlXEAsym(pydantic.BaseModel):
    meta_data: MetaData
    mlp_X: MetaMLP_X
    gnn_E: MetaGNN_E
    diffusion: MetaDiffusionXEAsym
    optimizer_X: MetaOptimizer
    optimizer_E: MetaOptimizer
    lr_scheduler: MetaLRScheduler
    train: MetaTrain

def load_train_yaml(exp_name):
    with open(f"configs/train/{exp_name}.yaml") as f:
        yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader)

    if exp_name.endswith("X_E"):
        yaml_dict = MetaYamlXE(**yaml_data).dict()
    elif exp_name.endswith("E"):
        yaml_dict = MetaYamlE(**yaml_data).dict()
    elif exp_name.endswith("X_E_asym"):
        yaml_dict = MetaYamlXEAsym(**yaml_data).dict()

    return yaml_dict

class MetaModelMFAE(pydantic.BaseModel):
    hidden_size: int
    dropout: float
    num_layers: int

class MetaTrainMFAE(pydantic.BaseModel):
    num_epochs: int
    patient_epochs: int

class MetaYamlMFAE(pydantic.BaseModel):
    meta_data: MetaData
    model: MetaModelMFAE
    optimizer: MetaOptimizer
    train: MetaTrainMFAE

class MetaModelVGAE(pydantic.BaseModel):
    hidden_size: int
    dropout: float

class MetaYamlVGAE(pydantic.BaseModel):
    meta_data: MetaData
    model: MetaModelVGAE
    optimizer: MetaOptimizer
    train: MetaTrainMFAE

def load_train_mf_ae_yaml(exp_name, model_name):
    with open(f"configs/train/{exp_name}.yaml") as f:
        yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader)

    if model_name in ["feat_mf", "gae"]:
        return MetaYamlMFAE(**yaml_data).dict()

    if model_name == "vgae":
        return MetaYamlVGAE(**yaml_data).dict()

class MetaSampleYaml(pydantic.BaseModel):
    num_samples: int

def load_sample_yaml(data_name):
    with open(f"configs/sample/{data_name}.yaml") as f:
        yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader)
    yaml_dict = MetaSampleYaml(**yaml_data).dict()
    return yaml_dict

def set_seed(seed=0):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    dgl.seed(seed)
