from model.base_model import ClassifierBaseModel
from model.predictor import ClassifierModel
import torch
from model.lrw import LearnableRandomWalk
import numpy as np

class Classifier(ClassifierBaseModel):
    def __init__(self, args, dataset, data_name, n, feat_dim, hidden_dim, output_dim, dropout, walk_len, task_level, device):
        super(Classifier, self).__init__(n, data_name, walk_len, device)
        self.args = args
        self.dataset = dataset
        self.model_load_path = "./model_files/lrw/"+data_name+"/lrw.pth"
        self.task_level = task_level
        row_index = self.dataset.adj.nonzero()[0]
        col_index = self.dataset.adj.nonzero()[1]
        self.walk_op = LearnableRandomWalk(data_name=self.args.data_name, n=self.dataset.y.size(0), hidden_att_dim=64, 
                                 feat_dim=self.dataset.num_features, hidden_dim=256, output_dim=self.dataset.num_node_classes, 
                                 cof=0.8, delta=0.1, walk_time=args.walk_time, 
                                 walk_len=args.walk_len, dropout=0.3, task_level=self.task_level, y=self.dataset.y)
        self.walk_op.preprocess(self.dataset.adj, self.dataset.x)
        self.walk_op.load_state_dict(torch.load(self.model_load_path, map_location=torch.device('cuda:0')))
        self.walk_op = self.walk_op.to(device)
        self.walk_op.eval()
        self.base_model = ClassifierModel(data_name, n, feat_dim, hidden_dim, output_dim, dropout, walk_len, task_level, args.backbone, device, self.args)
        self.base_model.edge_index = np.vstack((row_index, col_index))
        self.base_model.edge_index = torch.Tensor(self.base_model.edge_index).long()
