import torch
import torch.nn.functional as F
import torch.nn as nn

from torch_geometric.data import Dataset
from torch_geometric.datasets import (
    Planetoid,
    WebKB,
    WikipediaNetwork,
    Actor,
)
from datasets.ood_dataset import load_dataset

from libs.reporter import WandBReporter

from datasets.heterophilous_graph_dataset import HeterophilousGraphDataset
from datasets.data_utils import eval_acc, eval_rocauc, rand_splits

from models.model_baseline import (
    GAT,
    GCN,
    MSP,
    Mahalanobis, 
    OE,
    ODIN,
)

from models.grand import GRAND 
from models.gnnsafe import GNNSafe
# from models.gpn.models import GPN, EnergyModel, EnergyProp, MaxLogits, SGCN
# from models.gpn.gdk import GDK
from models.gspde import GSPDE
from models.gnsd import GNSD
from datasets.data_utils import eval_acc, eval_rocauc, rand_splits
from libs.logger import Logger_classify, Logger_detect, Logger_ood


# NOTE: problem with the dependency of gpn pyblaze 

def make_optimizer(cfg, model):

    # NOTE: special case for GPN, return warmup and standard optimizer
    # if isinstance(model, GPN):
    #     optimizer, _ = model.get_optimizer(cfg["lr"], cfg["weight_decay"])
    #     warmup_optimizer = model.get_warmup_optimizer(cfg["lr"], cfg["weight_decay"]) 
    #     return optimizer, warmup_optimizer

    # # NOTE: special case for SGCN, return warmup and teacher optimizer
    # if isinstance(model, SGCN):
    #     teacher_optimizer = torch.optim.Adam(model.teacher.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"])
    #     return teacher_optimizer
    # else:
    if cfg["name"] == "adamax":
        optimizer = torch.optim.Adamax(model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"])
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"])
    
    # Add gradient clipping information to the optimizer config
    if "grad_clip" not in cfg:
        cfg["grad_clip"] = False
    if "clip_value" not in cfg:
        cfg["clip_value"] = 1.0
        
    return optimizer


def make_ood_dataset(cfg):
    dataset_ind, dataset_ood_tr, dataset_ood_te = load_dataset(cfg)
    if len(dataset_ind.y.shape) == 1:
        dataset_ind.y = dataset_ind.y.unsqueeze(1)
    if len(dataset_ood_tr.y.shape) == 1:
        dataset_ood_tr.y = dataset_ood_tr.y.unsqueeze(1)
    if isinstance(dataset_ood_te, list):
        for data in dataset_ood_te:
            if len(data.y.shape) == 1:
                data.y = data.y.unsqueeze(1)
    else:
        if len(dataset_ood_te.y.shape) == 1:
            dataset_ood_te.y = dataset_ood_te.y.unsqueeze(1)
    ### get splits for all runs ###
    if cfg["name"] in ['cora', 'citeseer', 'pubmed', 'arxiv', 'roman_empire', 
                       'amazon_ratings', 'minesweeper', 'tolokers', 'questions'] and cfg["split_mode"] != "year":
        pass
    else:
        dataset_ind.splits = rand_splits(dataset_ind.node_idx, train_prop=cfg["train_prop"], valid_prop=cfg["valid_prop"])
    return dataset_ind, dataset_ood_tr, dataset_ood_te


def make_dataset(cfg):
    if "root" not in cfg:
        cfg["root"] = f"/tmp/{cfg['name']}"
    if cfg["name"].lower() in {"cora", "citeseer", "pubmed"}:
        dataset = Planetoid(**cfg, split="random")
    elif cfg["name"].lower() in {"chameleon", "squirrel"}:  
        dataset = WikipediaNetwork(**cfg)
    elif cfg["name"].lower() in {"texas", "cornell", "wisconsin"}:
            dataset.data.edge_list = torch.load(f"data/sdrl/{cfg['name'].lower()}_sdrl.pt")
    elif cfg["name"].lower() in {"texas", "cornell", "wisconsin"}:
        dataset = WebKB(**cfg)
    elif cfg["name"].lower() in {"actor"}:
        name = cfg.pop('name')
        dataset = Actor(**cfg)
        cfg['name'] = name
    if cfg["name"].lower() in {
        "roman_empire",
        "amazon_ratings",
        "minesweeper",
        "tolokers",
        "questions",
    }:
        dataset = HeterophilousGraphDataset(**cfg)
    return dataset

def make_model(cfg, dataset):
    # ood detection
    num_classes = max(dataset.y.max().item() + 1, dataset.y.shape[1])
    num_features = dataset.x.shape[1]
    if cfg["name"] == "GCN":
        model = GCN(num_features, num_classes)
        return model
    elif cfg["name"] == "GAT":
        model = GAT(num_features, num_classes, **cfg)
        return model
    elif cfg["name"].lower() == "grand":
        model = GRAND(cfg, dataset, cfg["device"])
        return model
    elif cfg["name"].lower() == "gspde":
        model = GSPDE(num_features, num_classes, cfg)
        return model
    elif cfg["name"].lower() == "gnsd":
        model = GNSD(num_features, num_classes, cfg)
        return model

    elif cfg["name"].lower() == "msp":
        model = MSP(num_features, num_classes, cfg)
        return model
    elif cfg["name"].lower() == "mahalanobis":
        model = Mahalanobis(num_features, num_classes, cfg)
        return model
    elif cfg["name"].lower() == "gnnsafe":
        model = GNNSafe(num_features, num_classes, cfg)
        return model
    elif cfg["name"].lower() == "oe":
        model = OE(num_features, num_classes, cfg)
        return model
    elif cfg["name"].lower() == "odin":
        model = ODIN(num_features, num_classes, cfg)
        return model
    # elif cfg["name"].lower() == "gpn":
    #     model = GPN(num_features, num_classes, cfg)
    #     return model
    # elif cfg["name"].lower() == "gdk":
    #     model = GDK(num_features, num_classes, cfg)
    #     return model
    # elif cfg["name"].lower() == "energy_model":
    #     model = EnergyModel(num_features, num_classes, cfg)
    #     return model
    # elif cfg["name"].lower() == "energy_prop":
    #     model = EnergyProp(num_features, num_classes, cfg)
    #     return model
    # elif cfg["name"].lower() == "max_logits":
    #     model = MaxLogits(num_features, num_classes, cfg)
    # elif cfg["name"].lower() == "sgcn":
    #     teacher = MaxLogits(num_features, num_classes, cfg)
    #     model = SGCN(num_features, num_classes, cfg)
    #     model.create_storage(dataset, teacher, cfg["device"])
    #     return teacher, model # NOTE: special case for SGCN, return teacher and model
    else:
        print(f"ERROR, DIDN'T FIND MODEL {cfg['name']}")


def make_reporter(cfg, exp_cfg=None):
    if cfg["logger_name"] == "wandb":
        return WandBReporter(cfg, exp_cfg)
    elif cfg["logger_name"] == 'ood_classify':
        reporter = Logger_classify(cfg["runs"], cfg)
    elif cfg["logger_name"] == 'ood_detect':
        reporter = Logger_detect(cfg["runs"], cfg)
    elif cfg["logger_name"] == 'ood':
        reporter = Logger_ood(cfg["runs"], cfg)
    else:
        raise ValueError(f"Reporter {cfg['logger_name']} not found")
    return reporter



def make_lossfn(cfg):
    if cfg["model"] == "grand":
        criterion = nn.CrossEntropyLoss()
    # ood detection
    if cfg["name"] == "ood":
        if cfg["dataset"] in ["proteins", "ppi", "twitch"]:
            eval_func = eval_rocauc
        else:
            eval_func = eval_acc
        if cfg["dataset"] in ["proteins", "ppi"]:
            criterion = nn.BCEWithLogitsLoss()
        else:
            criterion = nn.NLLLoss()
        return criterion, eval_func

    # classification 
    if cfg["dataset"] in ('proteins', 'ppi'): # multi-label binary classification
        criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = nn.NLLLoss()

    ### metric for classification ###
    if cfg["dataset"] in ('proteins', 'ppi', 'twitch'): # binary classification
        eval_func = eval_rocauc
    else:
        eval_func = eval_acc
    return criterion, eval_func