import torch, sys, pickle, argparse
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Subset
from models.MLP import MLPNet
import argparse as ap
from types import SimpleNamespace
from dsl.gdl import * 
from data_loader import *
import copy
from generate_embeddings import generate_embeddings_tmp, generate_embeddings
import time

NAME = ''
PORTION = 0.0
K = 1 


# assume label is int
def get_num_classes(y):
  myset = set()
  for i, val in enumerate(y):
      if not int(y[i]) in myset:
          myset.add(int(y[i]))
  return len(myset)
  
def learn(dataset, learning_rate=0.01):
  #print("Choen learning rate : {}".format(learning_rate))
  DATASET = dataset
  # ==== 1. Simulate or load your precomputed graph-level features and labels ====
  
  with open('embeddings/X_{}.pickle'.format(DATASET), 'rb') as f:
    X = pickle.load(f)
  
  with open('embeddings/Y_{}.pickle'.format(DATASET), 'rb') as f:
    y = pickle.load(f)

  X = torch.tensor(X, dtype=torch.float)
  # Check if labels start from 1 instead of 0 and adjust if needed
  min_label = min(y)
  if min_label > 0:
    #print(f"Shifting labels from [{min_label}-{max(y)}] to [0-{max(y)-min_label}]")
    y = [label - min_label for label in y]
  
  y = torch.tensor(y, dtype=torch.long)
  num_graphs = len(X) 
  num_features = len(X[0])   # number of features per graph
  num_classes = get_num_classes(y)    # number of classes (e.g., binary classification)

  from generate_embeddings import load_dataset
  data = load_dataset(DATASET)
  
  #print("Load data : Done")
  train_indices = list(data.train_graphs)
  val_indices = list(data.val_graphs)
  test_indices = list(data.test_graphs)

  #print(len(train_indices))
  #print(len(val_indices))
  #print(len(test_indices))
  #print("Num features: {}".format(num_features))
  #sys.exit()
  

  # ==== 2. Split dataset into train:val:test = 8:1:1 ====
  dataset = TensorDataset(X, y)
  train_dataset = Subset(dataset, train_indices)
  val_dataset = Subset(dataset, val_indices)
  test_dataset = Subset(dataset, test_indices)
  
  # Create dataloaders
  train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
  val_loader = DataLoader(val_dataset, batch_size=64)
  test_loader = DataLoader(test_dataset, batch_size=64)
  

  # Using MLPNet from models folder instead of the inline GraphMLP class

  # ==== 4. Setup training ====
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  #device = torch.device('cuda')
  
  # Create model arguments
  model_args = SimpleNamespace(
    device=device,
    #mlp_hidden=[32, 16],  # Hidden layer dimensions
    mlp_feature_dim=128,
    mlp_hidden=[64],  # Hidden layer dimensions
    readout="mean",       # Readout function
    #dropout=0.2          # Dropout rate
    dropout=0.5,          # Dropout rate
  )
  
  # Initialize the MLPNet model from models folder
  model = MLPNet(input_dim=num_features, output_dim=num_classes, model_args=model_args).to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  criterion = nn.CrossEntropyLoss()
  
  # ==== 5. Training and Evaluation Functions ====
  def train():
    model.train()
    total_loss = 0
    for xb, yb in train_loader:
      xb, yb = xb.to(device), yb.to(device)
      # Create a SimpleNamespace object to mimic the expected data format
      # Create batch indices: each sample in the batch gets assigned its index
      batch_indices = torch.arange(xb.size(0), device=device)
      data = SimpleNamespace(x=xb, batch=batch_indices)
      optimizer.zero_grad()
      logits, _, _ = model(data)
      loss = criterion(logits, yb)
      loss.backward()
      optimizer.step()
      total_loss += loss.item() * xb.size(0)
    return total_loss / len(train_loader.dataset)


  @torch.no_grad()
  def evaluate_tmp(loader):
    predictions = []
    model.eval()
    for xb, yb in loader:
      xb, yb = xb.to(device), yb.to(device)
      # Create a SimpleNamespace object to mimic the expected data format
      # Create batch indices: each sample in the batch gets assigned its index
      batch_indices = torch.arange(xb.size(0), device=device)
      data = SimpleNamespace(x=xb, batch=batch_indices)
      logits, _, _ = model(data)
      pred = logits.argmax(dim=1)
      predictions.append(pred)
    return predictions 
 
  
  @torch.no_grad()
  def evaluate(loader):
    model.eval()
    correct = 0
    for xb, yb in loader:
      xb, yb = xb.to(device), yb.to(device)
      # Create a SimpleNamespace object to mimic the expected data format
      # Create batch indices: each sample in the batch gets assigned its index
      batch_indices = torch.arange(xb.size(0), device=device)
      data = SimpleNamespace(x=xb, batch=batch_indices)
      logits, _, _ = model(data)
      pred = logits.argmax(dim=1)
      correct += (pred == yb).sum().item()
    return correct / len(loader.dataset)
  
  # ==== 6. Training Loop with Validation-based Best Model Selection ====
  best_val_acc = 0.0
  best_model_state = None
  patience_counter = 0
  max_epoch = 800
  early_stopping = 100
  
  for epoch in range(1, max_epoch + 1):
    train_loss = train()
    val_acc = evaluate(val_loader)
  
    if val_acc > best_val_acc:
      best_val_acc = val_acc
      best_model_state = model.state_dict()  # Save best model
      patience_counter = 0  # Reset patience counter
    else:
      patience_counter += 1  # Increment patience counter
      
    #print(f"Epoch {epoch:03d}, Train Loss: {train_loss:.4f}, Val Acc: {val_acc:.4f}, Patience: {patience_counter}/{early_stopping}")
    
    # Early stopping check
    if patience_counter >= early_stopping:
      #print(f"Early stopping triggered after {epoch} epochs with no improvement for {early_stopping} epochs.")
      break
  
  # ==== 7. Load Best Model and Evaluate on Test Set ====
  model.load_state_dict(best_model_state)
  test_acc = evaluate(test_loader)
  #print(f"Best Validation Accuracy: {best_val_acc:.4f}")
  #print(f"Final Test Accuracy (Best Validation Model): {test_acc:.4f}")


  # ==== 8. Explanation ====
  import lime
  import lime.lime_tabular
  import numpy as np
  
  # Use small samples for speed and stability
  X_test_sample = X[test_indices].cpu().numpy()
  X_train_sample = X[train_indices].cpu().numpy()
  
  # Define model prediction function for LIME
  def model_predict_lime(x_numpy):
      x_tensor = torch.tensor(x_numpy, dtype=torch.float, device=device)
      batch_indices = torch.arange(x_tensor.size(0), device=device)
      data = SimpleNamespace(x=x_tensor, batch=batch_indices)
      model.eval()
      with torch.no_grad():
          _, probs, _ = model(data)
      return probs.cpu().numpy()
  
  # Create LIME explainer
  lime_explainer = lime.lime_tabular.LimeTabularExplainer(
      X_train_sample,
      mode='classification',
      discretize_continuous=True,
      random_state=42
  )
  # Get explanations for all test instances
  all_feature_importance = []
  for i in range(len(X_test_sample)):
      explanation = lime_explainer.explain_instance(
          X_test_sample[i],
          model_predict_lime,
          num_features=X_test_sample.shape[1],
          num_samples=1000
      )

      # Extract feature importance (absolute values)
      feature_importance = np.zeros(X_test_sample.shape[1])
      for feat_idx, importance in explanation.as_map()[1]:  # assuming binary classification
          feature_importance[feat_idx] = abs(importance)

      all_feature_importance.append(feature_importance)

  mean_abs_lime = np.array(all_feature_importance)
  print("mean_abs_lime shape:", mean_abs_lime.shape)
  important_features_per_instance = np.argsort(mean_abs_lime, axis=1)[:, ::-1]

  if DATASET == 'BBBP':
    data = load_BBBP()
  elif DATASET == 'BACE':
    data = load_BACE()
  else:  
    data = load_Data(DATASET)

  fidelity = []
  sparsity = [] 
  _start = time.time()
  accumulated = []


  if set(test_indices) != data.test_graphs:
    print(test_indices)
    print(data.test_graphs)
    raise "Something changed in the test indices"

  for i, test_graph_idx in enumerate(test_indices):
    important_feature_list = important_features_per_instance[i]
    important_feature_list = important_feature_list[:K]
    minimized_subgraph = minimize(test_graph_idx, data, important_feature_list)
    
    with open('embeddings/gdl_programs.pickle', 'rb') as f:
      gdl_programs = pickle.load(f)
    
    X_ = [[]] 
    y_ = [data.graph_to_label[test_graph_idx]]
  
    for _, pgm in enumerate(gdl_programs):
      if eval_GDL_program_DFS(pgm, minimized_subgraph, data):
        X_[0].append(1.0) 
      else:
        X_[0].append(0.0)

    X_ = torch.tensor(X_, dtype=torch.float)
    y_ = torch.tensor(y_, dtype=torch.long)
    #

    tmp_dataset = TensorDataset(X_, y_)
    tmp_data_loader = DataLoader(tmp_dataset, batch_size = 64)
    predictions_new = evaluate_tmp(tmp_data_loader)
   
    X_original = [X[test_graph_idx].tolist()]
    y_original = [data.graph_to_label[test_graph_idx]] 
    #

    X_original = torch.tensor(X_original, dtype=torch.float)
    y_original = torch.tensor(y_original, dtype=torch.long)
    #

    original_dataset = TensorDataset(X_original, y_original)
    original_loader = DataLoader(original_dataset, batch_size = 64)
    predictions_original = evaluate_tmp(original_loader)
    #

    if predictions_new[0] == predictions_original[0]:
      fidelity.append(0.0)
      #print("Correct")
    else:
      fidelity.append(1.0)
      #print("Incorrect")
     
    original_graph_len = len(data.graphs[test_graph_idx][0])
    subgraph_len = len(minimized_subgraph[0])
    sparsity.append(1 - (subgraph_len/original_graph_len))
    taken = time.time() - _start
    accumulated.append(taken)
  avg_fidelity = sum(fidelity) / len(fidelity)
  avg_sparsity = sum(sparsity) / len(sparsity)
  #print("------------------------------------")
  #print("Sparsity : {}".format(sparsity))
  #print("Fidelity : {}".format(fidelity))
  #print("====================================")
  print()
  print()
  print("Avg Sparsity : {}".format(avg_sparsity))
  print("Avg Fidelity : {}".format(avg_fidelity))
  #print("Accumulated")
  #print(accumulated)


#'''
def process_tmp(dataset, p):
  if NAME == 'BBBP':
    data = load_BBBP()
  elif NAME == 'BACE':
    data = load_BACE()
  else:
    data = load_Data(NAME)
  # data = load_Data(dataset)
  #data = load_Data(NAME)
  # print(dataset)
  # print(data.X_node)
  # raise
  k = int(len(data.train_graphs) * p)
  print("Chosen k : {}".format(k))
  # Import the filter_top_k module and call its process function
  generate_embeddings_tmp(dataset, k)
#'''

def generate_subgraph(data, graph_idx, subgraph):
  new_A = []
  new_node_attribute = []
  new_edge_attribute = []
  new_graph_label = []
  new_graph_indicator = []
  node_map = {}
  for i, node in enumerate(subgraph[0]):
    new_node_attribute.append(data.X_node[node])
    node_map[node] = i
    new_graph_indicator.append(1)
  for _, edge in enumerate(subgraph[1]):
    from_node = node_map[data.A[edge][0]]
    to_node = node_map[data.A[edge][1]]
    new_A.append((from_node, to_node))
    new_edge_attribute.append(data.X_edge[edge])
  new_graph_label.append(data.graph_to_label[graph_idx])

  path = 'datasets/Sparsity_Fidelity'
  cmd = 'rm -r {}'.format(path)
  os.system(cmd)

  cmd = 'mkdir {}'.format(path)
  os.system(cmd)
  with open ('{}/Sparsity_Fidelity_A.txt'.format(path), 'w') as f:
    for _, (from_node, to_node) in enumerate(new_A):
      f.write(f"{from_node+1}, {to_node+1}\n")
      f.write(f"{to_node+1}, {from_node+1}\n")

  with open ('{}/Sparsity_Fidelity_node_attributes.txt'.format(path), 'w') as f:
    for _, features in enumerate(new_node_attribute):
      for i, val in enumerate(features):
        #incomplete
        f.write(f"{val}")
        if i < len(features) - 1:
          f.write(f", ")
      f.write(f"\n")

  if len(data.X_edge[0]) > 0:
    with open ('{}/Sparsity_Fidelity_edge_attributes.txt'.format(path), 'w') as f:
      for _, features in enumerate(new_edge_attribute):
        for i, val in enumerate(features):
          #incomplete
          f.write(f"{val}")
          if i < len(features) - 1:
            f.write(f", ")
        f.write(f"\n")
        for i, val in enumerate(features):
          #incomplete
          f.write(f"{val}")
          if i < len(features) - 1:
            f.write(f", ")
        f.write(f"\n")


  with open ('{}/Sparsity_Fidelity_graph_labels.txt'.format(path), 'w') as f:
    for _, label in enumerate(new_graph_label):
      #incomplete
      f.write(f"{label}\n")

  with open ('{}/Sparsity_Fidelity_graph_indicator.txt'.format(path), 'w') as f:
    for _, label in enumerate(new_graph_indicator):
      #incomplete
      f.write(f"{label}\n")

  with open ('{}/Sparsity_Fidelity_train.txt'.format(path), 'w') as f:
    f.write("0")
  with open ('{}/Sparsity_Fidelity_val.txt'.format(path), 'w') as f:
    f.write("0")
  with open ('{}/Sparsity_Fidelity_test.txt'.format(path), 'w') as f:
    f.write("0")

  cmd = f"rm -r {path}/learned_GDL_programs"
  os.system(cmd)
  cmd = f"cp -r datasets/{NAME}/learned_GDL_programs {path}/"
  os.system(cmd)

def setGlobal(name, portion, topk):
  global PORTION 
  global NAME
  global K 
  NAME = name
  PORTION = portion
  K = topk

def filter_one(gdl_programs, test_graph_idx, data):
  filtered = set()
  for program in gdl_programs:
    if eval_GDL_program_DFS(program, data.graphs[test_graph_idx], data):
      filtered.add(program)
  return filtered


def remove_node(subgraph, gdl_programs, data):
  candidate_nodes = subgraph[0]
  
  # Try to generate a smaller subgraph
  for _, node in enumerate(candidate_nodes):
    new_subgraph = copy.deepcopy(subgraph)
    
    # Remove node
    new_subgraph[0].remove(node)    
 
    # Remove edges
    for _, edge_idx in enumerate(subgraph[1]):
      if node in data.A[edge_idx]:
        new_subgraph[1].remove(edge_idx)
    
    # check connected
    from util.connected import graph_is_connected
    if not (graph_is_connected(new_subgraph, data)):
      #print("graph is not connected")
      continue

    # check gdl_programs still return 1
    flag = True
    for _, program in enumerate(gdl_programs):
      if not eval_GDL_program_DFS_Slow(program, new_subgraph, data):
        flag = False
    if flag:
      # Refined
      return new_subgraph
  
  # Fixed point
  return subgraph


def minimize(test_graph_idx, data, important_feature_list):
  with open('embeddings/gdl_programs.pickle', 'rb') as f:
    gdl_programs = pickle.load(f)

  filtered_gdl_programs = [gdl_programs[i] for i in important_feature_list]
  #print("Filetered GDL programs : {}".format(len(filtered_gdl_programs)))
  
  gdl_programs = filter_one(filtered_gdl_programs, test_graph_idx, data)
  
  subgraph = copy.deepcopy(data.graphs[test_graph_idx])
  changed = True
  i = 0
  while(changed):
    new_subgraph = remove_node(subgraph, gdl_programs, data)
    i = i + 1
    if len(new_subgraph[0]) < len(subgraph[0]):
      subgraph = copy.deepcopy(new_subgraph)
    else:
      changed = False
  return subgraph 



#def process(dataset, k):
def process(dataset, p):
  if dataset == 'BBBP':
    data = load_BBBP()
  elif dataset == 'BACE':
    data = load_BACE()
  else:  
    data = load_Data(dataset)

  total = len(data.train_graphs)
  if total == 0:
    raise ValueError(f"No training graphs found for dataset: {dataset}")
  # Clamp k to [1, total]
  k = max(1, min(total, int(total * p)))
  print("Chosen k : {}".format(k))
  # Import the filter_top_k module and call its process function
  from generate_embeddings import generate_embeddings 
  generate_embeddings(dataset, k)

if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('-d', '--dataset', help="input dataset")
  parser.add_argument('-lr', '--learning_rate', type=float, default=0.01, help="learning rate (default: 0.01)")
  parser.add_argument('-k', '--k', type=int, default=10, help="number of top k")
  parser.add_argument('-p', '--p', type=float, default=1.0, help="portion")
  args = parser.parse_args()
  dataset = args.dataset
  setGlobal(dataset, args.p, args.k)
  learning_rate = args.learning_rate
  process(dataset, args.p)
  learn(dataset, learning_rate=learning_rate)


