"""
Trains simple probes (logistic regression and MLPs) to predict safety alignment outcomes from CoT activations.

Models Trained:
- Logistic Regression: Simple linear classifier with balanced class weights
- MLP: 2-layer neural network with early stopping and validation
- Random baselines: For comparison and significance testing

Input data:
    input_folder/
    ├── activations/         # PyTorch tensors from 2b_get_activations.py
    │   ├── 0_0.pt
    │   ├── 0_1.pt
    │   └── ...
    └── labels/              # Safety labels from 2a_evaluate_safety.py
        ├── 0/
        │   ├── 0_0.json
        │   ├── 0_1.json
        │   └── ...
        └── ...

Output:
    Prints performance metrics (F1, accuracy, PR-AUC).
    Optionally saves detailed predictions and training data to TSV files.
"""

import collections
from loguru import logger
import os
import json
import torch
import argparse
import pathlib
import numpy as np
from tqdm import tqdm
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, precision_recall_curve, auc, confusion_matrix, classification_report
from sklearn.preprocessing import StandardScaler
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from utils import eval_pred, add_to_final_scores, calculate_metrics_stats, save_probe_outputs_tsv

parser = argparse.ArgumentParser()
parser.add_argument("--input_folder", type=str, required=True, help="input folder containing activations and labels")
parser.add_argument("--N_runs", type=int, default=5, help="number of different seeded runs")
parser.add_argument("--sample_K", type=int, default=-1, help="number of training samples")
parser.add_argument("--pca", action="store_true", help="run PCA")
parser.add_argument("--pca_components", type=int, default=50, help="number of different seeded runs")

### storing test prediction outputs
parser.add_argument("--store_outputs", action="store_true", help="whether to store model outputs")
parser.add_argument("--probe_output_folder", type=str, default="../probe_outputs/", help="folder to store model outputs and results")

args = parser.parse_args()
INPUT_FOLDER = pathlib.Path(args.input_folder)
if args.store_outputs:
    PROBE_OUTPUT_FOLDER = pathlib.Path(args.probe_output_folder) / INPUT_FOLDER.name
    PROBE_OUTPUT_FOLDER.mkdir(parents=True, exist_ok=True)

### load and engineer data
def load_data():
    """
    Load activations and labels from the input folder.
    """
    activations = dict()
    labels = dict()
    prompts = {}
    cots = {}

    for act_file in tqdm((INPUT_FOLDER / "activations").glob("*.pt"), desc="Loading activations"):
        filename = os.path.basename(act_file)
        # key is in the format of "{prompt_id}_{sentence_id}"
        key = filename.split('.')[0]
        activation = torch.load(act_file)
        activations[key] = activation

    for label_file in tqdm((INPUT_FOLDER / "labels").rglob("*.json"), desc="Loading labels and texts"):
        filename = os.path.basename(label_file)
        key = filename.split('.')[0]
        with open(label_file, 'r') as f:
            data = json.load(f)
            labels[key] = data["safety_label"]["score"]
            prompts[key] = data.get("prompt", "")
            cots[key] = data.get("cot", "")
    return activations, prompts, cots, labels

def prepare_data(activations, labels):
    """
    Reformat activations and labels into a single numpy array.
    """
    def convert_to_numpy(tensor):
        # convert a tensor to numpy, handle bfloat16
        if isinstance(tensor, np.ndarray):
            return tensor
        if tensor.dtype == torch.bfloat16:
            tensor = tensor.to(torch.float32)
        if tensor.dtype in [torch.float16, torch.int8, torch.uint8, torch.int16]:
            tensor = tensor.to(torch.float32)
        if tensor.requires_grad:
            tensor = tensor.detach()
        if tensor.device.type != 'cpu':
            tensor = tensor.cpu()
        return tensor.numpy()

    activations_list = []
    labels_list = []
    prompt_sent_ids = []

    # sanity check: make sure we only process ids that exist in both activations and labels
    assert set(activations.keys()) ==  set(labels.keys()), f"difference: {set(activations.keys()) - set(labels.keys())}"
    for id in activations.keys():
        # convert to numpy first
        activations_list.append(convert_to_numpy(activations[id]))
        labels_list.append(labels[id]) 
        prompt_sent_ids.append(id)

    X = np.vstack(activations_list)
    labels_np = np.array(labels_list)    
    return X, labels_np, prompt_sent_ids

def apply_pca(X_train, X_val, X_test):
    # fit PCA on training data
    pca_components = min(args.pca_components, X_train.shape[0], X_train.shape[1])
    logger.info(f"PCA:::reducing to {pca_components}")
    pca = PCA(n_components=pca_components)
    X_train_pca = pca.fit_transform(X_train)
    # transform test data using the same PCA
    X_val_pca = pca.transform(X_val)
    X_test_pca = pca.transform(X_test)
    return X_train_pca, X_val_pca, X_test_pca


### logistic regression
def train_logistic_regression(X_train, y_train, X_test, y_test):
    model = LogisticRegression(max_iter=1000, class_weight="balanced")
    model.fit(X_train, y_train)

    y_pred = model.predict(X_test)
    y_pred_prob = model.predict_proba(X_test)[:, 1]

    return y_pred, y_pred_prob

### MLP
class CustomMLP2Layer(nn.Module):
    def __init__(self, input_size, hidden_size1=100, hidden_size2=50):
        super(CustomMLP2Layer, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size2, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x) 
        x = self.sigmoid(x)
        return x
    
def train_mlp(X_train, y_train, X_val, y_val, X_test, y_test):
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_val_scaled = scaler.transform(X_val)
    X_test_scaled = scaler.transform(X_test)
    
    X_train_tensor = torch.FloatTensor(X_train_scaled)
    y_train_tensor = torch.FloatTensor(y_train).unsqueeze(1)
    X_val_tensor = torch.FloatTensor(X_val_scaled)
    y_val_tensor = torch.FloatTensor(y_val).unsqueeze(1)
    X_test_tensor = torch.FloatTensor(X_test_scaled)
    y_test_tensor = torch.FloatTensor(y_test).unsqueeze(1)
    
    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
    test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
    
    batch_size = 32
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    input_size = X_train.shape[1]
    model = CustomMLP2Layer(input_size)
    
    # calculate class weights for weighted BCE loss
    num_pos = np.sum(y_train)  # number of 1s
    num_neg = len(y_train) - num_pos  # number of 0s

    if num_pos < num_neg:
        # class 1 is minority - upweight class 1
        pos_weight = torch.FloatTensor([num_neg / num_pos])
        minority_class = 1
    else:
        # class 0 is minority - upweight class 0  
        pos_weight = torch.FloatTensor([num_pos / num_neg])
        minority_class = 0
    criterion = nn.BCELoss(reduction='none')
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # training loop
    num_epochs = 50
    best_f1 = 0
    best_model_state = None
    patience = 5
    patience_counter = 0
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            
            loss = criterion(outputs, labels)
            weights = torch.ones_like(labels)
            weights[labels == minority_class] = pos_weight
            loss = (loss * weights).mean()

            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
        epoch_loss = running_loss / len(train_loader.dataset)
        
        # validation
        model.eval()
        y_pred = []
        y_true = []
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                outputs = model(inputs)
                y_pred.extend(outputs.cpu().numpy())
                y_true.extend(labels.cpu().numpy())
        
        y_pred = np.array(y_pred).flatten()
        y_true = np.array(y_true).flatten()
        
        # convert to binary
        y_pred_binary = (y_pred >= 0.5).astype(int)
        val_f1 = f1_score(y_true, y_pred_binary, average="binary")
        
        if val_f1 > best_f1:
            best_f1 = val_f1
            best_model_state = model.state_dict().copy()
            patience_counter = 0
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
            
        if (epoch + 1) % 5 == 0:
            print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Val F1: {val_f1:.4f}')

    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    # test
    model.eval()
    with torch.no_grad():
        y_pred_tensor = model(X_test_tensor)
        y_pred_prob = y_pred_tensor.cpu().numpy().flatten()
        
    y_pred = (y_pred_prob >= 0.5).astype(int)
    return y_pred, y_pred_prob


########################
######## main ##########
########################
def main():
    activations_dict, prompts_dict, cots_dict, labels_dict = load_data()
    
    # get train-test split based on prompt IDs
    prompt_IDs = set([x.split("_")[0] for x in activations_dict.keys()])
    N = len(prompt_IDs)
    logger.debug(f"Loaded {len(activations_dict)=} activations of last-token CoTs for {N=} prompts.")

    # initialize lists to store results across all runs
    D_final_logreg_scores = collections.defaultdict(list)
    D_final_mlp_scores = collections.defaultdict(list)
    D_final_rand_lr_scores = collections.defaultdict(list)
    D_final_random_scores = collections.defaultdict(list)
    D_final_always_ones_scores = collections.defaultdict(list)
    D_final_always_zeros_scores = collections.defaultdict(list)
    D_final_theoretical_random_scores = collections.defaultdict(list)
    total_disagreement_percentage = 0

    for seed in range(args.N_runs):
        np.random.seed(seed)  # for reproducibility
        train_prompt_ids = set(np.random.choice(sorted(list(prompt_IDs)), int(0.7 * N), replace=False))
        test_prompt_ids = prompt_IDs - train_prompt_ids
        print(f"Selected {len(train_prompt_ids)} prompts for training and {len(test_prompt_ids)} for testing")
        
        # split train_prompt_ids into train and validation sets (90:10 split)
        train_prompt_ids_list = list(train_prompt_ids)
        np.random.shuffle(train_prompt_ids_list)
        split_idx = int(0.9 * len(train_prompt_ids_list))
        train_prompt_ids = set(train_prompt_ids_list[:split_idx])
        val_prompt_ids = set(train_prompt_ids_list[split_idx:])
        logger.debug(f"Train prompts: {len(train_prompt_ids)}, Val prompts: {len(val_prompt_ids)}, Test prompts: {len(test_prompt_ids)}")
        
        # prepare data for this iteration
        X, labels_np, prompt_sent_ids  = prepare_data(activations_dict, labels_dict)
        train_indices = [i for i, key in enumerate(prompt_sent_ids) if key.split('_')[0] in train_prompt_ids]
        val_indices = [i for i, key in enumerate(prompt_sent_ids) if key.split('_')[0] in val_prompt_ids]
        test_indices = [i for i, key in enumerate(prompt_sent_ids) if key.split('_')[0] in test_prompt_ids]
        
        if args.sample_K and args.sample_K > 0:
            assert len(train_indices) >= args.sample_K, f"Not enough training samples. Required: {args.sample_K}, Available: {len(train_indices)}"
            np.random.shuffle(train_indices)
            train_indices = train_indices[:args.sample_K]
            logger.info(f">>>> use {args.sample_K} data")
        else:
            logger.info("use all data")

        X_train = X[train_indices]
        X_val = X[val_indices]
        X_test = X[test_indices]

        # convert scores to binary classes
        threshold = 0.5
        y_train = (labels_np[train_indices] >= threshold).astype(int) 
        y_val = (labels_np[val_indices] >= threshold).astype(int)
        y_test = (labels_np[test_indices] >= threshold).astype(int)

        # current labels: safe -> 0; unsafe -> 1
        # flip labels if safe is rarer
        if (y_test == 0).sum() < (y_test == 1).sum():
            logger.info("Flipping labels (0->1, 1->0) so unsafe -> 0, safe (rarer) -> 1")
            y_train = 1 - y_train
            y_val = 1 - y_val
            y_test = 1 - y_test

        keys_train = [prompt_sent_ids[i] for i in train_indices]
        keys_val = [prompt_sent_ids[i] for i in val_indices]
        keys_test = [prompt_sent_ids[i] for i in test_indices]

        if args.pca:
            X_train, X_val, X_test = apply_pca(X_train, X_val, X_test)
        
        logger.debug(f"Training set: {len(X_train)} latents (Safe: {np.sum(y_train==1)}, Unsafe: {np.sum(y_train==0)})")
        logger.debug(f"Validation set: {len(X_val)} latents (Safe: {np.sum(y_val==1)}, Unsafe: {np.sum(y_val==0)})")
        logger.debug(f"Testing set: {len(X_test)} latents (Safe: {np.sum(y_test==1)}, Unsafe: {np.sum(y_test==0)})")

        ##############################
        ### train safety probes
        ##############################
        logreg_y_pred, logreg_y_pred_prob = train_logistic_regression(X_train, y_train, X_test, y_test)
        mlp_y_pred, mlp_y_pred_prob = train_mlp(X_train, y_train, X_val, y_val, X_test, y_test)

        ##############################
        ### random baseline
        ##############################
        np.random.seed(seed)  # use same seed as outer loop for reproducibility
        positive_prior = np.sum(y_train == 1)/len(y_train)
        random_probs = np.random.uniform(0, 1, size=len(X_test))
        random_y_pred = (random_probs < positive_prior).astype(int)
        always_ones_pred = np.ones(len(X_test))
        always_zeros_pred = np.zeros(len(X_test))

        # shuffle
        np.random.seed(seed)  # use same seed as outer loop for reproducibility
        shuffled_indices = np.arange(len(y_train))
        np.random.shuffle(shuffled_indices)
        y_train_shuffled = y_train[shuffled_indices]
        
        disagreement = np.sum(y_train != y_train_shuffled)
        disagreement_percentage = (disagreement / len(y_train)) * 100
        total_disagreement_percentage += disagreement_percentage
        rand_lr_pred, rand_lr_pred_prob = train_logistic_regression(X_train, y_train_shuffled, X_test, y_test)

        # eval
        logreg_eval = eval_pred(y_test, logreg_y_pred, logreg_y_pred_prob, metrics=["f1", "accuracy", "pr_auc"])
        mlp_eval = eval_pred(y_test, mlp_y_pred, mlp_y_pred_prob, metrics=["f1", "accuracy", "pr_auc"])
        rand_lr_eval = eval_pred(y_test, rand_lr_pred, rand_lr_pred_prob, metrics=["f1", "accuracy", "pr_auc"])
        random_eval = eval_pred(y_test, random_y_pred, random_probs, metrics=["f1", "accuracy", "pr_auc"])
        theory_random_eval = {"f1": positive_prior, "pr_auc": positive_prior}
        always_ones_eval = eval_pred(y_test, always_ones_pred, metrics=["f1", "accuracy"])
        always_ones_eval["pr_auc"] = positive_prior # precision is p and recall is 1
        always_zeros_eval = eval_pred(y_test, always_zeros_pred, metrics=["f1", "accuracy"])
        always_zeros_eval["pr_auc"] = 0 # precision is undefined and recall is 0

        add_to_final_scores(logreg_eval, D_final_logreg_scores, 'logreg')
        add_to_final_scores(mlp_eval, D_final_mlp_scores, 'mlp')
        add_to_final_scores(rand_lr_eval, D_final_rand_lr_scores, 'random_logreg')
        add_to_final_scores(random_eval, D_final_random_scores, 'empirical_random')
        add_to_final_scores(theory_random_eval, D_final_theoretical_random_scores, 'theoretical_random')
        add_to_final_scores(always_ones_eval, D_final_always_ones_scores, "always_ones")
        add_to_final_scores(always_zeros_eval, D_final_always_zeros_scores, "always_zeros")
        
        ##############################
        #### save test outputs
        ##############################
        if args.store_outputs:
            test_text_prompts = [prompts_dict[key] for key in keys_test]
            test_text_cots = [cots_dict[key] for key in keys_test]

            save_probe_outputs_tsv(
                output_dir=PROBE_OUTPUT_FOLDER,
                probe_name=f"logreg_seed{seed}",
                prompt_sent_ids=keys_test,
                prompts=test_text_prompts,
                cots=test_text_cots,
                true_labels=y_test,
                pred_labels=logreg_y_pred,
                pred_probs=logreg_y_pred_prob)
            
            save_probe_outputs_tsv(
                output_dir=PROBE_OUTPUT_FOLDER,
                probe_name=f"mlp_seed{seed}",
                prompt_sent_ids=keys_test,
                prompts=test_text_prompts,
                cots=test_text_cots,
                true_labels=y_test,
                pred_labels=mlp_y_pred,
                pred_probs=mlp_y_pred_prob
            )

    print(f"---- mean disagreement after shuffling: {total_disagreement_percentage/args.N_runs:.2f}%")
    print(f"(N_train: {len(train_indices)}. N_test: {len(test_indices)})")

    print(calculate_metrics_stats([
        D_final_logreg_scores,
        D_final_mlp_scores,
        D_final_rand_lr_scores,
        D_final_random_scores,
        D_final_theoretical_random_scores,
        D_final_always_ones_scores,
        D_final_always_zeros_scores
    ]))


if __name__ == "__main__":
    main()