import torch
from torch_geometric.loader import DataLoader
from prompt_graph.utils import constraint,  center_embedding, Gprompt_tuning_loss
from prompt_graph.evaluation import GPPTEva, GNNNodeEva, GPFEva, GPFNDEva
from .task import BaseTask
import time
import warnings
import numpy as np
from prompt_graph.data import load4node, node_sample, GraphDataset, node_sample_pate_teacher, node_sample_pate_student, node_sample_weighted_pate_teacher
from prompt_graph.evaluation import AllInOneEva, AllInOneNDEva
import os
from prompt_graph.utils import process
import ipdb
import torchmetrics
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple
import prompt_graph.utils.dp_utils as dp_utils
from opacus.accountants import create_accountant
import pandas as pd

warnings.filterwarnings("ignore")
L2NORM_BOUND = 1
class NodeTask(BaseTask):
      def __init__(self, data, input_dim, output_dim, graphs_list = None, dpsgd=False, eps=1, delta=1e-5, sample_rate=0.6, pate=False, teacher_idx = 0, student_prompt=False, weighted_pate=False, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.task_type = 'NodeTask'
            self.data = data
            if self.dataset_name == 'ogbn-arxiv':
                  self.data.y = self.data.y.squeeze()
            self.input_dim = input_dim # inital node feature
            self.output_dim = output_dim
            self.graphs_list = graphs_list
            self.answering =  torch.nn.Sequential(torch.nn.Linear(self.hid_dim, self.output_dim),
                                                torch.nn.Softmax(dim=1)).to(self.device) 
            self.pate = pate
            self.student_prompt = student_prompt
            self.teacher_idx = teacher_idx
            self.weighted_pate = weighted_pate
            if not self.pate and not self.student_prompt:
                  self.train_idx, self.train_labels, self.test_idx, self.test_labels = node_sample(self.data, self.shot_num, num_classes=self.output_dim, seed=self.seed)
            elif self.pate and not self.student_prompt:
                  self.train_idx, self.train_labels, self.test_idx, self.test_labels = node_sample_pate_teacher(self.data, self.shot_num, num_classes=self.output_dim, seed=self.seed, teacher_idx=self.teacher_idx)
            elif self.pate and self.student_prompt:
                  self.train_idx, self.train_labels, self.test_idx, self.test_labels = node_sample_pate_student(self.data, self.shot_num, seed=self.seed, device=self.device, dataset_name=self.dataset_name, pre_train_data=self.pre_train_data, pre_train_type=self.pre_train_type, prompt_type=self.prompt_type, gnn_type=self.gnn_type)
            elif self.weighted_pate and not self.student_prompt:
                  self.train_idx, self.train_labels, self.test_idx, self.test_labels, self.average_centrality_score = node_sample_weighted_pate_teacher(self.data, self.shot_num, num_classes=self.output_dim, seed=self.seed, teacher_idx=self.teacher_idx)
            elif self.weighted_pate and self.student_prompt:
                  self.train_idx, self.train_labels, self.test_idx, self.test_labels = node_sample_pate_student(self.data, self.shot_num, seed=self.seed, device=self.device, dataset_name=self.dataset_name, pre_train_data=self.pre_train_data, pre_train_type=self.pre_train_type, prompt_type=self.prompt_type, gnn_type=self.gnn_type)

            # hyperparamerters for DP-SGD
            self.dpsgd = dpsgd
            self.sample_rate = sample_rate
            self.eps = eps
            self.delta = delta

            if self.batch_size > len(self.train_idx):
                  self.batch_size = len(self.train_idx)

      def train(self, data, train_idx):
            if self.dpsgd:
                  steps = int(1/self.sample_rate)
                  for i in range(steps):
                        accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=self.output_dim).to(self.device)
                        self.module_list.train()
                        self.optimizer.zero_grad() 
                        out = self.module_list[0](data.x, data.edge_index, batch=None) 
                        out = self.module_list[1](out)
                        pred = out.argmax(dim=1)
                        sample_idx = np.random.choice(train_idx, int(self.sample_rate * len(train_idx)))
                        acc = accuracy(pred[sample_idx], data.y[sample_idx])
                        loss_fn = nn.CrossEntropyLoss(reduction='none')
                        losses = loss_fn(out[sample_idx], data.y[sample_idx])

                        per_sample_grads = []
                        for i in range(len(losses)):
                              grad = torch.autograd.grad(losses[i], self.module_list.parameters(), retain_graph=True, allow_unused=True)
                              grad = torch.cat([g.view(-1) for g in grad])
                              per_sample_grads.append(grad)
                        per_sample_grads = torch.stack(per_sample_grads)
                        # clip per sample gradient
                        per_sample_grads = dp_utils.clip_and_accumulate(per_sample_grads, clipping=L2NORM_BOUND, device=self.device)
                        sanitized_grads = dp_utils.add_noise(per_sample_grads, self.noise_multiplier, L2NORM_BOUND, self.device, num_samples=len(losses))
                        start = 0
                        recovered_grads = []
                        for p in self.module_list.parameters():
                              param_length = p.numel()
                              grad = sanitized_grads[start:start + param_length].view_as(p)
                              recovered_grads.append(grad)
                              start += param_length
                        # Assign the recovered gradients back to the model parameters
                        for p, grad in zip(self.module_list.parameters(), recovered_grads):
                              p.grad = grad

                        loss = self.criterion(out[sample_idx], data.y[sample_idx])
                        self.optimizer.step()          
                        # update gnn, answering
                        for p, g in zip(self.gnn.parameters(), self.module_list[0].parameters()):
                              p.grad = g.grad
                        for p, g in zip(self.answering.parameters(), self.module_list[1].parameters()):
                              p.grad = g.grad
                        self.accountant.step(noise_multiplier=self.noise_multiplier, sample_rate=self.sample_rate)
                        spent_eps = self.accountant.get_epsilon(delta=self.delta)
                        print('spent eps: {}'.format(spent_eps))
            else:
                  accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=self.output_dim).to(self.device)
                  self.gnn.train()
                  self.answering.train()
                  self.optimizer.zero_grad() 
                  out = self.gnn(data.x, data.edge_index, batch=None) 
                  out = self.answering(out)
                  pred = out.argmax(dim=1)
                  acc = accuracy(pred[train_idx], data.y[train_idx])
                  loss = self.criterion(out[train_idx], data.y[train_idx])
                  loss.backward()  
                  self.optimizer.step()  

            return loss.item(), acc.item()
            
      def GPPTtrain(self, data, train_idx):
            if self.dpsgd:
                  steps = int(1/self.sample_rate)
                  for i in range(steps):
                        accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=self.output_dim).to(self.device)
                        self.prompt.train()
                        self.gnn.eval()
                        node_embedding = self.gnn(data.x, data.edge_index)
                        out = self.prompt(node_embedding, data.edge_index)
                        pred = out.argmax(dim=1)

                        sample_idx = np.random.choice(train_idx, int(self.sample_rate * len(train_idx)))
                        acc = accuracy(pred[sample_idx], data.y[sample_idx])
                        loss = self.criterion(out[sample_idx], data.y[sample_idx])
                        loss = loss + 0.001 * constraint(self.device, self.prompt.get_TaskToken())
                        loss_fn = nn.CrossEntropyLoss(reduction='none')
                        losses = loss_fn(out[sample_idx], data.y[sample_idx])
                        losses = losses + 0.001 * constraint(self.device, self.prompt.get_TaskToken())

                        per_sample_grads = []
                        for i in range(len(losses)):
                              grad = torch.autograd.grad(losses[i], self.prompt.parameters(), retain_graph=True, allow_unused=True)
                              grad = torch.cat([g.view(-1) for g in grad[1:]])
                              per_sample_grads.append(grad)
                        per_sample_grads = torch.stack(per_sample_grads)
                        # clip per sample gradient
                        per_sample_grads = dp_utils.clip_and_accumulate(per_sample_grads, clipping=L2NORM_BOUND, device=self.device)
                        sanitized_grads = dp_utils.add_noise(per_sample_grads, self.noise_multiplier, L2NORM_BOUND, self.device, len(losses))
                        start = 0
                        recovered_grads = []
                        for i, p in enumerate(self.prompt.parameters()):
                              if i > 0:
                                    param_length = p.numel()
                                    grad = sanitized_grads[start:start + param_length].view_as(p)
                                    recovered_grads.append(grad)
                                    start += param_length
                        # Assign the recovered gradients back to the model parameters
                        for (i, p), grad in zip(enumerate(self.prompt.parameters()), recovered_grads):
                              if i > 0:
                                    p.grad = grad

                        self.pg_opi.step() 
                        self.pg_opi.zero_grad()
                        mid_h = self.prompt.get_mid_h()
                        self.prompt.update_StructureToken_weight(mid_h)

                        self.accountant.step(noise_multiplier=self.noise_multiplier, sample_rate=self.sample_rate)
                        spent_eps = self.accountant.get_epsilon(delta=self.delta)
                        print('spent eps: {}'.format(spent_eps))
            else:
                  accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=self.output_dim).to(self.device)
                  self.prompt.train()
                  self.gnn.eval()
                  node_embedding = self.gnn(data.x, data.edge_index)
                  out = self.prompt(node_embedding, data.edge_index)
                  pred = out.argmax(dim=1)
                  acc = accuracy(pred[train_idx], data.y[train_idx])
                  loss = self.criterion(out[train_idx], data.y[train_idx])
                  loss = loss + 0.001 * constraint(self.device, self.prompt.get_TaskToken())
                  self.pg_opi.zero_grad()
                  loss.backward()
                  self.pg_opi.step()
                  mid_h = self.prompt.get_mid_h()
                  self.prompt.update_StructureToken_weight(mid_h)
            return loss.item(), acc.item()
      
      
      def GPFTrain(self, train_loader, epoch):
            if self.dpsgd:
                  steps = int(1/self.sample_rate)
                  for i in range(steps):
                        accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=self.output_dim).to(self.device)
                        self.module_list.train()
                        self.gnn.eval()
                        total_loss = 0.0 
                        total_acc = 0.0
                        count = 0
                        outs = []
                        labels = []

                        sample_idx = np.random.choice(self.train_idx, int(self.sample_rate * len(self.train_idx)))
                        train_graphs = []
                        for graph in self.graphs_list:                              
                              if graph.index in sample_idx:
                                    train_graphs.append(graph)
                        train_dataset = GraphDataset(train_graphs)
                        train_loader = DataLoader(train_dataset, batch_size=len(train_graphs), shuffle=True)
                        for batch in train_loader:  
                              self.optimizer.zero_grad() 
                              batch = batch.to(self.device)
                              batch.x = self.module_list[0](batch.x)
                              out = self.gnn(batch.x, batch.edge_index, batch.batch, prompt = self.prompt, prompt_type = self.prompt_type)
                              out = self.module_list[1](out)
                              outs.append(out)
                              labels.append(batch.y)
                              pred = out.argmax(dim=1)
                              acc = accuracy(pred, batch.y)
                              loss = self.criterion(out, batch.y) 
                              loss_fn = nn.CrossEntropyLoss(reduction='none')
                              losses = loss_fn(out, batch.y)
                              per_sample_grads = []
                              for i in range(len(losses)):
                                    grad = torch.autograd.grad(losses[i], self.module_list.parameters(), retain_graph=True)
                                    grad = torch.cat([g.view(-1) for g in grad])
                                    per_sample_grads.append(grad)
                              per_sample_grads = torch.stack(per_sample_grads)
                              # clip per sample gradient
                              summed_grad = dp_utils.clip_and_accumulate(per_sample_grads, clipping=L2NORM_BOUND, device=self.device)
                              sanitized_grads = dp_utils.add_noise(summed_grad, self.noise_multiplier, L2NORM_BOUND, self.device, num_samples=len(losses))

                              start = 0
                              recovered_grads = []
                              for p in self.module_list.parameters():
                                    param_length = p.numel()
                                    grad = sanitized_grads[start:start + param_length].view_as(p)
                                    recovered_grads.append(grad)
                                    start += param_length
                              # Assign the recovered gradients back to the model parameters
                              for p, grad in zip(self.module_list.parameters(), recovered_grads):
                                    p.grad = grad

                              self.optimizer.step() 
                              total_loss += loss.item()  
                              total_acc += acc.item()
                              count += 1

                              self.accountant.step(noise_multiplier=self.noise_multiplier, sample_rate=self.sample_rate)
                              spent_eps = self.accountant.get_epsilon(delta=self.delta)
                              print('spent eps: {}'.format(spent_eps))
            else:
                  accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=self.output_dim).to(self.device)
                  self.module_list.train()
                  self.gnn.eval()
                  total_loss = 0.0 
                  total_acc = 0.0
                  count = 0
                  outs = []
                  labels = []
                  for batch in train_loader:  
                        self.optimizer.zero_grad() 
                        batch = batch.to(self.device)
                        batch.x = self.module_list[0](batch.x)
                        out = self.gnn(batch.x, batch.edge_index, batch.batch, prompt = self.prompt, prompt_type = self.prompt_type)
                        out = self.module_list[1](out)
                        outs.append(out)
                        labels.append(batch.y)
                        pred = out.argmax(dim=1)
                        acc = accuracy(pred, batch.y)
                        loss = self.criterion(out, batch.y)  
                        loss.backward()  
                        self.optimizer.step()  
                        total_loss += loss.item()  
                        total_acc += acc.item()
                        count += 1
            return total_loss / count, total_acc / count, torch.cat(outs, dim=0), torch.cat(labels, dim=0)
            
      def AllInOneTrain(self, train_loader):
            if self.dpsgd:
                  self.gnn.eval()
                  self.module_list.train()
                  answer_loss, answer_acc = self.prompt.Tune_DP(train_loader, self.gnn, self.module_list, self.optimizer, self.device, self.output_dim, L2NORM_BOUND, self.noise_multiplier, self.eps, self.delta, self.accountant, self.sample_rate, self.graph_list, self.train_idx)
            else:
                  self.gnn.eval()
                  self.module_list.train()
                  answer_loss, answer_acc = self.prompt.Tune(train_loader, self.gnn,  self.module_list, self.criterion, self.optimizer, self.device, self.output_dim)
            return answer_loss, answer_acc

      def run(self):
            self.initialize_gnn()
            self.initialize_prompt()
            self.initialize_optimizer()
            idx_train = self.train_idx
            train_lbls = self.train_labels
            idx_test = self.test_idx
            test_lbls = self.test_labels
            print('Num of Train: {}, Num of Test: {}'.format(len(idx_train), len(idx_test)))
            if self.prompt_type == 'GPPT':
                  node_embedding = self.gnn(self.data.x, self.data.edge_index)
                  self.prompt.weigth_init(node_embedding,self.data.edge_index, self.data.y, idx_train)

            if self.prompt_type in ['All-in-one', 'GPF', 'GPF-plus']:
                  train_graphs = []
                  test_graphs = []
                  for graph in self.graphs_list:                              
                        if graph.index in idx_train:
                              train_graphs.append(graph)
                        elif graph.index in idx_test:
                              test_graphs.append(graph)
                  print('Done!!!')
                  # reduce test dataset to 50*len(train_idx)
                  if len(test_graphs) > 50*len(train_graphs):
                        test_graphs = test_graphs[:50*len(train_graphs)]
                  print('Final Num of Train: {}, Num of Test: {}'.format(len(train_graphs), len(test_graphs)))

                  train_dataset = GraphDataset(train_graphs)
                  test_dataset = GraphDataset(test_graphs)
                  # create train loader
                  train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
                  test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)
                  print("prepare induce graph data is finished!")

            if self.dpsgd:
                  self.noise_multiplier = dp_utils.get_noise_multiplier(target_epsilon=self.eps, target_delta=self.delta, sample_rate=self.sample_rate, epochs=self.epochs, accountant='rdp')
                  print("noise multiplier: {}".format(self.noise_multiplier))
                  self.accountant = create_accountant(mechanism='rdp')
                  
            for epoch in range(1, self.epochs):
                  t0 = time.time()
                  if self.prompt_type == 'None':
                        loss, acc = self.train(self.data, idx_train)        
                        test_acc, test_loss, _, _ = GNNNodeEva(self.data, idx_test,  self.gnn, self.answering, self.output_dim, self.device) 
                  elif self.prompt_type == 'GPPT':
                        loss, acc = self.GPPTtrain(self.data, idx_train)  
                        test_acc, _, _, test_loss, _= GPPTEva(self.data, idx_test, self.gnn, self.prompt, self.output_dim, self.device) 
                  elif self.prompt_type == 'All-in-one':
                        loss, acc = self.AllInOneTrain(train_loader)      
                        test_acc, _, _, test_loss, _ = AllInOneEva(test_loader, self.gnn, self.module_list, self.output_dim, self.device)                                          
                  elif self.prompt_type in ['GPF', 'GPF-plus']:
                        loss, acc, outs, labels = self.GPFTrain(train_loader, epoch)   
                        test_acc, _, _, test_loss, _ = GPFEva(test_loader, self.gnn, self.module_list, self.output_dim, self.device) 
                  print("Epoch {:03d} |  Time(s) {:.4f} | Train Loss {:.4f} | Test Loss {:.4f} | Train Acc {:.4f} | Test Acc {:.4f}  ".format(epoch, time.time() - t0, loss, test_loss, acc, test_acc))
            return  test_acc

      def pate_ensemble(self, header):

            self.initialize_gnn()
            self.initialize_prompt()
            self.initialize_optimizer()
            idx_train = self.train_idx
            train_lbls = self.train_labels
            idx_test = self.test_idx
            test_lbls = self.test_labels

            print('Num of Train: {}, Num of Test: {}'.format(len(idx_train), len(idx_test)))
            # GPPT prompt initialtion
            if self.prompt_type == 'GPPT':
                  node_embedding = self.gnn(self.data.x, self.data.edge_index)
                  self.prompt.weigth_init(node_embedding,self.data.edge_index, self.data.y, idx_train)

            if self.prompt_type in ['Gprompt', 'All-in-one', 'GPF', 'GPF-plus']:
                  train_graphs = []
                  test_graphs = []
                  print('distinguishing the train dataset and test dataset...')
                  for graph in self.graphs_list:                             
                        if graph.index in idx_train:
                              if self.pate or self.student_prompt:
                                    # ipdb.set_trace()
                                    graph.y = train_lbls[torch.nonzero(idx_train == graph.index).item()]
                              train_graphs.append(graph)
                        elif graph.index in idx_test:
                              test_graphs.append(graph)
                  print('Done!!!')

                  # reduce test dataset to 50*len(train_idx)
                  if len(test_graphs) > 50*len(train_graphs):
                        test_graphs = test_graphs[:50*len(train_graphs)]
                  print('Final Num of Train: {}, Num of Test: {}'.format(len(train_graphs), len(test_graphs)))

                  train_dataset = GraphDataset(train_graphs)
                  test_dataset = GraphDataset(test_graphs)

                  # create train loader
                  train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
                  test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)
                  print("prepare induce graph data is finished!")
            elif self.prompt_type in ['GPPT']:
                  if len(idx_test) > 50*len(idx_train):
                        idx_test = idx_test[:50*len(idx_train)]
                  print('Final Num of Train: {}, Num of Test: {}'.format(len(idx_train), len(idx_test)))
                  
            for epoch in range(1, self.epochs):
                  t0 = time.time()
                  if self.prompt_type == 'None':
                        loss, acc = self.train(self.data, idx_train)        
                        test_acc, test_loss, _, _ = GNNNodeEva(self.data, idx_test,  self.gnn, self.answering, self.output_dim, self.device) 
                  elif self.prompt_type == 'GPPT':
                        loss, acc = self.GPPTtrain(self.data, idx_train)  # save self.prompt
                        test_acc, pre_labels, _, test_loss, _= GPPTEva(self.data, idx_test, self.gnn, self.prompt, self.output_dim, self.device) 
                  elif self.prompt_type == 'All-in-one':
                        loss, acc = self.AllInOneTrain(train_loader)   # save self.module_list (include prompt and answering)
                        test_acc, pre_labels, _, test_loss, _ = AllInOneEva(test_loader, self.gnn, self.module_list, self.output_dim, self.device)                                           
                  elif self.prompt_type in ['GPF', 'GPF-plus']:
                        loss, acc, outs, labels = self.GPFTrain(train_loader, epoch)   # save self.module_list (include prompt and answering)
                        test_acc, pre_labels, _, test_loss, _ = GPFEva(test_loader, self.gnn, self.module_list, self.output_dim, self.device)                                           
                  
                  print("Epoch {:03d} |  Time(s) {:.4f} | Train Loss {:.4f} | Test Loss {:.4f} | Train Acc {:.4f} | Test Acc {:.4f}  ".format(epoch, time.time() - t0, loss, test_loss, acc, test_acc))

            if self.pate and not self.student_prompt:
                  ensemble_save_path = './dataspace/PateEnsemble/{}shot/{}_{}/seed_{}/{}_{}_{}_{}.pt'.format(self.shot_num, self.dataset_name, self.pre_train_data, self.seed, self.pre_train_type, self.prompt_type, self.gnn_type, self.teacher_idx)
            elif self.pate and self.student_prompt:
                  ensemble_save_path = './dataspace/StudentPrompt/{}shot/{}_{}/seed_{}/{}_{}_{}.pt'.format(self.shot_num, self.dataset_name, self.pre_train_data, self.seed, self.pre_train_type, self.prompt_type, self.gnn_type)
            if not os.path.exists(os.path.split(ensemble_save_path)[0]):
                  os.makedirs(os.path.split(ensemble_save_path)[0])
            
            # save prompt or module_list as pt
            if self.prompt_type in ['GPPT']:
                  torch.save(self.prompt, ensemble_save_path)
            elif self.prompt_type in ['All-in-one', 'GPF', 'GPF-plus']:
                  torch.save(self.module_list, ensemble_save_path)
                              
            print(f"Final True Accuracy: {test_acc:.4f}" )
            if self.weighted_pate:
                  return idx_train, self.average_centrality_score
            else:
                  return idx_train


