from model.lrw import LearnableRandomWalk
from model.classifier import Classifier

import torch

class ModelZoo:
    def __init__(self, logger, args, dataset, num_nodes, feat_dim, output_dim, label, test_idx, task_level):
        super(ModelZoo, self).__init__()
        self.logger = logger
        self.args = args
        self.feat_dim = feat_dim
        self.output_dim = output_dim
        self.num_nodes = num_nodes
        self.task_level = task_level
        self.label = label
        self.device  = torch.device('cuda:{}'.format(args.gpu_id) if (args.use_cuda and torch.cuda.is_available()) else 'cpu')
        self.test_idx = test_idx
        self.q = self.args.edge_q if self.task_level == "edge" else self.args.node_q
        self.dataset = dataset
        self.log_model()

    def log_model(self):
        if  self.args.model_name == "lrw":
            self.logger.info(f"model: {self.args.model_name}, hidden_dim: {self.args.hidden_dim}, hidden_att_dim: {self.args.hidden_att_dim}, dropout: {self.args.dropout}")
        elif self.args.model_name == "classifier":
            self.logger.info(f"model: {self.args.model_name}, hidden_dim: {self.args.hidden_dim}, dropout: {self.args.dropout}, backbone: {self.args.backbone}")
    
    def model_init(self):
        if self.args.model_name == "lrw":
            model = LearnableRandomWalk(data_name=self.args.data_name, n=self.num_nodes, hidden_att_dim=self.args.hidden_att_dim, 
                                 feat_dim=self.feat_dim, hidden_dim=self.args.hidden_dim, output_dim=self.output_dim, 
                                 cof=self.args.walk_cof, delta=self.args.walk_delta, walk_time=self.args.walk_time, 
                                 walk_len=self.args.walk_len, dropout=self.args.dropout, task_level=self.task_level, y=self.label)     
        elif self.args.model_name == "classifier":
            model = Classifier(data_name=self.args.data_name, dataset=self.dataset, args=self.args,n=self.num_nodes,
                                 feat_dim=self.feat_dim, hidden_dim=self.args.hidden_dim, output_dim=self.output_dim,  
                                dropout=self.args.dropout, walk_len=self.args.walk_len, task_level=self.task_level, device=self.device)
        else:
            return NotImplementedError

        return model
