import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch_geometric.data import Data as GeoData
from torch_geometric.nn import GCNConv
from sklearn.neighbors import NearestNeighbors
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
import os
import time
import random
import networkx as nx
import yaml
from pathlib import Path

# Load configuration
def load_config(config_path="config.yaml"):
    """Load configuration from YAML file"""
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

# Load configuration
config = load_config()

# Seed
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(config['random_seed'])

# Load data
train_df = pd.read_csv(config['data']['train_data'])
test_df = pd.read_csv(config['data']['test_data'])
model_responses = pd.read_csv(config['data']['model_responses'], index_col=0)  # shape: [num_questions, num_models]
model_costs = pd.read_csv(config['data']['model_costs'])  # includes 'question_id', 'model_name', 'cost'

# BERT
tokenizer = BertTokenizer.from_pretrained(config['bert']['model_name'])
bert = BertModel.from_pretrained(config['bert']['model_name']).eval().cuda()

def get_bert_embeddings(questions):
    all_embeds = []
    with torch.no_grad():
        for q in tqdm(questions, desc="BERT Embedding"):
            inputs = tokenizer(q, return_tensors='pt', 
                             truncation=config['bert']['truncation'], 
                             max_length=config['bert']['max_length'], 
                             padding=config['bert']['padding'])
            inputs = {k: v.cuda() for k, v in inputs.items()}
            outputs = bert(**inputs)
            emb = outputs.last_hidden_state[:, 0, :].squeeze(0).cpu().numpy()
            all_embeds.append(emb)
    return np.stack(all_embeds)

def get_or_load_bert_embeddings(questions, path):
    if os.path.exists(path):
        return np.load(path)
    else:
        emb = get_bert_embeddings(questions)
        np.save(path, emb)
        return emb

train_embeds = get_or_load_bert_embeddings(train_df['question'].tolist(), config['embeddings']['train_embeds'])
test_embeds = get_or_load_bert_embeddings(test_df['question'].tolist(), config['embeddings']['test_embeds'])

# KNN Graph
def build_knn_graph(embeds, k=5):
    nbrs = NearestNeighbors(n_neighbors=k+1, metric=config['graph']['metric']).fit(embeds)
    distances, indices = nbrs.kneighbors(embeds)
    edge_index = []
    edge_weight = []
    for i in range(embeds.shape[0]):
        # Skip the first neighbor which is the point itself
        for pos, j in enumerate(indices[i][1:]):
            edge_index.append([i, j])
            # Convert cosine distance to cosine similarity for edge weight
            w = 1.0 - float(distances[i][pos + 1])
            edge_weight.append(w)
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_weight = torch.tensor(edge_weight, dtype=torch.float)
    return edge_index, edge_weight

edge_index, edge_weight = build_knn_graph(train_embeds, k=config['graph']['k_neighbors'])

# Convert PyG edge_index to NetworkX graph
def pyg_to_networkx(edge_index, num_nodes):
    G = nx.Graph()
    edges = edge_index.cpu().numpy().T
    G.add_nodes_from(range(num_nodes))
    G.add_edges_from(edges)
    return G

# Build NetworkX graph from train graph
G = pyg_to_networkx(edge_index, train_embeds.shape[0])

# Compute metrics
num_nodes = G.number_of_nodes()
num_edges = G.number_of_edges()
degree_centrality = nx.degree_centrality(G)
betweenness_centrality = nx.betweenness_centrality(G)
closeness_centrality = nx.closeness_centrality(G)

# Save metrics to a text file
graph_metrics_path = config['output']['graph_metrics']
with open(graph_metrics_path, "w") as f:
    f.write(f"Number of nodes: {num_nodes}\n")
    f.write(f"Number of edges: {num_edges}\n")
    f.write("Degree centrality (first 10 nodes):\n")
    for node, val in list(degree_centrality.items())[:10]:
        f.write(f"  Node {node}: {val:.4f}\n")
    f.write("Betweenness centrality (first 10 nodes):\n")
    for node, val in list(betweenness_centrality.items())[:10]:
        f.write(f"  Node {node}: {val:.4f}\n")
    f.write("Closeness centrality (first 10 nodes):\n")
    for node, val in list(closeness_centrality.items())[:10]:
        f.write(f"  Node {node}: {val:.4f}\n")

print(f"Graph metrics saved to {graph_metrics_path}")

# Only use model_responses for train questions
train_question_ids = train_df['question_id']
train_model_responses = model_responses.loc[train_question_ids].values  # [num_train_questions, num_models]
responses = torch.tensor(train_model_responses, dtype=torch.float)  # [num_train_questions, num_models]

# PyG Graph
data = GeoData(
    x=torch.tensor(train_embeds, dtype=torch.float),
    edge_index=edge_index,
    edge_weight=edge_weight,
    responses=responses  # Only train responses
)

# GNN + IRT Model
class LLMIRTGNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_models, theta_dim=16, dropout=0.3):
        super().__init__()
        self.gcn1 = GCNConv(input_dim, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

        self.W_theta = nn.Parameter(torch.randn(num_models, theta_dim))
        self.W_a = nn.Linear(hidden_dim, theta_dim)
        self.W_b = nn.Linear(hidden_dim, 1)

    def forward(self, x, edge_index, edge_weight=None):
        h = self.relu(self.gcn1(x, edge_index, edge_weight=edge_weight))
        h = self.dropout(h)
        h = self.relu(self.gcn2(h, edge_index, edge_weight=edge_weight))

        a_i = self.W_a(h)  # [N, D]
        b_i = self.W_b(h).squeeze(-1)  # [N]
        theta = self.W_theta  # [M, D]

        logits = torch.matmul(a_i, theta.T) - b_i.unsqueeze(1)  # [N, M]
        prob = torch.sigmoid(logits)
        return prob, a_i, b_i, theta

device = torch.device('cuda' if torch.cuda.is_available() and config['device']['use_cuda'] else 'cpu')
model = LLMIRTGNN(input_dim=train_embeds.shape[1], 
                 hidden_dim=config['model']['hidden_dim'], 
                 num_models=responses.shape[1],
                 theta_dim=config['model']['theta_dim'],
                 dropout=config['model']['dropout']).to(device)
optimizer = torch.optim.Adam(model.parameters(), 
                            lr=config['training']['learning_rate'], 
                            weight_decay=config['training']['weight_decay'])
loss_fn = nn.BCELoss()

# Training
epochs = config['training']['epochs']
model.train()
for epoch in range(epochs):
    optimizer.zero_grad()
    prob, _, _, _ = model(data.x.to(device), data.edge_index.to(device), data.edge_weight.to(device))
    target = data.responses.to(device)
    loss = loss_fn(prob, target)
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1} Loss: {loss.item():.4f}")

# Inference on test
def predict_on_test(test_embeds, train_embeds, model, k=5):
    model.eval()
    preds = []
    with torch.no_grad():
        nbrs = NearestNeighbors(n_neighbors=k, metric=config['inference']['metric']).fit(train_embeds)
        for emb in test_embeds:
            dists, idxs = nbrs.kneighbors([emb])
            neighbors = train_embeds[idxs[0]]
            x = torch.cat([torch.tensor(emb).unsqueeze(0), torch.tensor(neighbors)], dim=0).to(device)
            edge_list = []
            edge_weights = []
            for i in range(k):
                edge_list.extend([[0, i+1], [i+1, 0]])
                w = 1.0 - float(dists[0][i])
                edge_weights.extend([w, w])
            edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous().to(device)
            edge_weight = torch.tensor(edge_weights, dtype=torch.float).to(device)

            prob, _, _, theta = model(x, edge_index, edge_weight=edge_weight)
            preds.append(prob[0].cpu().numpy())  # prediction for test node
    return np.stack(preds), theta.detach().cpu().numpy()

test_preds, theta_final = predict_on_test(test_embeds, train_embeds, model, k=config['inference']['k_neighbors'])

# Create DataFrame of predictions
test_result_df = pd.DataFrame(test_preds, columns=model_responses.columns)
test_result_df['question_id'] = test_df['question_id']


# Save outputs
predictions_path = config['output']['predictions'].format(theta_dim=config['model']['theta_dim'], epochs=config['training']['epochs'])
test_result_df.to_csv(predictions_path, index=False)


# After training, extract a, b, theta for train set
model.eval()
with torch.no_grad():
    _, a_train, b_train, theta = model(data.x.to(device), data.edge_index.to(device), data.edge_weight.to(device))

# Save a (discrimination) and b (difficulty) for train questions
a_train_path = config['output']['a_train'].format(epochs=config['training']['epochs'])
a_df = pd.DataFrame(a_train.cpu().numpy())
a_df['question_id'] = train_df['question_id'].values
a_df.to_csv(a_train_path, index=False)

b_train_path = config['output']['b_train'].format(epochs=config['training']['epochs'])
b_df = pd.DataFrame({'b': b_train.cpu().numpy(), 'question_id': train_df['question_id'].values})
b_df.to_csv(b_train_path, index=False)

# Save theta (model ability)
theta_path = config['output']['theta'].format(epochs=config['training']['epochs'])
theta_df = pd.DataFrame(theta.detach().cpu().numpy())
theta_df['model_name'] = model_responses.columns
theta_df.to_csv(theta_path, index=False)

# Compute a and b for test data
def get_a_b_for_test(test_embeds, train_embeds, model, k=3):
    model.eval()
    a_list = []
    b_list = []
    with torch.no_grad():
        nbrs = NearestNeighbors(n_neighbors=k, metric=config['inference']['metric']).fit(train_embeds)
        for emb in test_embeds:
            dists, idxs = nbrs.kneighbors([emb])
            neighbors = train_embeds[idxs[0]]
            x = torch.cat([torch.tensor(emb).unsqueeze(0), torch.tensor(neighbors)], dim=0).to(device)
            edge_list = []
            edge_weights = []
            for i in range(k):
                edge_list.extend([[0, i+1], [i+1, 0]])
                w = 1.0 - float(dists[0][i])
                edge_weights.extend([w, w])
            edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous().to(device)
            edge_weight = torch.tensor(edge_weights, dtype=torch.float).to(device)

            _, a_i, b_i, _ = model(x, edge_index, edge_weight=edge_weight)
            a_list.append(a_i[0].cpu().numpy())  # first node is the test question
            b_list.append(b_i[0].cpu().item())
    return np.stack(a_list), np.array(b_list)

a_test, b_test = get_a_b_for_test(test_embeds, train_embeds, model, k=config['inference']['k_neighbors'])

# Save a and b for test questions
a_test_path = config['output']['a_test'].format(epochs=config['training']['epochs'])
a_test_df = pd.DataFrame(a_test)
a_test_df['question_id'] = test_df['question_id'].values
a_test_df.to_csv(a_test_path, index=False)

b_test_path = config['output']['b_test'].format(epochs=config['training']['epochs'])
b_test_df = pd.DataFrame({'b': b_test, 'question_id': test_df['question_id'].values})
b_test_df.to_csv(b_test_path, index=False)

print("Saved LLM performance predictions and expected costs.")


