import torch
from prompt_graph.model import GAT, GCN, GCov, GIN, GraphSAGE, GraphTransformer
from prompt_graph.prompt import LightPrompt,HeavyPrompt
from torch import nn, optim
from prompt_graph.data import load4node, load4graph
from prompt_graph.utils import Gprompt_tuning_loss
import numpy as np

class BaseTask:
    def __init__(self, pre_train_model_path=None, gnn_type='TransformerConv', hid_dim = 128, num_layer = 2, dataset_name='Cora', prompt_type='GPF', epochs=100, shot_num=10, device : int = 1, 
                config=None, finetune=False, task='GraphTask', transfer=False):
        self.pre_train_model_path = pre_train_model_path
        self.device = torch.device('cuda:'+ str(device) if torch.cuda.is_available() else 'cpu')
        self.hid_dim = hid_dim
        self.num_layer = num_layer
        self.dataset_name = dataset_name
        self.shot_num = shot_num
        self.gnn_type = gnn_type
        self.prompt_type = prompt_type
        self.epochs = epochs
        self.config = config
        self.finetune = finetune
        self.initialize_lossfn()
        self.task = task
        self.transfer = transfer

    def initialize_optimizer(self):
        if self.prompt_type == 'None':
            model_param_group = []
            model_param_group.append({"params": self.gnn.parameters()})
            model_param_group.append({"params": self.answering.parameters()})
            if self.finetune:
                try:
                    self.optimizer = optim.Adam(model_param_group, lr=self.config.finetune_lr, weight_decay=self.config.finetune_weight_decay)
                except:
                    self.optimizer = optim.Adam(model_param_group, lr=self.config.supervised_and_finetune_lr, weight_decay=self.config.supervised_and_finetune_weight_decay)
            else:
                self.optimizer = optim.Adam(model_param_group, lr=self.config.supervised_and_finetune_lr, weight_decay=self.config.supervised_and_finetune_weight_decay)
        elif self.prompt_type in ['All-in-one', 'answering']:
            self.pg_opi = optim.Adam(filter(lambda p: p.requires_grad, self.prompt.parameters()), lr=self.config.pg_lr, weight_decay= self.config.pg_weight_decay)
            self.answer_opi = optim.Adam(filter(lambda p: p.requires_grad, self.answering.parameters()), lr=self.config.answer_lr, weight_decay=self.config.answer_weight_decay)
        elif self.prompt_type in ['GPF', 'GPF-plus']:
            model_param_group = []
            model_param_group.append({"params": self.prompt.parameters()})
            model_param_group.append({"params": self.answering.parameters()})
            self.optimizer = optim.Adam(model_param_group, lr=0.005, weight_decay=5e-4)
        elif self.prompt_type in ['Gprompt', 'GPPT']:
            self.pg_opi = optim.Adam(self.prompt.parameters(), lr=0.005, weight_decay=5e-4)
        elif self.prompt_type == 'MultiGprompt':
            self.optimizer = torch.optim.Adam([*self.DownPrompt.parameters(),*self.feature_prompt.parameters()], lr=0.001)


    def initialize_lossfn(self):
        self.criterion = torch.nn.CrossEntropyLoss()
        if self.prompt_type == 'Gprompt':
            self.criterion = Gprompt_tuning_loss()
            
    def initialize_prompt(self):
        if self.prompt_type == 'None':
            self.prompt = None
        elif self.prompt_type in ['All-in-one', 'answering']:
            lr, wd = 0.001, 0.00001
            self.prompt = HeavyPrompt(token_dim=self.input_dim, token_num=self.config.prompt_graph_token_num).to(self.device)
        else:
            raise KeyError(" We don't support this kind of prompt.")

    def initialize_gnn(self):
        if self.gnn_type == 'GAT':
            self.gnn = GAT(input_dim=self.input_dim, out_dim=self.hid_dim, num_layer=self.num_layer)
        elif self.gnn_type == 'GCN':
            self.gnn = GCN(input_dim=self.input_dim, out_dim=self.hid_dim, num_layer=self.num_layer)
        elif self.gnn_type == 'GraphSAGE':
            self.gnn = GraphSAGE(input_dim=self.input_dim, out_dim=self.hid_dim, num_layer=self.num_layer)
        elif self.gnn_type == 'GIN':
            self.gnn = GIN(input_dim=self.input_dim, out_dim=self.hid_dim, num_layer=self.num_layer)
        elif self.gnn_type == 'GCov':
            self.gnn = GCov(input_dim=self.input_dim, out_dim=self.hid_dim, num_layer=self.num_layer)
        elif self.gnn_type == 'GraphTransformer':
            self.gnn = GraphTransformer(input_dim=self.input_dim, out_dim=self.hid_dim, num_layer=self.num_layer)
        else:
            raise ValueError(f"Unsupported GNN type: {self.gnn_type}")
        self.gnn.to(self.device)

        if self.pre_train_model_path != 'None' and self.prompt_type != 'MultiGprompt':
            if self.gnn_type not in self.pre_train_model_path :
                raise ValueError(f"the Downstream gnn '{self.gnn_type}' does not match the pre-train model")
            if self.dataset_name not in self.pre_train_model_path :
                # raise ValueError(f"the Downstream dataset '{self.dataset_name}' does not match the pre-train dataset")
                print(f"Warning: the Downstream dataset '{self.dataset_name}' does not match the pre-train dataset")

            self.gnn.load_state_dict(torch.load(self.pre_train_model_path, map_location=self.device))
            print("Successfully loaded pre-trained weights!")

         
      
 
            
      
