import torch
from prompt_graph.model import GAT, GCN, GraphTransformer
from prompt_graph.prompt import GPF, GPF_plus, LightPrompt,HeavyPrompt, GPPTPrompt
from torch import nn, optim
from prompt_graph.data import load4node
from prompt_graph.utils import Gprompt_tuning_loss
import numpy as np
import ipdb

class BaseTask:
    def __init__(self, pre_train_model_path='None', gnn_type='TransformerConv',
                 hid_dim = 128, num_layer = 2, dataset_name='Cora', prompt_type='None', epochs=100, shot_num=10, device : int = 5, lr =0.001, wd = 5e-4,
                 batch_size = 16, seed = 0, search = False, use_different_dataset=False, pre_train_data=None):
        
        self.pre_train_model_path = pre_train_model_path
        self.pre_train_type = self.return_pre_train_type(pre_train_model_path)
        # self.device = torch.device('cuda:' + str(device) if torch.cuda.is_available() else 'cpu')
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.hid_dim = hid_dim
        self.num_layer = num_layer
        self.dataset_name = dataset_name
        self.shot_num = shot_num
        self.gnn_type = gnn_type
        self.prompt_type = prompt_type
        self.epochs = epochs
        self.lr = lr
        self.wd = wd
        self.batch_size = batch_size
        self.search = search
        self.seed = seed
        self.pre_train_data = pre_train_data
        self.initialize_lossfn()

    def initialize_lossfn(self):
        self.criterion = torch.nn.CrossEntropyLoss()
        if self.prompt_type == 'Gprompt':
            self.criterion = Gprompt_tuning_loss()

    def initialize_optimizer(self):
        if self.prompt_type == 'None':
            if self.pre_train_model_path == 'None':
                model_param_group = []
                model_param_group.append({"params": self.gnn.parameters()})
                model_param_group.append({"params": self.answering.parameters()})
                self.module_list = nn.Sequential(*[self.gnn, self.answering])
                self.optimizer = optim.Adam(self.module_list.parameters(), lr=self.lr, weight_decay=self.wd)
            elif not self.disable_dp:
                self.module_list = nn.Sequential(*[self.gnn, self.answering])
                self.optimizer = optim.Adam(self.module_list.parameters(), lr=self.lr, weight_decay=self.wd)
            else:
                model_param_group = []
                model_param_group.append({"params": self.gnn.parameters()})
                model_param_group.append({"params": self.answering.parameters()})
                self.optimizer = optim.Adam(model_param_group, lr=self.lr, weight_decay=self.wd)

        elif self.prompt_type == 'All-in-one':
            self.pg_opi = optim.Adam( self.prompt.parameters(), lr=1e-3, weight_decay= self.wd)
            self.answer_opi = optim.Adam( self.answering.parameters(), lr=self.lr, weight_decay= self.wd)
            self.module_list = nn.Sequential(*[self.prompt, self.answering])
            self.optimizer = optim.Adam(self.module_list.parameters(), lr=self.lr, weight_decay=self.wd)

        elif self.prompt_type in ['GPF', 'GPF-plus']:
            self.module_list = nn.Sequential(*[self.prompt, self.answering])
            self.optimizer = optim.Adam(self.module_list.parameters(), lr=self.lr, weight_decay=self.wd)
        elif self.prompt_type in ['GPPT']:
            self.pg_opi = optim.Adam(self.prompt.parameters(), lr=2e-3, weight_decay=5e-4)

    def initialize_prompt(self):
        if self.prompt_type == 'None':
            self.prompt = None
        elif self.prompt_type == 'GPPT':
            if(self.task_type=='NodeTask'):
                if self.dataset_name == 'Texas':
                    self.prompt = GPPTPrompt(self.hid_dim, 5, self.output_dim, device = self.device)
                else:
                    self.prompt = GPPTPrompt(self.hid_dim, self.output_dim, self.output_dim, device = self.device)
        elif self.prompt_type =='All-in-one':
            if(self.task_type=='NodeTask'):
                self.prompt = HeavyPrompt(token_dim=self.input_dim, token_num=10, cross_prune=0.1, inner_prune=0.3).to(self.device)
        elif self.prompt_type == 'GPF':
            self.prompt = GPF(self.input_dim, self.device).to(self.device)
        elif self.prompt_type == 'GPF-plus':
            self.prompt = GPF_plus(self.input_dim, 20).to(self.device)
        else:
            raise KeyError(" We don't support this kind of prompt.")
        

    def initialize_gnn(self):
        if self.gnn_type == 'GAT':
            self.gnn = GAT(input_dim=self.input_dim, hid_dim=self.hid_dim, num_layer=self.num_layer)
        elif self.gnn_type == 'GCN':
            self.gnn = GCN(input_dim=self.input_dim, hid_dim=self.hid_dim, num_layer=self.num_layer)
        elif self.gnn_type == 'GraphTransformer':
            self.gnn = GraphTransformer(input_dim=self.input_dim, hid_dim=self.hid_dim, num_layer=self.num_layer)
        else:
            raise ValueError(f"Unsupported GNN type: {self.gnn_type}")
        self.gnn.to(self.device)
        print(self.gnn)

        if self.pre_train_model_path != 'None' and self.prompt_type != 'MultiGprompt':
            if self.gnn_type not in self.pre_train_model_path :
                raise ValueError(f"the Downstream gnn '{self.gnn_type}' does not match the pre-train model")

            self.gnn.load_state_dict(torch.load(self.pre_train_model_path, map_location='cpu'))
            self.gnn.to(self.device)       
            print("Successfully loaded pre-trained weights!")

    def return_pre_train_type(self, pre_train_model_path):
        names = ['None', 'DGI', 'GraphMAE','Edgepred_GPPT', 'Edgepred_Gprompt','GraphCL', 'SimGRACE']
        for name in names:
            if name  in  pre_train_model_path:
                return name


      
 
            
      
