import torch
from model.model import GraphSAGE, GCN, GIN, SGC, GAT

def init_base_model(data, params, device, args):
    input_dim, output_dim = data.num_features, data.num_classes
    is_batch = True if args.dataset in ['ogbn-arxiv', 'flickr', 'reddit'] else False

    ModelClass = {"graphsage": GraphSAGE, "gcn": GCN, "gin": GIN, "sgc": SGC, "gat": GAT}[args.architecture]
    model = ModelClass(
        version="base",
        num_layers=params['num_layers'],
        input_dim=input_dim,
        hidden_dim=params['hidden_dim'],
        output_dim=output_dim,
        dropout_ratio=params['dropout_ratio'],
        activation=params['activation'],
        norm_type=params['norm_type'],
        is_batch=is_batch
    ).to(device)
    model.reset_parameters()
    
    return model

def init_processed_model(data, params, device, args):
    input_dim, output_dim = data.num_features, data.num_classes
    ModelClass = {"graphsage": GraphSAGE, "gcn": GCN, "gin": GIN, "sgc": SGC, "gat": GAT}[args.architecture]
    is_batch = True if args.dataset in ['ogbn-arxiv', 'flickr', 'reddit'] else False
    model = ModelClass(
        version=args.pp_method,
        num_layers=params['num_layers'],
        input_dim=input_dim,
        hidden_dim=params['hidden_dim'],
        output_dim=output_dim,
        dropout_ratio=params['dropout_ratio'],
        activation=params['activation'],
        norm_type=params['norm_type'],
        is_batch=is_batch
    ).to(device)
    model.reset_parameters()

    return model

def clone_base_model(data, model_path, params, device, args):
    model = init_base_model(data, params, device, args)
    load_model(model, model_path)
    return model

def clone_processed_model(data, model_path, params, device, args):
    model = init_processed_model(data, params, device, args)
    load_model(model, model_path)
    return model

def save_model(model, path):
    torch.save(model.state_dict(), path)

def load_model(model, path):
    model.load_state_dict(torch.load(path))
    model.eval()