import argparse
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from config import *
from create_tensor import load_or_create_split
from predictor import PromptMLP
from utils import process_math_id
from torch.utils.data import TensorDataset, DataLoader

class DifficultyMLP(nn.Module):
    def __init__(self, input_dim, num_classes=5, hidden_layers=None, dropout=0.15):
        super().__init__()
        if hidden_layers is None:
            hidden_layers = [input_dim, input_dim // 2, input_dim // 4, input_dim // 8]
        
        layers = []
        prev_dim = input_dim
        
        # Input normalization
        layers.append(nn.LayerNorm(input_dim))
        
        for i, hidden_dim in enumerate(hidden_layers):
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.GELU(),
                nn.Dropout(dropout if i < len(hidden_layers) - 1 else dropout * 0.5)
            ])
            prev_dim = hidden_dim
        
        self.final_layer = nn.Linear(prev_dim, num_classes)
        self.dropout_final = nn.Dropout(dropout * 0.3)
        
        self.mlp = nn.Sequential(*layers)
        self._init_weights()
    
    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.BatchNorm1d):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)
            elif isinstance(module, nn.LayerNorm):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)
    
    def forward(self, x):
        features = self.mlp(x)
        features = self.dropout_final(features)
        return self.final_layer(features)


SELECTED_LAYER = 14
model = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
output_dir = os.path.join(MATH_DIR, MODEL_IDS[model], "difficulty_prediction", f"L{SELECTED_LAYER}_mlp")
os.system(f"mkdir -p {output_dir}")

math = pd.read_csv(os.path.join(MATH_DIR, "math3k.csv"))
math_test = math[math['train'] == 0]
math_train, math_val = load_or_create_split(math[math['train'] == 1], 'math', 0.1, 0)

X_s, Y_s, ID_s = list(), list(), list()
for df in [math_train, math_val, math_test]:
    X, Y, IDs = list(), list(), list()
    for idx, row in df.iterrows():
        unique_id = row['unique_id']
        prompt_emb_path = os.path.join(MATH_DIR, MODEL_IDS[model], 'embedding', f'{unique_id}.prompt.lasttoken.npz')
        prompt_emb = torch.from_numpy(np.load(prompt_emb_path)['data'])
        X.append(prompt_emb[SELECTED_LAYER])
        Y.append(row['level'] - 1)
        IDs.append(unique_id)
    X_s.append(torch.stack(X, dim=0))
    Y_s.append(torch.tensor(Y))
    ID_s.append(IDs)

X_train, X_val, X_test = X_s
Y_train, Y_val, Y_test = Y_s
ID_train, ID_val, ID_test = ID_s

input_dim = X_train.size(1)
num_classes = 5

print(f"Input dimension: {input_dim}")
print(f"Number of classes: {num_classes}")
print(f"Training samples: {X_train.size(0)}")
print(f"Validation samples: {X_val.size(0)}")
print(f"Test samples: {X_test.size(0)}")

model = DifficultyMLP(input_dim=input_dim, num_classes=num_classes)
train_dataset = TensorDataset(X_train, Y_train)
val_dataset = TensorDataset(X_val, Y_val)
test_dataset = TensorDataset(X_test, Y_test)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True, persistent_workers=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-3, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)

num_epochs = 50
best_val_acc = 0.0
train_losses = []
val_accuracies = []

# Cache for per-epoch data
epoch_models = {}
epoch_train_predictions = {}
epoch_val_predictions = {}
epoch_test_predictions = {}
epoch_train_probabilities = {}
epoch_val_probabilities = {}
epoch_test_probabilities = {}

print("Starting training...")
for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_loss = 0.0
    train_predictions = []
    train_labels = []
    train_probabilities = []
    train_unique_ids = []
    
    for batch_idx, (batch_x, batch_y) in enumerate(train_loader):
        batch_x, batch_y = batch_x.to(device, non_blocking=True), batch_y.to(device, non_blocking=True)
        
        optimizer.zero_grad()
        
        # Mixed precision training
        if scaler is not None:
            with torch.cuda.amp.autocast():
                outputs = model(batch_x)
                loss = criterion(outputs, batch_y)
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(batch_x)
            loss = criterion(outputs, batch_y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
        
        train_loss += loss.item()
        
        # Store predictions and probabilities for this batch
        _, predicted = torch.max(outputs.data, 1)
        train_predictions.extend(predicted.detach().cpu().numpy())
        train_labels.extend(batch_y.detach().cpu().numpy())
        train_probabilities.extend(outputs.detach().cpu().numpy())
        
        # Get unique_ids for this batch
        batch_start_idx = batch_idx * train_loader.batch_size
        batch_end_idx = min(batch_start_idx + train_loader.batch_size, len(ID_train))
        train_unique_ids.extend(ID_train[batch_start_idx:batch_end_idx])
    
    avg_train_loss = train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    
    # Cache training predictions and probabilities
    epoch_train_predictions[epoch] = {
        'predictions': train_predictions,
        'labels': train_labels,
        'unique_ids': train_unique_ids
    }
    epoch_train_probabilities[epoch] = {
        'probabilities': train_probabilities,
        'labels': train_labels,
        'unique_ids': train_unique_ids
    }
    
    # Validation phase
    model.eval()
    val_correct = 0
    val_total = 0
    val_predictions = []
    val_labels = []
    val_probabilities = []
    val_unique_ids = []
    
    with torch.no_grad():
        for batch_idx, (batch_x, batch_y) in enumerate(val_loader):
            batch_x, batch_y = batch_x.to(device, non_blocking=True), batch_y.to(device, non_blocking=True)
            outputs = model(batch_x)
            _, predicted = torch.max(outputs.data, 1)
            val_total += batch_y.size(0)
            val_correct += (predicted == batch_y).sum().item()
            
            # Store predictions and probabilities for this batch
            val_predictions.extend(predicted.detach().cpu().numpy())
            val_labels.extend(batch_y.detach().cpu().numpy())
            val_probabilities.extend(outputs.detach().cpu().numpy())
            
            # Get unique_ids for this batch
            batch_start_idx = batch_idx * val_loader.batch_size
            batch_end_idx = min(batch_start_idx + val_loader.batch_size, len(ID_val))
            val_unique_ids.extend(ID_val[batch_start_idx:batch_end_idx])
    
    val_acc = val_correct / val_total
    val_accuracies.append(val_acc)
    
    # Cache validation predictions and probabilities
    epoch_val_predictions[epoch] = {
        'predictions': val_predictions,
        'labels': val_labels,
        'unique_ids': val_unique_ids
    }
    epoch_val_probabilities[epoch] = {
        'probabilities': val_probabilities,
        'labels': val_labels,
        'unique_ids': val_unique_ids
    }
    
    # Learning rate scheduling (step after each epoch)
    scheduler.step()
    
    # Cache model weights for this epoch
    epoch_models[epoch] = model.state_dict().copy()
    
    # Save best model by val accuracy
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), os.path.join(output_dir, 'best_model.pt'))
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Acc: {val_acc:.4f}')
    
    # Early stopping
    if epoch > 10 and val_acc < max(val_accuracies[-10:]) - 0.01:
        print("Early stopping triggered")
        break

# Load best model for testing
model.load_state_dict(torch.load(os.path.join(output_dir, 'best_model.pt')))

# Test evaluation with best model
model.eval()
test_correct = 0
test_total = 0
test_predictions = []
test_labels = []
test_probabilities = []
test_unique_ids = []

with torch.no_grad():
    for batch_idx, (batch_x, batch_y) in enumerate(test_loader):
        batch_x, batch_y = batch_x.to(device, non_blocking=True), batch_y.to(device, non_blocking=True)
        outputs = model(batch_x)
        _, predicted = torch.max(outputs.data, 1)
        test_total += batch_y.size(0)
        test_correct += (predicted == batch_y).sum().item()
        test_predictions.extend(predicted.detach().cpu().numpy())
        test_labels.extend(batch_y.detach().cpu().numpy())
        test_probabilities.extend(outputs.detach().cpu().numpy())
        
        # Get unique_ids for this batch
        batch_start_idx = batch_idx * test_loader.batch_size
        batch_end_idx = min(batch_start_idx + test_loader.batch_size, len(ID_test))
        test_unique_ids.extend(ID_test[batch_start_idx:batch_end_idx])

test_acc = test_correct / test_total

# Generate test predictions for all cached epochs
print("Generating test predictions for all epochs...")
for epoch in epoch_models.keys():
    model.load_state_dict(epoch_models[epoch])
    model.eval()
    
    epoch_test_predictions_list = []
    epoch_test_labels_list = []
    epoch_test_probabilities_list = []
    epoch_test_unique_ids = []
    
    with torch.no_grad():
        for batch_idx, (batch_x, batch_y) in enumerate(test_loader):
            batch_x, batch_y = batch_x.to(device, non_blocking=True), batch_y.to(device, non_blocking=True)
            outputs = model(batch_x)
            _, predicted = torch.max(outputs.data, 1)
            epoch_test_predictions_list.extend(predicted.detach().cpu().numpy())
            epoch_test_labels_list.extend(batch_y.detach().cpu().numpy())
            epoch_test_probabilities_list.extend(outputs.detach().cpu().numpy())
            
            # Get unique_ids for this batch
            batch_start_idx = batch_idx * test_loader.batch_size
            batch_end_idx = min(batch_start_idx + test_loader.batch_size, len(ID_test))
            epoch_test_unique_ids.extend(ID_test[batch_start_idx:batch_end_idx])
    
    epoch_test_predictions[epoch] = {
        'predictions': epoch_test_predictions_list,
        'labels': epoch_test_labels_list,
        'unique_ids': epoch_test_unique_ids
    }
    epoch_test_probabilities[epoch] = {
        'probabilities': epoch_test_probabilities_list,
        'labels': epoch_test_labels_list,
        'unique_ids': epoch_test_unique_ids
    }

print(f'\nFinal Results:')
print(f'Best Validation Accuracy: {best_val_acc:.4f}')
print(f'Test Accuracy: {test_acc:.4f}')

# Per-class accuracy
from sklearn.metrics import classification_report, confusion_matrix
print(f'\nClassification Report:')
print(classification_report(test_labels, test_predictions, target_names=[f'Level {i+1}' for i in range(5)]))

print(f'\nConfusion Matrix:')
print(confusion_matrix(test_labels, test_predictions))

# Save results
results = {
    'best_val_acc': best_val_acc,
    'test_acc': test_acc,
    'train_losses': train_losses,
    'val_accuracies': val_accuracies,
    'test_predictions': test_predictions,
    'test_labels': test_labels,
    'test_probabilities': test_probabilities,
    'epoch_models': epoch_models,
    'epoch_train_predictions': epoch_train_predictions,
    'epoch_val_predictions': epoch_val_predictions,
    'epoch_test_predictions': epoch_test_predictions,
    'epoch_train_probabilities': epoch_train_probabilities,
    'epoch_val_probabilities': epoch_val_probabilities,
    'epoch_test_probabilities': epoch_test_probabilities
}

torch.save(results, os.path.join(output_dir, 'results.pt'))

for epoch, model_state in epoch_models.items():
    torch.save(model_state, os.path.join(output_dir, f'epoch_{epoch}_model.pt'))

training_metrics = pd.DataFrame({
    'epoch': list(range(len(train_losses))),
    'train_loss': train_losses,
    'val_accuracy': val_accuracies
})
training_metrics.to_csv(os.path.join(output_dir, 'training_metrics.csv'), index=False)

best_test_results = pd.DataFrame({
    'unique_id': test_unique_ids,
    'true_label': test_labels,
    'predicted_label': test_predictions,
    'prob_class_0': [probs[0] for probs in test_probabilities],
    'prob_class_1': [probs[1] for probs in test_probabilities],
    'prob_class_2': [probs[2] for probs in test_probabilities],
    'prob_class_3': [probs[3] for probs in test_probabilities],
    'prob_class_4': [probs[4] for probs in test_probabilities],
    'max_prob': [max(probs) for probs in test_probabilities],
    'confidence': [max(probs) - sorted(probs)[-2] for probs in test_probabilities]
})
best_test_results.to_csv(os.path.join(output_dir, 'best_model_test_results.csv'), index=False)

for epoch in epoch_models.keys():
    train_data = epoch_train_predictions[epoch]
    train_probs = epoch_train_probabilities[epoch]
    train_df = pd.DataFrame({
        'epoch': [epoch] * len(train_data['labels']),
        'split': ['train'] * len(train_data['labels']),
        'unique_id': train_data['unique_ids'],
        'true_label': train_data['labels'],
        'predicted_label': train_data['predictions'],
        'prob_class_0': [probs[0] for probs in train_probs['probabilities']],
        'prob_class_1': [probs[1] for probs in train_probs['probabilities']],
        'prob_class_2': [probs[2] for probs in train_probs['probabilities']],
        'prob_class_3': [probs[3] for probs in train_probs['probabilities']],
        'prob_class_4': [probs[4] for probs in train_probs['probabilities']],
        'max_prob': [max(probs) for probs in train_probs['probabilities']],
        'confidence': [max(probs) - sorted(probs)[-2] for probs in train_probs['probabilities']]
    })
    
    val_data = epoch_val_predictions[epoch]
    val_probs = epoch_val_probabilities[epoch]
    val_df = pd.DataFrame({
        'epoch': [epoch] * len(val_data['labels']),
        'split': ['val'] * len(val_data['labels']),
        'unique_id': val_data['unique_ids'],
        'true_label': val_data['labels'],
        'predicted_label': val_data['predictions'],
        'prob_class_0': [probs[0] for probs in val_probs['probabilities']],
        'prob_class_1': [probs[1] for probs in val_probs['probabilities']],
        'prob_class_2': [probs[2] for probs in val_probs['probabilities']],
        'prob_class_3': [probs[3] for probs in val_probs['probabilities']],
        'prob_class_4': [probs[4] for probs in val_probs['probabilities']],
        'max_prob': [max(probs) for probs in val_probs['probabilities']],
        'confidence': [max(probs) - sorted(probs)[-2] for probs in val_probs['probabilities']]
    })
    
    test_data = epoch_test_predictions[epoch]
    test_probs = epoch_test_probabilities[epoch]
    test_df = pd.DataFrame({
        'epoch': [epoch] * len(test_data['labels']),
        'split': ['test'] * len(test_data['labels']),
        'unique_id': test_data['unique_ids'],
        'true_label': test_data['labels'],
        'predicted_label': test_data['predictions'],
        'prob_class_0': [probs[0] for probs in test_probs['probabilities']],
        'prob_class_1': [probs[1] for probs in test_probs['probabilities']],
        'prob_class_2': [probs[2] for probs in test_probs['probabilities']],
        'prob_class_3': [probs[3] for probs in test_probs['probabilities']],
        'prob_class_4': [probs[4] for probs in test_probs['probabilities']],
        'max_prob': [max(probs) for probs in test_probs['probabilities']],
        'confidence': [max(probs) - sorted(probs)[-2] for probs in test_probs['probabilities']]
    })
    
    epoch_df = pd.concat([train_df, val_df, test_df], ignore_index=True)
    epoch_df.to_csv(os.path.join(output_dir, f'epoch_{epoch}_predictions.csv'), index=False)

epoch_summaries = []
for epoch in epoch_models.keys():
    train_data = epoch_train_predictions[epoch]
    val_data = epoch_val_predictions[epoch]
    test_data = epoch_test_predictions[epoch]
    
    train_acc = sum(1 for p, l in zip(train_data['predictions'], train_data['labels']) if p == l) / len(train_data['labels'])
    val_acc = sum(1 for p, l in zip(val_data['predictions'], val_data['labels']) if p == l) / len(val_data['labels'])
    test_acc = sum(1 for p, l in zip(test_data['predictions'], test_data['labels']) if p == l) / len(test_data['labels'])
    
    epoch_summaries.append({
        'epoch': epoch,
        'train_accuracy': train_acc,
        'val_accuracy': val_acc,
        'test_accuracy': test_acc,
        'train_loss': train_losses[epoch] if epoch < len(train_losses) else None
    })

summary_df = pd.DataFrame(epoch_summaries)
summary_df.to_csv(os.path.join(output_dir, 'epoch_summaries.csv'), index=False)

best_epoch = max(epoch_models.keys(), key=lambda e: val_accuracies[e] if e < len(val_accuracies) else 0)
print(f"Best validation accuracy epoch: {best_epoch}")

model.load_state_dict(epoch_models[best_epoch])
model.eval()

if torch.cuda.is_available():
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
else:
    import time

best_epoch_test_data = epoch_test_predictions[best_epoch]
best_epoch_test_probs = epoch_test_probabilities[best_epoch]

# Measure prediction overhead with batch size = 1
prediction_times = []
with torch.no_grad():
    for i, unique_id in enumerate(best_epoch_test_data['unique_ids']):
        test_idx = ID_test.index(unique_id)
        sample_embedding = X_test[test_idx:test_idx+1].to(device, non_blocking=True)
        
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            start_event.record()
            _ = model(sample_embedding)
            end_event.record()
            torch.cuda.synchronize()
            prediction_time = start_event.elapsed_time(end_event) / 1000.0
        else:
            start_time = time.time()
            _ = model(sample_embedding)
            end_time = time.time()
            prediction_time = end_time - start_time
        
        prediction_times.append(prediction_time)

best_epoch_results = pd.DataFrame({
    'unique_id': best_epoch_test_data['unique_ids'],
    'level': [label + 1 for label in best_epoch_test_data['labels']],
    'level_predict': [pred + 1 for pred in best_epoch_test_data['predictions']],
    'prediction_time': prediction_times
})

best_epoch_results.to_csv(os.path.join(output_dir, 'best_epoch_test_results.csv'), index=False)
