import argparse
import os
import polars as pl
import pandas as pd
import torch
import numpy as np
from config import DSET_TO_DIR, MODEL_IDS
from create_tensor import create_tensor_for_dataset
from predictor import ReasoningModel, ReasoningMLPModel
from torch.utils.data import DataLoader, TensorDataset, Dataset
from utils import convert_model_setting_to_str, convert_data_setting_to_str
from sklearn.metrics import precision_score, recall_score, f1_score, average_precision_score, accuracy_score
from sklearn.utils.class_weight import compute_class_weight
from torch.optim.lr_scheduler import _LRScheduler, StepLR
import math
from tqdm import tqdm
import gc
import torch.nn as nn


class SimpleMLPModel(nn.Module):
    
    def __init__(self, tensor_dim, dropout=0.15, hidden_layers=None):
        super().__init__()
        
        if hidden_layers is None:
            hidden_layers = [tensor_dim // 2, tensor_dim // 4]
        
        self.ln = nn.LayerNorm(tensor_dim)
        self.dropout = nn.Dropout(dropout)
        
        layers = []
        prev_dim = tensor_dim
        
        for hidden_dim in hidden_layers:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
            prev_dim = hidden_dim
        
        layers.append(nn.Linear(prev_dim, 1, bias=True))
        self.mlp = nn.Sequential(*layers)
        
        self._init_weights()
    
    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
    
    def forward(self, intermediate_emb, seq_mask):
        intermediate_emb = self.ln(intermediate_emb)
        if intermediate_emb.dim() == 3 and intermediate_emb.size(1) == 1:
            intermediate_emb = intermediate_emb.squeeze(1)  # [batch, 1, dim] -> [batch, dim]
        
        intermediate_emb = self.dropout(intermediate_emb)
        logits = self.mlp(intermediate_emb)
        return logits


class SimpleDataset(Dataset):
    def __init__(self, prompt_tensor, intermediate_tensor, label_tensor, sample_weight_tensor=None):
        self.prompt_tensor = prompt_tensor
        self.intermediate_tensor = intermediate_tensor
        self.label_tensor = label_tensor
        self.sample_weight_tensor = sample_weight_tensor
        
    def __len__(self):
        return len(self.intermediate_tensor)
    
    def __getitem__(self, idx):
        if self.sample_weight_tensor is not None:
            return self.prompt_tensor[idx], self.intermediate_tensor[idx], self.label_tensor[idx], self.sample_weight_tensor[idx]
        else:
            return self.prompt_tensor[idx], self.intermediate_tensor[idx], self.label_tensor[idx]


def simple_collate(batch):
    if len(batch[0]) == 4:
        prompts, intermediates, labels, sample_weights = zip(*batch)
        sample_weights = torch.stack(sample_weights)
    else:
        prompts, intermediates, labels = zip(*batch)
        sample_weights = None
    
    prompt_batch = torch.stack(prompts)
    intermediate_batch = torch.stack(intermediates)
    label_batch = torch.stack(labels)
    
    mask_batch = torch.ones(len(intermediates), 1, dtype=torch.bool)
    
    if sample_weights is not None:
        return prompt_batch, intermediate_batch, mask_batch, label_batch, sample_weights
    else:
        return prompt_batch, intermediate_batch, mask_batch, label_batch


def main():
    parser = argparse.ArgumentParser(description='Train GSM8K model with specified layer index')
    parser.add_argument('-l', '--layer_idx', type=int, default=20, help='Layer index to use (default: 20)')
    args = parser.parse_args()

    model_id = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
    model_name = MODEL_IDS[model_id]
    
    dataset = "gsm8k"
    layer_idx = args.layer_idx
    warmup = 1
    stable = 3
    interval = 16
    max_seq = 1
    arch = "mlp"
    tensor_dim = 4096
    val_frac = 0.1
    seed = 0
    mlp_dropout = 0.15
    batch_size = 2048
    weight_decay = 0.1
    lr = 3e-4
    
    # MLP hidden layer configuration
    mlp_hidden_layers = [2048, 1024]
    
    print(f"Training GSM8K with fixed settings:")
    print(f"Dataset: {dataset}")
    print(f"Layer: {layer_idx}")
    print(f"Warmup: {warmup}, Stable: {stable}, Interval: {interval}")
    print(f"Max sequence length: {max_seq}")
    print(f"Architecture: {arch}")
    print(f"Tensor dimension: {tensor_dim}")
    print(f"MLP hidden layers: {mlp_hidden_layers}")
    print(f"Validation fraction: {val_frac}, Seed: {seed}")
    print(f"Batch size: {batch_size}, Learning rate: {lr}")
    print("-" * 50)
    
    DDIR = DSET_TO_DIR[dataset]
    model_dir_full = os.path.join(DDIR, model_name, "model", f"L{layer_idx}_mlp")
    os.system(f"mkdir -p {model_dir_full}")
    
    if os.path.isfile(os.path.join(model_dir_full, 'history.csv')):
        print("Found existing training history, skipping...")
        return
    
    model = SimpleMLPModel(tensor_dim, dropout=mlp_dropout, hidden_layers=mlp_hidden_layers)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    print(f"Using device: {device}")
    
    print("Loading data...")
    train_cache, val_cache, test_cache = create_tensor_for_dataset(
        dataset, model_id, layer_idx, warmup, stable, interval, max_seq, False, val_frac=val_frac, seed=seed)
    train_info, train_prompt, train_intermediate, train_label = train_cache
    val_info, val_prompt, val_intermediate, val_label = val_cache
    test_info, test_prompt, test_intermediate, test_label = test_cache
    
    print(f"Train samples: {len(train_info)}")
    print(f"Val samples: {len(val_info)}")
    print(f"Test samples: {len(test_info)}")
    
    print(f"\nFirst training sample tensor dimensions:")
    print(f"Prompt tensor shape: {train_prompt[0].shape}")
    print(f"Intermediate tensor shape: {train_intermediate[0].shape}")
    print(f"Label tensor shape: {train_label[0].shape}")
    print(f"Label value: {train_label[0].item()}")
    print("-" * 50)
    
    print(f"Train prompt shape: {train_prompt.shape}")
    print(f"Train intermediate: list of {len(train_intermediate)} tensors")
    print(f"First intermediate tensor shape: {train_intermediate[0].shape}")
    print(f"Train label shape: {train_label.shape}")
    print(f"Label tensor dtype: {train_label.dtype}")
    
    print("Converting intermediate tensors to single tensor...")
    train_intermediate = torch.stack(train_intermediate)  # [num_samples, max_seq, tensor_dim]
    val_intermediate = torch.stack(val_intermediate)
    test_intermediate = torch.stack(test_intermediate)
    print(f"After stacking - Train intermediate shape: {train_intermediate.shape}")
    
    print("Keeping tensors on CPU, will move to GPU per batch")
    
    # Save labels for later evaluation
    torch.save(val_label, os.path.join(model_dir_full, 'val.label.pt'))
    torch.save(test_label, os.path.join(model_dir_full, 'test.label.pt'))
    
    # Save info files
    test_info.to_csv(os.path.join(model_dir_full, 'test.info.csv'), index=False)
    train_info.to_csv(os.path.join(model_dir_full, 'train.info.csv'), index=False)
    val_info.to_csv(os.path.join(model_dir_full, 'val.info.csv'), index=False)
    
    train_ds = SimpleDataset(train_prompt, train_intermediate, train_label)
    val_ds = SimpleDataset(val_prompt, val_intermediate, val_label)
    test_ds = SimpleDataset(test_prompt, test_intermediate, test_label)
    
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=False, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, pin_memory=False, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, pin_memory=False, num_workers=0)
    
    train_labels_np = train_label.cpu().numpy()
    pos_weight = (train_labels_np == 0).sum() / (train_labels_np == 1).sum()
    print(f"Positive weight for loss: {pos_weight:.4f}")
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = StepLR(optimizer, step_size=5, gamma=0.8)
    criterion = torch.nn.BCEWithLogitsLoss(reduction="mean", pos_weight=torch.tensor(pos_weight, dtype=torch.float32, device=device))
    
    # Training loop
    epochs = 15
    best_val_f1 = 0.0
    history = []
    
    print(f"Starting training for {epochs} epochs...")
    
    for epoch in range(epochs):
        # Training
        model.train()
        total_train_loss = 0
        train_preds = []
        train_probs = []
        train_labels = []
        
        current_lr = optimizer.param_groups[0]['lr']
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]')
        for batch in pbar:
            if len(batch) == 4:
                prompts, intermediates, labels, sample_weights = batch
            else:
                prompts, intermediates, labels = batch
                sample_weights = None
            
            prompts = prompts.to(device)
            intermediates = intermediates.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(intermediates, torch.ones(intermediates.size(0), 1, dtype=torch.bool, device=device))
            loss = criterion(outputs.squeeze(), labels.float())
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            probs = torch.sigmoid(outputs).detach()
            train_preds.extend((probs > 0.5).cpu().numpy())
            train_probs.extend(probs.cpu().numpy())
            train_labels.extend(labels.cpu().numpy())
            
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{current_lr:.2e}'})
        
        avg_train_loss = total_train_loss / len(train_loader)
        train_preds = np.array(train_preds)
        train_probs = np.array(train_probs)
        train_labels = np.array(train_labels)
        
        train_acc = accuracy_score(train_labels, train_preds)
        train_precision = precision_score(train_labels, train_preds, zero_division=0)
        train_recall = recall_score(train_labels, train_preds, zero_division=0)
        train_f1 = f1_score(train_labels, train_preds, zero_division=0)
        train_ap = average_precision_score(train_labels, train_probs)
        
        model.eval()
        total_val_loss = 0
        val_preds = []
        val_probs = []
        val_labels = []
        
        with torch.no_grad():
            pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{epochs} [Val]')
            for batch in pbar:
                if len(batch) == 4:
                    prompts, intermediates, labels, sample_weights = batch
                else:
                    prompts, intermediates, labels = batch
                
                prompts = prompts.to(device)
                intermediates = intermediates.to(device)
                labels = labels.to(device)
                
                outputs = model(intermediates, torch.ones(intermediates.size(0), 1, dtype=torch.bool, device=device))
                loss = criterion(outputs.squeeze(), labels.float())
                total_val_loss += loss.item()
                
                probs = torch.sigmoid(outputs).cpu().numpy()
                val_preds.extend((probs > 0.5))
                val_probs.extend(probs)
                val_labels.extend(labels.cpu().numpy())
                
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        avg_val_loss = total_val_loss / len(val_loader)
        val_preds = np.array(val_preds)
        val_probs = np.array(val_probs)
        val_labels = np.array(val_labels)
        
        val_acc = accuracy_score(val_labels, val_preds)
        val_precision = precision_score(val_labels, val_preds, zero_division=0)
        val_recall = recall_score(val_labels, val_preds, zero_division=0)
        val_f1 = f1_score(val_labels, val_preds, zero_division=0)
        val_ap = average_precision_score(val_labels, val_probs)
        
        model.eval()
        total_test_loss = 0
        test_preds = []
        test_probs = []
        test_labels = []
        
        with torch.no_grad():
            pbar = tqdm(test_loader, desc=f'Epoch {epoch+1}/{epochs} [Test]')
            for batch in pbar:
                if len(batch) == 4:
                    prompts, intermediates, labels, sample_weights = batch
                else:
                    prompts, intermediates, labels = batch
                
                prompts = prompts.to(device)
                intermediates = intermediates.to(device)
                labels = labels.to(device)
                
                outputs = model(intermediates, torch.ones(intermediates.size(0), 1, dtype=torch.bool, device=device))
                loss = criterion(outputs.squeeze(), labels.float())
                total_test_loss += loss.item()
                
                probs = torch.sigmoid(outputs).cpu().numpy()
                test_preds.extend((probs > 0.5))
                test_probs.extend(probs)
                test_labels.extend(labels.cpu().numpy())
                
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        avg_test_loss = total_test_loss / len(test_loader)
        test_preds = np.array(test_preds)
        test_probs = np.array(test_probs)
        test_labels = np.array(test_labels)
        
        test_acc = accuracy_score(test_labels, test_preds)
        test_precision = precision_score(test_labels, test_preds, zero_division=0)
        test_recall = recall_score(test_labels, test_preds, zero_division=0)
        test_f1 = f1_score(test_labels, test_preds, zero_division=0)
        test_ap = average_precision_score(test_labels, test_probs)
        
        print(f'Epoch {epoch+1}/{epochs} Summary:')
        print(f'Train - Loss: {avg_train_loss:.4f}, Acc: {train_acc:.4f}, Prec: {train_precision:.4f}, Rec: {train_recall:.4f}, F1: {train_f1:.4f}, AP: {train_ap:.4f}')
        print(f'Val   - Loss: {avg_val_loss:.4f}, Acc: {val_acc:.4f}, Prec: {val_precision:.4f}, Rec: {val_recall:.4f}, F1: {val_f1:.4f}, AP: {val_ap:.4f}')
        print(f'Test  - Loss: {avg_test_loss:.4f}, Acc: {test_acc:.4f}, Prec: {test_precision:.4f}, Rec: {test_recall:.4f}, F1: {test_f1:.4f}, AP: {test_ap:.4f}')
        print('-' * 50)
        
        # Save history
        history.append({
            'epoch': epoch + 1,
            'train_loss': avg_train_loss,
            'train_acc': train_acc,
            'train_precision': train_precision,
            'train_recall': train_recall,
            'train_f1': train_f1,
            'train_ap': train_ap,
            'val_loss': avg_val_loss,
            'val_acc': val_acc,
            'val_precision': val_precision,
            'val_recall': val_recall,
            'val_f1': val_f1,
            'val_ap': val_ap,
            'test_loss': avg_test_loss,
            'test_acc': test_acc,
            'test_precision': test_precision,
            'test_recall': test_recall,
            'test_f1': test_f1,
            'test_ap': test_ap,
            'learning_rate': current_lr
        })
        
        # Save model checkpoint
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': avg_train_loss,
            'val_f1': val_f1,
        }, os.path.join(model_dir_full, f'e{epoch+1}.pt'))
        
        # Save predictions
        np.savez_compressed(os.path.join(model_dir_full, f'test.pred.{epoch+1}.npz'), predictions=test_probs)
        np.savez_compressed(os.path.join(model_dir_full, f'val.pred.{epoch+1}.npz'), predictions=val_probs)
        
        # Update learning rate
        scheduler.step()
        
        # Clear GPU memory
        torch.cuda.empty_cache()
        gc.collect()
    
    # Save final history
    history_df = pd.DataFrame(history)
    history_df.to_csv(os.path.join(model_dir_full, 'history.csv'), index=False)
    
    print(f"Results saved to: {model_dir_full}")


if __name__ == "__main__":
    main()



