import torch, sys, pickle, argparse
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Subset
from models.MLP import MLPNet
import argparse as ap
from types import SimpleNamespace
from dsl.gdl import find_matching_subgraph
from data_loader import *

# assume label is int
def get_num_classes(y):
  myset = set()
  for i, val in enumerate(y):
      if not int(y[i]) in myset:
          myset.add(int(y[i]))
  return len(myset)
  
def learn(dataset, learning_rate=0.01):
  print("Choen learning rate : {}".format(learning_rate))
  DATASET = dataset
  # ==== 1. Simulate or load your precomputed graph-level features and labels ====
  
  with open('embeddings/X_{}.pickle'.format(DATASET), 'rb') as f:
    X = pickle.load(f)
  
  with open('embeddings/Y_{}.pickle'.format(DATASET), 'rb') as f:
    y = pickle.load(f)

  X = torch.tensor(X, dtype=torch.float)
  # Check if labels start from 1 instead of 0 and adjust if needed
  min_label = min(y)
  if min_label > 0:
    print(f"Shifting labels from [{min_label}-{max(y)}] to [0-{max(y)-min_label}]")
    y = [label - min_label for label in y]
  
  y = torch.tensor(y, dtype=torch.long)
  num_graphs = len(X) 
  num_features = len(X[0])   # number of features per graph
  num_classes = get_num_classes(y)    # number of classes (e.g., binary classification)

  from generate_embeddings import load_dataset
  data = load_dataset(DATASET)
  
  print("Load data : Done")
  train_indices = list(data.train_graphs)
  val_indices = list(data.val_graphs)
  test_indices = list(data.test_graphs)

  print(len(train_indices))
  print(len(val_indices))
  print(len(test_indices))
  print("Num features: {}".format(num_features))
  #sys.exit()
  

  # ==== 2. Split dataset into train:val:test = 8:1:1 ====
  dataset = TensorDataset(X, y)
  train_dataset = Subset(dataset, train_indices)
  val_dataset = Subset(dataset, val_indices)
  test_dataset = Subset(dataset, test_indices)
  
  # Create dataloaders
  train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
  val_loader = DataLoader(val_dataset, batch_size=64)
  test_loader = DataLoader(test_dataset, batch_size=64)
  

  # Using MLPNet from models folder instead of the inline GraphMLP class

  # ==== 4. Setup training ====
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  #device = torch.device('cuda')
  
  # Create model arguments
  model_args = SimpleNamespace(
    device=device,
    #mlp_hidden=[32, 16],  # Hidden layer dimensions
    mlp_feature_dim=128,
    mlp_hidden=[64],  # Hidden layer dimensions
    readout="mean",       # Readout function
    #dropout=0.2          # Dropout rate
    dropout=0.5,          # Dropout rate
  )
  
  # Initialize the MLPNet model from models folder
  model = MLPNet(input_dim=num_features, output_dim=num_classes, model_args=model_args).to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  criterion = nn.CrossEntropyLoss()
  
  # ==== 5. Training and Evaluation Functions ====
  def train():
    model.train()
    total_loss = 0
    for xb, yb in train_loader:
      xb, yb = xb.to(device), yb.to(device)
      # Create a SimpleNamespace object to mimic the expected data format
      # Create batch indices: each sample in the batch gets assigned its index
      batch_indices = torch.arange(xb.size(0), device=device)
      data = SimpleNamespace(x=xb, batch=batch_indices)
      optimizer.zero_grad()
      logits, _, _ = model(data)
      loss = criterion(logits, yb)
      loss.backward()
      optimizer.step()
      total_loss += loss.item() * xb.size(0)
    return total_loss / len(train_loader.dataset)
  
  @torch.no_grad()
  def evaluate(loader):
    model.eval()
    correct = 0
    for xb, yb in loader:
      xb, yb = xb.to(device), yb.to(device)
      # Create a SimpleNamespace object to mimic the expected data format
      # Create batch indices: each sample in the batch gets assigned its index
      batch_indices = torch.arange(xb.size(0), device=device)
      data = SimpleNamespace(x=xb, batch=batch_indices)
      logits, _, _ = model(data)
      pred = logits.argmax(dim=1)
      correct += (pred == yb).sum().item()
    return correct / len(loader.dataset)
  
  # ==== 6. Training Loop with Validation-based Best Model Selection ====
  best_val_acc = 0.0
  best_model_state = None
  patience_counter = 0
  max_epoch = 800
  early_stopping = 100
  
  for epoch in range(1, max_epoch + 1):
    train_loss = train()
    val_acc = evaluate(val_loader)
  
    if val_acc > best_val_acc:
      best_val_acc = val_acc
      best_model_state = model.state_dict()  # Save best model
      patience_counter = 0  # Reset patience counter
    else:
      patience_counter += 1  # Increment patience counter
      
    print(f"Epoch {epoch:03d}, Train Loss: {train_loss:.4f}, Val Acc: {val_acc:.4f}, Patience: {patience_counter}/{early_stopping}")
    
    # Early stopping check
    if patience_counter >= early_stopping:
      print(f"Early stopping triggered after {epoch} epochs with no improvement for {early_stopping} epochs.")
      break
  
  # ==== 7. Load Best Model and Evaluate on Test Set ====
  model.load_state_dict(best_model_state)
  test_acc = evaluate(test_loader)
  print(f"Best Validation Accuracy: {best_val_acc:.4f}")
  print(f"Final Test Accuracy (Best Validation Model): {test_acc:.4f}")

  sys.exit()

  # ==== 8. Feature Importance using SHAP (KernelExplainer) ====
  import shap
  import numpy as np

  # Use small samples for speed and stability
  X_test_sample = X[test_indices].cpu().numpy()
  background = X[train_indices].cpu().numpy()


  # Define model prediction function for SHAP (outputs probabilities)
  def model_predict(x_numpy):
      x_tensor = torch.tensor(x_numpy, dtype=torch.float, device=device)
      # Create a SimpleNamespace object to mimic the expected data format
      # Create batch indices: each sample in the batch gets assigned its index
      batch_indices = torch.arange(x_tensor.size(0), device=device)
      data = SimpleNamespace(x=x_tensor, batch=batch_indices)
      model.eval()
      with torch.no_grad():
          _, probs, _ = model(data)  # MLPNet already returns probabilities
      return probs.cpu().numpy()
  explainer = shap.KernelExplainer(model_predict, background)
  shap_values = explainer.shap_values(X_test_sample, nsamples=100)



  # For multi-class, take mean absolute SHAP value across classes for each instance
  if isinstance(shap_values, list):
    # shap_values: list of arrays, each (num_instances, num_features)
    #print("shap_values type:", type(shap_values))
    #print("shap_values[0] shape:", np.shape(shap_values[0]))
    abs_shap = np.abs(np.stack(shap_values, axis=-1))  # (num_instances, num_features, num_classes)
    mean_abs_shap = np.mean(abs_shap, axis=2)  # (num_instances, num_features)
  else:
    # print("shap_values type:", type(shap_values))
    # print("shap_values shape:", np.shape(shap_values))
    mean_abs_shap = np.mean(np.abs(shap_values), axis=2)  # (num_instances, num_features)
  
  print("mean_abs_shap shape:", mean_abs_shap.shape)
  
  important_features_per_instance = np.argsort(mean_abs_shap, axis=1)[:, ::-1]
  
  with open('embeddings/gdl_programs_{}.pickle'.format(DATASET), 'rb') as f:
    gdl_programs = pickle.load(f) 
  
  print(test_indices)
  print(len(important_features_per_instance))

  k = 1  # 가장 중요한 feature 개수
  print("\nFeature importances per instance (descending order):")
  for idx, important_features in enumerate(important_features_per_instance):
    print(f"Test instance {idx}: {test_indices[idx]}")
    print(f"Original label {data.graph_to_label[test_indices[idx]]}")
    for feature_idx in important_features[:k]:
        importance = mean_abs_shap[idx, feature_idx]
        if isinstance(importance, np.ndarray):
            if importance.size == 1:
                importance = importance.item()
            else:
                print(f"Warning: importance shape={importance.shape}, using first value only")
                importance = importance.flatten()[0]
        print(f"Feature {feature_idx}: importance={importance:.6f}")
        print(f"Feature value : {X[test_indices[idx]]}")
        print("GDL program:")
        print(gdl_programs[feature_idx].nodeVars)
        print(gdl_programs[feature_idx].edgeVars)
        print("Matching Subgraph:")
        subgraph = find_matching_subgraph(gdl_programs[feature_idx], data.graphs[test_indices[idx]], data)
        if subgraph == None:
          continue
        print(subgraph)
        print("Nodes:")
        print(subgraph[0])
        print("Edges:")
        for edge in subgraph[1]:
          print(data.A[edge])
        print("-" * 50)


#def process(dataset, k):
def process(dataset, p):
  if dataset == 'BBBP':
    data = load_BBBP()
  elif dataset == 'BACE':
    data = load_BACE()
  else:  
    data = load_Data(dataset)

  #with open('../datasets/{}/tr.pickle'.format(dataset), 'rb') as f:
  #  train_indices = list(pickle.load(f))
  #k = int(len(train_indices) * p)
  total = len(data.train_graphs)
  if total == 0:
    raise ValueError(f"No training graphs found for dataset: {dataset}")
  # Clamp k to [1, total]
  k = max(1, min(total, int(total * p)))
  print("Chosen k : {}".format(k))
  # Import the filter_top_k module and call its process function
  from generate_embeddings import generate_embeddings 
  generate_embeddings(dataset, k)

if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('-d', '--dataset', help="input dataset")
  parser.add_argument('-lr', '--learning_rate', type=float, default=0.01, help="learning rate (default: 0.01)")
  #parser.add_argument('-k', '--k', type=int, default=32, help="number of top k")
  parser.add_argument('-p', '--p', type=float, default=1.0, help="portion")
  args = parser.parse_args()
  dataset = args.dataset
  learning_rate = args.learning_rate
  process(dataset, args.p)
  learn(dataset, learning_rate=learning_rate)


