import os
import sys
import json
import numpy as np
from tqdm import tqdm
from sklearn.metrics.pairwise import euclidean_distances
import pickle
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report, roc_auc_score, roc_curve, average_precision_score
from sklearn.model_selection import train_test_split
import torch

sys.path.append(os.path.abspath("."))
from task_tracker.CONFIG_1 import current_risk

# Define fixed dataset and anchor sample sizes
MODEL = "llama3_8b"
DATASET = "Medicalsys"
ANCHOR_SAMPLES = [200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000]

# Output directory
OUTPUT_DIR = f"/guardrail/TaskTracker/store/model/{current_risk}/{DATASET}/{MODEL}"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Import required modules
from task_tracker.training.dataset import ActivationsDatasetDynamic, ActivationsDatasetDynamicPrimaryText
from task_tracker.training.helpers.data import load_file_paths
from task_tracker.training.utils.constants_1 import CONSTANTS_ALL_MODELS, OOD_POISONED_FILE

# Model layers definition
LAYERS_PER_MODEL = {
    'llama3_70b': [0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79],
    'phi3': [0, 7, 15, 23, 31],
    'mixtral': [0, 7, 15, 23, 31],
    'mistral': [0, 7, 15, 23, 31],
    'llama3_8b': [0, 7, 15, 23, 31],
    'mistral_no_priming': [0, 7, 15, 23, 31],
    'vicuna': [0, 7, 15, 23, 31],
}

# Constants for the current model
ACTIVATION_FILE_LIST_DIR = CONSTANTS_ALL_MODELS[MODEL]['ACTIVATION_FILE_LIST_DIR'].replace("FinQA", DATASET)
ACTIVATIONS_DIR = CONSTANTS_ALL_MODELS[MODEL]['ACTIVATIONS_DIR'].replace("FinQA", DATASET)
ACTIVATIONS_VAL_DIR = CONSTANTS_ALL_MODELS[MODEL]['ACTIVATIONS_VAL_DIR'].replace("FinQA", DATASET)

# Basic configuration
config = {
    'activations': ACTIVATIONS_DIR,
    'activations_ood': ACTIVATIONS_VAL_DIR,
    'ood_poisoned_file': OOD_POISONED_FILE,
    'exp_name': f'distance_based_classification_{MODEL}_{DATASET}',
}

# Function to train a multi-class classifier
def train_multi_class_classifier(train_activations, train_labels):
    print("[*] Training Random Forest Classifier")
    classifier = RandomForestClassifier(n_estimators=100, random_state=42, max_depth=10, min_samples_split=5)
    classifier.fit(train_activations, train_labels)
    return classifier

# Function to evaluate the classifier
def evaluate_classifier(classifier, test_activations, test_labels):
    print("[*] Evaluating the classifier")
    predictions = classifier.predict(test_activations)
    accuracy = accuracy_score(test_labels, predictions)
    
    # Multi-class ROC AUC
    roc_auc = roc_auc_score(test_labels, classifier.predict_proba(test_activations), multi_class='ovr')
    
    # Calculate AUPRC (average precision) for multi-class
    y_true_binary = np.zeros((len(test_labels), len(np.unique(test_labels))))
    for i, label in enumerate(test_labels):
        y_true_binary[i, label] = 1
    
    auprc = average_precision_score(y_true_binary, classifier.predict_proba(test_activations))
    
    print(f"Test accuracy: {accuracy}")
    print(f"ROC AUC score: {roc_auc}")
    print(f"AUPRC score: {auprc}")
    print(classification_report(test_labels, predictions))
    
    return accuracy, roc_auc, auprc

def load_activations(file_paths, num_layers, activations_dir):
    dataset = ActivationsDatasetDynamic(file_paths, root_dir=activations_dir, num_layers=num_layers)
    activations = []
    for activation in tqdm(dataset):
        activations.append(activation.flatten().float().numpy())
    return np.array(activations)

def load_val_activations(file_paths, num_layers, activations_dir):
    root_dir = activations_dir.replace('/training', '/validation')
    dataset = ActivationsDatasetDynamic(file_paths, root_dir, num_layers=num_layers)
    activations = []
    for activation in tqdm(dataset):
        activations.append(activation.flatten().float().numpy())
    return np.array(activations)

def load_test_activations(file_paths, num_layers, activations_dir):
    root_dir = activations_dir.replace('/training', '/test')
    dataset = ActivationsDatasetDynamic(file_paths, root_dir, num_layers=num_layers)
    activations = []
    for activation in tqdm(dataset):
        activations.append(activation.flatten().float().numpy())
    return np.array(activations)

def generate_distance_features(samples, clean_anchors_list):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    features = []
    clean_anchors_tensors = [torch.tensor(clean_anchors).to(device) for clean_anchors in clean_anchors_list]
    for sample in tqdm(samples):
        sample_tensor = torch.tensor(sample).to(device)
        distances = []
        for clean_anchors_tensor in clean_anchors_tensors:
            expanded_sample = sample_tensor.expand_as(clean_anchors_tensor)
            distance = torch.norm(expanded_sample - clean_anchors_tensor, dim=1).mean().item()
            distances.append(distance)
        features.append(distances)
    return np.array(features)

# Modified function to use specific number of anchor samples
def train_and_evaluate_with_distance(train_files, val_files, test_files, num_layers, activations_dir, n_anchors):
    # Load training activations
    print("Loading training activations.")
    train_activations = []
    train_labels = []
    for label, files in enumerate(train_files):
        activations = load_activations(files, num_layers, activations_dir)
        train_activations.append(activations)
        train_labels.extend([label] * len(activations))
    train_activations = np.concatenate(train_activations, axis=0)
    train_labels = np.array(train_labels)

    # Load test activations
    print("Loading test activations.")
    test_activations = []
    test_labels = []
    for label, files in enumerate(test_files):
        activations = load_test_activations(files, num_layers, activations_dir)
        test_activations.append(activations)
        test_labels.extend([label] * len(activations))
    test_activations = np.concatenate(test_activations, axis=0)
    test_labels = np.array(test_labels)

    # Load validation activations
    print("Loading validation activations.")
    val_activations = []
    val_labels = []
    for label, files in enumerate(val_files):
        activations = load_val_activations(files, num_layers, activations_dir)
        val_activations.append(activations)
        val_labels.extend([label] * len(activations))
    val_activations = np.concatenate(val_activations, axis=0)
    val_labels = np.array(val_labels)

    # Set up clean anchor samples with specified count for each class
    clean_anchors_list = []
    for i in range(4):  # 4 classes
        class_samples = train_activations[train_labels == i]
        n = min(len(class_samples), n_anchors)
        print(f"Using {n} anchor samples for class {i}")
        clean_anchors_list.append(class_samples[:n])

    # Generate training features and labels
    print("Generating training features.")
    train_features = generate_distance_features(train_activations, clean_anchors_list)

    # Generate validation features and labels
    print("Generating validation features.")
    val_features = generate_distance_features(val_activations, clean_anchors_list)

    # Train multi-class classifier
    classifier = train_multi_class_classifier(train_features, train_labels)

    # Evaluate on validation set
    print("Evaluating on validation set.")
    evaluate_classifier(classifier, val_features, val_labels)

    # Evaluate on test set
    print("Evaluating on test set.")
    test_features = generate_distance_features(test_activations, clean_anchors_list)
    test_accuracy, test_roc_auc, test_auprc = evaluate_classifier(classifier, test_features, test_labels)
    
    return test_accuracy, test_roc_auc, test_auprc

if __name__ == "__main__":
    LAYERS = LAYERS_PER_MODEL[MODEL]
    
    # Load the file paths (only once, all anchor counts share the same files)
    train_files = [
        load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'train_goods_files_{MODEL}.txt')),
        load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'train_employee_files_{MODEL}.txt')),
        load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'train_case_files_{MODEL}.txt')),
        load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'train_financial_files_{MODEL}.txt'))
    ]
    val_files = [
        load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'val_goods_files_{MODEL}.txt')),
        load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'val_employee_files_{MODEL}.txt')),
        load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'val_case_files_{MODEL}.txt')),
        load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'val_financial_files_{MODEL}.txt'))
    ]
    test_files = [
        load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'test_goods_files_{MODEL}.txt')),
        load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'test_employee_files_{MODEL}.txt')),
        load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'test_case_files_{MODEL}.txt')),
        load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'test_financial_files_{MODEL}.txt'))
    ]
    
    print(f"Training with {len(train_files[0])} goods files, {len(train_files[1])} employee files, {len(train_files[2])} case files, and {len(train_files[3])} financial files.")
    print(f"Evaluating with {len(val_files[0])} goods files, {len(val_files[1])} employee files, {len(val_files[2])} case files, and {len(val_files[3])} financial files.")
    
    # Iterate through different anchor sample sizes
    for n_anchors in ANCHOR_SAMPLES:
        print(f"\n===== Using {n_anchors} anchor samples =====\n")
        
        # Initialize results list
        results = []
        
        # Process each layer
        for n_layer in LAYERS:
            print(f"[*] Processing the {n_layer}-th activation layer with {n_anchors} anchors.")
            
            # Create directories for this anchor count and layer
            anchor_output_dir = os.path.join(OUTPUT_DIR, f"anchors_{n_anchors}")
            os.makedirs(anchor_output_dir, exist_ok=True)
            layer_output_dir = os.path.join(anchor_output_dir, str(n_layer))
            os.makedirs(layer_output_dir, exist_ok=True)

            # Create config for this run
            _config = config.copy()
            _config["num_layers"] = n_layer
            _config["n_anchors"] = n_anchors
            _config["exp_name"] = f"{config['exp_name']}_{n_layer}_anchors_{n_anchors}"
            print(_config["exp_name"])
            
            # Save config
            with open(os.path.join(layer_output_dir, 'config.json'), 'w') as f:
                json.dump(_config, f)

            # Train and evaluate model with this anchor count
            test_accuracy, test_roc_auc, test_auprc = train_and_evaluate_with_distance(
                train_files, 
                val_files, 
                test_files, 
                num_layers=(n_layer, n_layer),
                activations_dir=ACTIVATIONS_DIR,
                n_anchors=n_anchors
            )
            
            # Save results
            results.append({
                'layer': n_layer,
                'anchors': n_anchors,
                'test_accuracy': test_accuracy,
                'test_roc_auc': test_roc_auc,
                'test_auprc': test_auprc
            })

        # Save results to file for this anchor count
        results_file = os.path.join(OUTPUT_DIR, f"{current_risk}_{DATASET.lower()}_{MODEL}_results_anchors_{n_anchors}.json")
        with open(results_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=4)
        print(f"Results for {n_anchors} anchors saved to {results_file}")
    
    # Create a combined results file with anchor information
    all_results = []
    for n_anchors in ANCHOR_SAMPLES:
        results_file = os.path.join(OUTPUT_DIR, f"{current_risk}_{DATASET.lower()}_{MODEL}_results_anchors_{n_anchors}.json")
        if os.path.exists(results_file):
            with open(results_file, 'r') as f:
                layer_results = json.load(f)
                # Find the best layer result based on test_roc_auc
                if layer_results:
                    best_result = max(layer_results, key=lambda x: x.get('test_roc_auc', 0))
                    all_results.append({
                        "test_accuracy": best_result.get("test_accuracy"),
                        "layer": best_result.get("layer"),
                        "test_auroc": best_result.get("test_roc_auc"),
                        "test_auprc": best_result.get("test_auprc"),
                        "risk": current_risk,
                        "dataset": DATASET,
                        "anchor": n_anchors
                    })
    
    # Save the combined results
    combined_file = os.path.join(OUTPUT_DIR, f"{current_risk}_{DATASET}_processed_results.json")
    with open(combined_file, 'w', encoding='utf-8') as f:
        json.dump(all_results, f, ensure_ascii=False, indent=4)
    print(f"Combined processed results saved to {combined_file}")