import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch_geometric.data import Data as GeoData
from torch_geometric.nn import GCNConv
from sklearn.neighbors import NearestNeighbors
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
import os
import time
import random
import networkx as nx
import yaml
from pathlib import Path

# Load configuration
def load_config(config_path="config.yaml"):
    """Load configuration from YAML file"""
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

# Load configuration
config = load_config()

# Seed
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(config['random_seed'])

# Load data
train_df = pd.read_csv(config['data']['train_data'])
test_df = pd.read_csv(config['data']['test_data'])
model_responses = pd.read_csv(config['data']['model_responses'], index_col=0)
model_costs = pd.read_csv(config['data']['model_costs'])  # includes 'question_id', 'model_name', 'cost'

# BERT
tokenizer = BertTokenizer.from_pretrained(config['bert']['model_name'])
bert = BertModel.from_pretrained(config['bert']['model_name']).eval().cuda()

def get_bert_embeddings(questions):
    all_embeds = []
    with torch.no_grad():
        for q in tqdm(questions, desc="BERT Embedding"):
            inputs = tokenizer(q, return_tensors='pt', 
                             truncation=config['bert']['truncation'], 
                             max_length=config['bert']['max_length'], 
                             padding=config['bert']['padding'])
            inputs = {k: v.cuda() for k, v in inputs.items()}
            outputs = bert(**inputs)
            emb = outputs.last_hidden_state[:, 0, :].squeeze(0).cpu().numpy()
            all_embeds.append(emb)
    return np.stack(all_embeds)

def get_or_load_bert_embeddings(questions, path):
    if os.path.exists(path):
        return np.load(path)
    else:
        emb = get_bert_embeddings(questions)
        np.save(path, emb)
        return emb

train_embeds = get_or_load_bert_embeddings(train_df['question'].tolist(), config['embeddings']['train_embeds'])
test_embeds = get_or_load_bert_embeddings(test_df['question'].tolist(), config['embeddings']['test_embeds'])

# KNN Graph
def build_knn_graph(embeds, k=5):
    nbrs = NearestNeighbors(n_neighbors=k+1, metric=config['graph']['metric']).fit(embeds)
    distances, indices = nbrs.kneighbors(embeds)
    edge_index = []
    edge_weight = []
    for i in range(embeds.shape[0]):
        for pos, j in enumerate(indices[i][1:]): 
            edge_index.append([i, j])
            w = 1.0 - float(distances[i][pos + 1])
            edge_weight.append(w)
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_weight = torch.tensor(edge_weight, dtype=torch.float)
    return edge_index, edge_weight

edge_index, edge_weight = build_knn_graph(train_embeds, k=config['graph']['k_neighbors'])

# Convert PyG edge_index to NetworkX graph
def pyg_to_networkx(edge_index, num_nodes):
    G = nx.Graph()
    edges = edge_index.cpu().numpy().T
    G.add_nodes_from(range(num_nodes))
    G.add_edges_from(edges)
    return G

G = pyg_to_networkx(edge_index, train_embeds.shape[0])



# Prepare responses
train_question_ids = train_df['question_id']
train_model_responses = model_responses.loc[train_question_ids].values
responses = torch.tensor(train_model_responses, dtype=torch.float)

# Create semi-supervised mask
def create_mask(num_nodes, mask_ratio=0.3, seed=42):
    np.random.seed(seed)
    mask = np.ones(num_nodes, dtype=bool)
    num_mask = int(mask_ratio * num_nodes)
    masked_indices = np.random.choice(num_nodes, num_mask, replace=False)
    mask[masked_indices] = False
    return torch.tensor(mask, dtype=torch.bool), masked_indices

mask_ratio = config['semi_supervised']['mask_ratio']
train_mask, masked_indices = create_mask(train_embeds.shape[0], 
                                       mask_ratio=mask_ratio, 
                                       seed=config['semi_supervised']['seed'])

# Save mask information
mask_info_path = config['output']['train_mask_info'].format(mask_ratio=mask_ratio)
pd.DataFrame({'question_id': train_df['question_id'], 'is_masked': ~train_mask.numpy()}).to_csv(mask_info_path, index=False)




# Calculate total cost before masking
total_cost_before = model_costs.drop(columns=['question_id']).sum().sum()

# Calculate total cost after masking (excluding masked questions)
masked_question_ids = train_df.iloc[masked_indices]['question_id']
masked_cost_df = model_costs[model_costs['question_id'].isin(masked_question_ids)].drop(columns=['question_id'])
masked_cost = masked_cost_df.sum().sum()

total_cost_after = total_cost_before - masked_cost

cost_saving_path = config['output']['cost_saving'].format(mask_ratio=mask_ratio)
with open(cost_saving_path, "w") as f:
    f.write(f"Total cost before masking: {total_cost_before:.4f}\n")
    f.write(f"Total cost after masking: {total_cost_after:.4f}\n")
    f.write(f"Total cost saved: {masked_cost:.4f}\n")

print(f"Cost saving metrics saved to {cost_saving_path}")




# PyG Data
data = GeoData(
    x=torch.tensor(train_embeds, dtype=torch.float),
    edge_index=edge_index,
    edge_weight=edge_weight,
    responses=responses,
    train_mask=train_mask
)

# GNN + IRT Model
class LLMIRTGNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_models, theta_dim=16, dropout=0.3):
        super().__init__()
        self.gcn1 = GCNConv(input_dim, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

        self.W_theta = nn.Parameter(torch.randn(num_models, theta_dim))
        self.W_a = nn.Linear(hidden_dim, theta_dim)
        self.W_b = nn.Linear(hidden_dim, 1)

    def forward(self, x, edge_index, edge_weight=None):
        h = self.relu(self.gcn1(x, edge_index, edge_weight=edge_weight))
        h = self.dropout(h)
        h = self.relu(self.gcn2(h, edge_index, edge_weight=edge_weight))

        a_i = self.W_a(h)
        b_i = self.W_b(h).squeeze(-1)
        theta = self.W_theta

        logits = torch.matmul(a_i, theta.T) - b_i.unsqueeze(1)
        prob = torch.sigmoid(logits)
        return prob, a_i, b_i, theta

device = torch.device('cuda' if torch.cuda.is_available() and config['device']['use_cuda'] else 'cpu')
model = LLMIRTGNN(input_dim=train_embeds.shape[1], 
                 hidden_dim=config['model']['hidden_dim'], 
                 num_models=responses.shape[1],
                 theta_dim=config['model']['theta_dim'],
                 dropout=config['model']['dropout']).to(device)
optimizer = torch.optim.Adam(model.parameters(), 
                            lr=config['training']['learning_rate'], 
                            weight_decay=config['training']['weight_decay'])
loss_fn = nn.BCELoss()

# Training
epochs = config['training']['epochs']
model.train()
for epoch in range(epochs):
    optimizer.zero_grad()
    prob, _, _, _ = model(data.x.to(device), data.edge_index.to(device), data.edge_weight.to(device))
    target = data.responses.to(device)
    mask = data.train_mask.to(device)
    loss = loss_fn(prob[mask], target[mask])
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 10 == 0:
        with torch.no_grad():
            pred = (prob > 0.5).float()
            correct = (pred[mask] == target[mask]).float().sum()
            total = mask.sum().item()
            acc = correct / total
            print(f"Epoch {epoch+1} Loss: {loss.item():.4f}, Accuracy: {acc:.4f}")

# Inference on test
def predict_on_test(test_embeds, train_embeds, model, k=5):
    model.eval()
    preds = []
    with torch.no_grad():
        nbrs = NearestNeighbors(n_neighbors=k, metric=config['inference']['metric']).fit(train_embeds)
        for emb in test_embeds:
            dists, idxs = nbrs.kneighbors([emb])
            neighbors = train_embeds[idxs[0]]
            x = torch.cat([torch.tensor(emb).unsqueeze(0), torch.tensor(neighbors)], dim=0).to(device)
            edge_list = []
            edge_weights = []
            for i in range(k):
                edge_list.extend([[0, i+1], [i+1, 0]])
                w = 1.0 - float(dists[0][i])
                edge_weights.extend([w, w])
            edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous().to(device)
            edge_weight = torch.tensor(edge_weights, dtype=torch.float).to(device)

            prob, _, _, theta = model(x, edge_index, edge_weight=edge_weight)
            preds.append(prob[0].cpu().numpy())
    return np.stack(preds), theta.detach().cpu().numpy()

test_preds, theta_final = predict_on_test(test_embeds, train_embeds, model, k=config['inference']['k_neighbors'])

# Save predictions
test_result_df = pd.DataFrame(test_preds, columns=model_responses.columns)
test_result_df['question_id'] = test_df['question_id']
predictions_path = config['output']['predictions'].format(theta_dim=config['model']['theta_dim'], mask_ratio=mask_ratio)
test_result_df.to_csv(predictions_path, index=False)

# Save parameters after training
model.eval()
with torch.no_grad():
    _, a_train, b_train, theta = model(data.x.to(device), data.edge_index.to(device), data.edge_weight.to(device))

a_df = pd.DataFrame(a_train.cpu().numpy())
a_df['question_id'] = train_df['question_id'].values
a_train_path = config['output']['a_train'].format(mask_ratio=mask_ratio)
a_df.to_csv(a_train_path, index=False)

b_df = pd.DataFrame({'b': b_train.cpu().numpy(), 'question_id': train_df['question_id'].values})
b_train_path = config['output']['b_train'].format(mask_ratio=mask_ratio)
b_df.to_csv(b_train_path, index=False)

theta_df = pd.DataFrame(theta.detach().cpu().numpy())
theta_df['model_name'] = model_responses.columns
theta_path = config['output']['theta'].format(mask_ratio=mask_ratio)
theta_df.to_csv(theta_path, index=False)

# Functions to compute a and b for test set (optional)
def get_a_b_for_test(test_embeds, train_embeds, model, k=3):
    model.eval()
    a_list = []
    b_list = []
    with torch.no_grad():
        nbrs = NearestNeighbors(n_neighbors=k, metric=config['inference']['metric']).fit(train_embeds)
        for emb in test_embeds:
            dists, idxs = nbrs.kneighbors([emb])
            neighbors = train_embeds[idxs[0]]
            x = torch.cat([torch.tensor(emb).unsqueeze(0), torch.tensor(neighbors)], dim=0).to(device)
            edge_list = []
            edge_weights = []
            for i in range(k):
                edge_list.extend([[0, i+1], [i+1, 0]])
                w = 1.0 - float(dists[0][i])
                edge_weights.extend([w, w])
            edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous().to(device)
            edge_weight = torch.tensor(edge_weights, dtype=torch.float).to(device)

            _, a_i, b_i, _ = model(x, edge_index, edge_weight=edge_weight)
            a_list.append(a_i[0].cpu().numpy())
            b_list.append(b_i[0].cpu().item())
    return np.stack(a_list), np.array(b_list)

a_test, b_test = get_a_b_for_test(test_embeds, train_embeds, model, k=config['inference']['k_neighbors'])

a_test_df = pd.DataFrame(a_test)
a_test_df['question_id'] = test_df['question_id'].values
a_test_path = config['output']['a_test'].format(mask_ratio=mask_ratio)
a_test_df.to_csv(a_test_path, index=False)

b_test_df = pd.DataFrame({'b': b_test, 'question_id': test_df['question_id'].values})
b_test_path = config['output']['b_test'].format(mask_ratio=mask_ratio)
b_test_df.to_csv(b_test_path, index=False)

print("Training complete. Predictions and metrics saved.")
