#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Evaluate model robustness against label distribution noise

import os
import json
import re
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
from collections import Counter
from typing import List, Dict, Union, Tuple
import logging
from tqdm import tqdm
import random
import argparse
from tabulate import tabulate

from dataset import BaseDataset, SNLIDataset, MTBenchDataset, SummEvalDataset

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("RobustnessTest")

class RobustnessTest:
    """Model robustness testing class"""

    def __init__(self, 
                base_path: str = os.path.dirname(os.path.abspath(__file__)), 
                raw_results_path: str = None,
                robust_output_path: str = None,
                seed: int = 42,
                epsilon: float = 0.25,
                alpha: float = 0.8,
                epoch: str = "2"):
        """
        Initialize robustness testing class
        
        Args:
            base_path: Base path
            raw_results_path: Original evaluation results path
            robust_output_path: Path for robustness test results output
            seed: Random seed
            epsilon: Adversarial perturbation parameter
            alpha: Distribution weight parameter
            epoch: Training epoch
        """
        # Set random seed to ensure experiment reproducibility
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        
        self.base_path = base_path
        self.epsilon = epsilon
        self.alpha = alpha
        self.epoch = epoch
        
        # Set paths
        if raw_results_path is None:
            self.raw_results_path = os.path.join(base_path, "evaluation_results", "raw")
        else:
            self.raw_results_path = raw_results_path
            
        if robust_output_path is None:
            self.robust_output_path = os.path.join(base_path, "evaluation_results", "robust")
        else:
            self.robust_output_path = robust_output_path
        
        # Ensure output directory exists
        os.makedirs(self.robust_output_path, exist_ok=True)
        
        # Supported datasets
        self.dataset_mapping = {
            "snli": lambda: SNLIDataset(file_path=os.path.join(base_path, "dataset", "test", "snli_test.jsonl"), name="snli"),
            "multinli": lambda: SNLIDataset(file_path=os.path.join(base_path, "dataset", "test", "multinli_test.jsonl"), name="multinli"),
            "mtbench": lambda: MTBenchDataset(file_path=os.path.join(base_path, "dataset", "test", "mt_bench_test.jsonl"), name="mtbench"),
            "summeval": lambda: SummEvalDataset(file_path=os.path.join(base_path, "dataset", "test", "summeval_test.jsonl"), name="summeval")
        }
        
        # Model configurations
        self.model_configs = {
            "Raw model": {"type": "qwen", "name": "Qwen2.5-7B-Instruct"},
            "Single-point": {"type": "qwenlora", "eps": "0.0", "alpha": "0.0"},
            "Ours w/o Adv": {"type": "qwenlora", "eps": "0.0", "alpha": "0.8"},
            "Ours (Full)": {"type": "qwenlora", "eps": "0.25", "alpha": "0.8"}
        }
        
        # Default model names
        self.raw_model_name = "Qwen2.5-7B-Instruct"
        self.fine_tuned_model_name = "qwen2.5-7b"

    def add_noise_to_distribution(self, distribution: Dict[str, float], noise_level: float) -> Dict[str, float]:
        """
        Add noise to distribution
        
        Args:
            distribution: Original distribution {label: probability}
            noise_level: Noise level
            
        Returns:
            Distribution with added noise
        """
        # Convert to tensor
        labels = list(distribution.keys())
        probs = torch.tensor([distribution[k] for k in labels], dtype=torch.float)
        
        # Generate random noise
        noise = torch.rand_like(probs) * 2 * noise_level - noise_level
        
        # Apply noise to original distribution
        perturbed_probs = probs + noise
        
        # Ensure non-negativity and normalization
        perturbed_probs = torch.clamp(perturbed_probs, min=1e-7)
        perturbed_probs = perturbed_probs / perturbed_probs.sum()
        
        # Convert back to dictionary format
        return {labels[i]: perturbed_probs[i].item() for i in range(len(labels))}

    def calculate_kl_divergence(self, pred_dist: Dict[str, float], 
                              true_dist: Dict[str, float]) -> float:
        """
        Calculate KL divergence
        
        Args:
            pred_dist: Predicted distribution {label: probability}
            true_dist: True distribution {label: probability}
            
        Returns:
            KL divergence value
        """
        all_labels = set(true_dist.keys()).union(set(pred_dist.keys()))
        epsilon = 1e-10  # Prevent division by zero
        
        # Ensure all labels exist in both distributions
        p_dist = {label: true_dist.get(label, epsilon) for label in all_labels}
        q_dist = {label: pred_dist.get(label, epsilon) for label in all_labels}
        
        # Normalize distributions
        p_sum = sum(p_dist.values())
        if p_sum > 0:
            p_dist = {k: v/p_sum for k, v in p_dist.items()}
            
        q_sum = sum(q_dist.values())
        if q_sum > 0:
            q_dist = {k: v/q_sum for k, v in q_dist.items()}
        
        # Calculate KL divergence
        kl = 0.0
        for label in all_labels:
            p = p_dist[label]
            q = q_dist[label]
            if p > epsilon:  # Only calculate when p>0
                kl += p * np.log(p / q)
                
        return kl

    def get_model_prediction_path(self, dataset_name: str, model_config: Dict[str, str]) -> str:
        """
        Get path to model prediction results file
        
        Args:
            dataset_name: Dataset name
            model_config: Model configuration
        
        Returns:
            Prediction results file path
        """
        model_type = model_config.get("type")
        
        if model_type == "qwen":
            # Raw model
            model_name = model_config.get("name")
            return os.path.join(self.raw_results_path, dataset_name, model_type, f"{model_name}.json")
        elif model_type == "qwenlora":
            # Fine-tuned models
            eps = model_config.get("eps")
            alpha = model_config.get("alpha")
            epoch = self.epoch
            model_file = f"{self.fine_tuned_model_name}_eps_{eps}_alpha_{alpha}_epoch_{epoch}.json"
            return os.path.join(self.raw_results_path, dataset_name, model_type, model_file)
        else:
            logger.error(f"Unknown model type: {model_type}")
            return None

    def load_model_predictions(self, file_path: str) -> Dict[str, Dict[str, float]]:
        """
        Load model prediction results
        
        Args:
            file_path: Prediction results file path
            
        Returns:
            Model prediction results {sample ID: {label: probability}}
        """
        if not os.path.exists(file_path):
            logger.error(f"Prediction file not found: {file_path}")
            return {}
            
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                predictions = json.load(f)
            return predictions
        except Exception as e:
            logger.error(f"Failed to load {file_path}: {e}")
            return {}

    def test_robustness(self, dataset_name: str, perturbation_levels: List[float]) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """
        Test model robustness under different noise levels
        
        Args:
            dataset_name: Dataset name
            perturbation_levels: List of noise levels
            
        Returns:
            Results DataFrame and relative change rate DataFrame
        """
        logger.info(f"Testing robustness on dataset: {dataset_name}")
        
        # Load dataset
        if dataset_name not in self.dataset_mapping:
            logger.error(f"Unknown dataset: {dataset_name}")
            return pd.DataFrame(), pd.DataFrame()
            
        dataset = self.dataset_mapping[dataset_name]()
        
        # Initialize result storage
        results = {model_name: {} for model_name in self.model_configs}
        
        # Load prediction results for each model
        model_predictions = {}
        for model_name, model_config in self.model_configs.items():
            prediction_path = self.get_model_prediction_path(dataset_name, model_config)
            if prediction_path:
                predictions = self.load_model_predictions(prediction_path)
                if predictions:
                    model_predictions[model_name] = predictions
                    logger.info(f"Loaded predictions for {model_name} ({len(predictions)} samples)")
                else:
                    logger.warning(f"No predictions found for {model_name}")
            else:
                logger.warning(f"Could not determine prediction path for {model_name}")
        
        # Skip if no prediction results
        if not model_predictions:
            logger.error(f"No model predictions found for dataset: {dataset_name}")
            return pd.DataFrame(), pd.DataFrame()
        
        # Process each noise level
        for delta in perturbation_levels:
            logger.info(f"Processing perturbation level δ = {delta:.2f}")
            
            # Record KL divergence for each model at current noise level
            kl_values = {model_name: [] for model_name in model_predictions}
            
            # Add same noise to each sample
            for item in tqdm(dataset, desc=f"Processing samples (δ={delta:.2f})"):
                id_key = str(item[dataset.id_key])
                
                # Get original label distribution
                true_dist = dataset.get_label(item)
                
                # Convert to dictionary form (if not already)
                if isinstance(true_dist, list):
                    # Convert list to distribution
                    counter = Counter(true_dist)
                    total = len(true_dist)
                    true_dist = {k: v/total for k, v in counter.items()}
                
                # Add noise to label distribution
                noisy_dist = self.add_noise_to_distribution(true_dist, delta)
                
                # Calculate KL divergence for each model
                for model_name, predictions in model_predictions.items():
                    if id_key in predictions:
                        pred_dist = predictions[id_key]
                        kl = self.calculate_kl_divergence(pred_dist, noisy_dist)
                        kl_values[model_name].append(kl)
            
            # Calculate average KL divergence
            for model_name in model_predictions:
                if kl_values[model_name]:
                    avg_kl = np.mean(kl_values[model_name])
                    results[model_name][f"{delta:.2f}"] = avg_kl
                    logger.info(f"Model: {model_name}, δ = {delta:.2f}, Avg KL = {avg_kl:.4f}")
                else:
                    results[model_name][f"{delta:.2f}"] = float('nan')
                    logger.warning(f"No valid samples for {model_name} at δ = {delta:.2f}")
        
        # Create results DataFrame
        df_results = pd.DataFrame({
            "Model": [model_name for model_name in self.model_configs if model_name in model_predictions],
        })
        
        # Add results column for each noise level
        for delta in perturbation_levels:
            delta_str = f"{delta:.2f}"
            kl_values = []
            for model_name in df_results["Model"]:
                kl_values.append(results[model_name].get(delta_str, float('nan')))
            df_results[f"δ = {delta_str}"] = kl_values
        
        # Calculate relative change rate
        df_relative = df_results.copy()
        
        # Get baseline column (δ = 0.00)
        base_col = f"δ = {perturbation_levels[0]:.2f}"
        
        # Calculate relative change rate for each noise level compared to baseline
        for delta in perturbation_levels[1:]:  # Skip first one as it's the baseline
            delta_str = f"{delta:.2f}"
            col_name = f"δ = {delta_str}"
            rel_values = []
            
            for i, model_name in enumerate(df_results["Model"]):
                base_val = df_results.loc[i, base_col]
                curr_val = df_results.loc[i, col_name]
                
                if not np.isnan(base_val) and not np.isnan(curr_val) and base_val != 0:
                    rel_change = (curr_val - base_val) / base_val * 100  # Percentage change
                    rel_values.append(f"+{rel_change:.1f}%" if rel_change >= 0 else f"{rel_change:.1f}%")
                else:
                    rel_values.append("N/A")
            
            df_relative[col_name] = rel_values
        
        return df_results, df_relative

    def generate_latex_table(self, results_by_dataset: Dict[str, Tuple[pd.DataFrame, pd.DataFrame]], 
                           output_type: str = "kl") -> str:
        """
        Generate LaTeX format table
        
        Args:
            results_by_dataset: Results for each dataset {dataset_name: (results_DataFrame, relative_change_DataFrame)}
            output_type: Output type, 'kl' for KL divergence values, 'relative' for relative change rates
            
        Returns:
            LaTeX format table
        """
        # Get all noise level column names
        delta_cols = []
        for _, (df, _) in results_by_dataset.items():
            delta_cols = [col for col in df.columns if col.startswith("δ =")]
            break
        
        if not delta_cols:
            logger.error("No delta columns found in results")
            return ""
            
        # Table header
        latex_table = "\\begin{table}[h]\n"
        latex_table += "\\centering\n"
        latex_table += "\\renewcommand{\\arraystretch}{1.25}\n"
        
        # Title
        if output_type == "kl":
            latex_table += "\\caption{Robustness analysis of our distribution alignment method under varying label perturbation levels ($\\delta$). KL divergence is reported (lower is better).}\n"
            latex_table += "\\label{robustness-table}\n"
        else:
            latex_table += "\\caption{Relative change in KL divergence compared to no perturbation ($\\delta = 0.00$).}\n"
            latex_table += "\\label{robustness-relative-table}\n"
        
        # Adjust table width
        latex_table += "\\resizebox{\\textwidth}{!}{\n"
        
        # Table start
        latex_table += "\\begin{tabular}{c c " + " c" * len(delta_cols) + "}\n"
        latex_table += "\\toprule\n"
        
        # Table header
        latex_table += "\\multirow{2}{*}{\\textbf{Model}} & \\multirow{2}{*}{\\textbf{Method}} & "
        latex_table += "\\multicolumn{" + str(len(delta_cols)) + "}{c}{\\textbf{KL Divergence at different perturbation levels ($\\delta$)}} \\\\\n"
        latex_table += "\\cmidrule(lr){3-" + str(2 + len(delta_cols)) + "}\n"
        
        # Noise level row
        latex_table += "& "
        for col in delta_cols:
            delta_val = col.replace("δ = ", "")
            latex_table += f"& $\\delta = {delta_val}$ "
        latex_table += "\\\\\n"
        
        latex_table += "\\midrule\n"
        
        # Data rows
        for dataset_name, (df_kl, df_rel) in results_by_dataset.items():
            # Get current dataset's row count
            rows = len(df_kl)
            
            # Add dataset row
            latex_table += f"\\multirow{{{rows}}}{{*}}{{{dataset_name}}} \n"
            
            # Add data for each model
            for i, row in df_kl.iterrows():
                model_name = row['Model']
                
                # Model name
                latex_table += f"& {model_name} "
                
                # KL divergence or relative change rate
                if output_type == "kl":
                    for col in delta_cols:
                        val = row[col]
                        if np.isnan(val):
                            latex_table += "& - "
                        else:
                            latex_table += f"& {val:.3f} "
                else:  # relative
                    baseline_col = delta_cols[0]
                    baseline_val = df_kl.loc[i, baseline_col]
                    
                    # First column always shows absolute value
                    if np.isnan(baseline_val):
                        latex_table += "& - "
                    else:
                        latex_table += f"& {baseline_val:.3f} "
                    
                    # Other columns show relative change
                    for col in delta_cols[1:]:
                        rel_val = df_rel.loc[i, col]
                        latex_table += f"& {rel_val} "
                
                latex_table += "\\\\\n"
            
            # Add middle line if not the last dataset
            if dataset_name != list(results_by_dataset.keys())[-1]:
                latex_table += "\\midrule\n"
        
        # Table footer
        latex_table += "\\bottomrule\n"
        latex_table += "\\end{tabular}\n"
        latex_table += "}\n"
        latex_table += "\\end{table}"
        
        return latex_table

    def generate_difference_table(self, results_by_dataset: Dict[str, Tuple[pd.DataFrame, pd.DataFrame]]) -> str:
        """
        Generate LaTeX table of KL divergence differences
        
        Args:
            results_by_dataset: Results for each dataset {dataset_name: (results_DataFrame, relative_change_DataFrame)}
            
        Returns:
            LaTeX format table
        """
        # Get all noise level column names
        delta_cols = []
        for _, (df, _) in results_by_dataset.items():
            delta_cols = [col for col in df.columns if col.startswith("δ =")]
            break
        
        if not delta_cols or len(delta_cols) < 2:
            logger.error("Need at least two noise level columns to calculate differences")
            return ""
            
        # Baseline column (δ = 0.00)
        base_col = delta_cols[0]
        diff_cols = delta_cols[1:]  # Skip baseline column
            
        # Table header
        latex_table = "\\begin{table}[h]\n"
        latex_table += "\\centering\n"
        latex_table += "\\renewcommand{\\arraystretch}{1.25}\n"
        
        # Title
        latex_table += "\\caption{Absolute difference in KL divergence compared to no perturbation ($\\delta = 0.00$).}\n"
        latex_table += "\\label{robustness-difference-table}\n"
        
        # Adjust table width
        latex_table += "\\resizebox{\\textwidth}{!}{\n"
        
        # Table start
        latex_table += "\\begin{tabular}{c c " + " c" * len(diff_cols) + "}\n"
        latex_table += "\\toprule\n"
        
        # Table header
        latex_table += "\\multirow{2}{*}{\\textbf{Model}} & \\multirow{2}{*}{\\textbf{Method}} & "
        latex_table += "\\multicolumn{" + str(len(diff_cols)) + "}{c}{\\textbf{KL Divergence difference at different perturbation levels ($\\delta$)}} \\\\\n"
        latex_table += "\\cmidrule(lr){3-" + str(2 + len(diff_cols)) + "}\n"
        
        # Noise level row
        latex_table += "& "
        for col in diff_cols:
            delta_val = col.replace("δ = ", "")
            latex_table += f"& $\\delta = {delta_val}$ "
        latex_table += "\\\\\n"
        
        latex_table += "\\midrule\n"
        
        # Data rows
        for dataset_name, (df_kl, _) in results_by_dataset.items():
            # Get current dataset's row count
            rows = len(df_kl)
            
            # Add dataset row
            latex_table += f"\\multirow{{{rows}}}{{*}}{{{dataset_name}}} \n"
            
            # Add data for each model
            for i, row in df_kl.iterrows():
                model_name = row['Model']
                
                # Model name
                latex_table += f"& {model_name} "
                
                # Calculate difference from baseline value
                base_val = row[base_col]
                
                # Output difference at each noise level
                for col in diff_cols:
                    curr_val = row[col]
                    
                    if not np.isnan(base_val) and not np.isnan(curr_val):
                        diff = curr_val - base_val  # Absolute difference
                        latex_table += f"& {diff:.3f} "
                    else:
                        latex_table += "& - "
                
                latex_table += "\\\\\n"
            
            # Add middle line if not the last dataset
            if dataset_name != list(results_by_dataset.keys())[-1]:
                latex_table += "\\midrule\n"
        
        # Table footer
        latex_table += "\\bottomrule\n"
        latex_table += "\\end{tabular}\n"
        latex_table += "}\n"
        latex_table += "\\end{table}"
        
        return latex_table

    def save_results(self, 
                    results_by_dataset: Dict[str, Tuple[pd.DataFrame, pd.DataFrame]],
                    perturbation_levels: List[float]) -> None:
        """
        Save results to files
        
        Args:
            results_by_dataset: Results for each dataset {dataset_name: (results_DataFrame, relative_change_DataFrame)}
            perturbation_levels: List of noise levels
        """
        # Create separate output directories for each model
        raw_output_dir = os.path.join(self.robust_output_path, self.raw_model_name)
        fine_tuned_output_dir = os.path.join(self.robust_output_path, self.fine_tuned_model_name)
        
        os.makedirs(raw_output_dir, exist_ok=True)
        os.makedirs(fine_tuned_output_dir, exist_ok=True)
        
        # Save CSV format results
        for dataset_name, (df_kl, df_rel) in results_by_dataset.items():
            # KL divergence results
            kl_csv_path = os.path.join(fine_tuned_output_dir, f"{dataset_name}_kl_divergence.csv")
            df_kl.to_csv(kl_csv_path, index=False)
            logger.info(f"Saved KL divergence results to {kl_csv_path}")
            
            # Relative change rate results
            rel_csv_path = os.path.join(fine_tuned_output_dir, f"{dataset_name}_relative_change.csv")
            df_rel.to_csv(rel_csv_path, index=False)
            logger.info(f"Saved relative change results to {rel_csv_path}")
            
            # Calculate difference table and save as CSV
            diff_df = df_kl.copy()
            base_col = f"δ = {perturbation_levels[0]:.2f}"
            
            for delta in perturbation_levels[1:]:
                delta_col = f"δ = {delta:.2f}"
                diff_col = f"diff_{delta:.2f}"
                diff_df[diff_col] = diff_df[delta_col] - diff_df[base_col]
            
            # Save difference CSV
            diff_csv_path = os.path.join(fine_tuned_output_dir, f"{dataset_name}_kl_difference.csv")
            diff_df.to_csv(diff_csv_path, index=False)
            logger.info(f"Saved KL difference results to {diff_csv_path}")
        
        # Generate and save LaTeX tables
        kl_latex = self.generate_latex_table(results_by_dataset, "kl")
        rel_latex = self.generate_latex_table(results_by_dataset, "relative")
        diff_latex = self.generate_difference_table(results_by_dataset)
        
        kl_latex_path = os.path.join(fine_tuned_output_dir, "robustness_kl_table.tex")
        with open(kl_latex_path, 'w', encoding='utf-8') as f:
            f.write(kl_latex)
        logger.info(f"Saved KL divergence LaTeX table to {kl_latex_path}")
        
        rel_latex_path = os.path.join(fine_tuned_output_dir, "robustness_relative_table.tex")
        with open(rel_latex_path, 'w', encoding='utf-8') as f:
            f.write(rel_latex)
        logger.info(f"Saved relative change LaTeX table to {rel_latex_path}")
        
        # Save difference table
        diff_latex_path = os.path.join(fine_tuned_output_dir, "robustness_difference_table.tex")
        with open(diff_latex_path, 'w', encoding='utf-8') as f:
            f.write(diff_latex)
        logger.info(f"Saved KL difference LaTeX table to {diff_latex_path}")
        
        # Generate and save nice console tables
        console_output = "\n====== KL Divergence Results ======\n\n"
        for dataset_name, (df_kl, _) in results_by_dataset.items():
            console_output += f"Dataset: {dataset_name}\n"
            console_output += tabulate(df_kl, headers='keys', tablefmt='grid', floatfmt='.3f') + "\n\n"
        
        console_output += "\n====== Relative Change Results ======\n\n"
        for dataset_name, (_, df_rel) in results_by_dataset.items():
            console_output += f"Dataset: {dataset_name}\n"
            console_output += tabulate(df_rel, headers='keys', tablefmt='grid') + "\n\n"
        
        # Add difference table to console output
        console_output += "\n====== KL Divergence Difference Results ======\n\n"
        for dataset_name, (df_kl, _) in results_by_dataset.items():
            diff_table_data = []
            headers = ["Model"]
            
            # Difference column titles
            for delta in perturbation_levels[1:]:
                headers.append(f"Diff(δ={delta:.2f})")
                
            # Model rows
            for i, row in df_kl.iterrows():
                model_row = [row["Model"]]
                base_val = row[f"δ = {perturbation_levels[0]:.2f}"]
                
                for delta in perturbation_levels[1:]:
                    delta_col = f"δ = {delta:.2f}"
                    curr_val = row[delta_col]
                    
                    if not np.isnan(base_val) and not np.isnan(curr_val):
                        diff = curr_val - base_val
                        model_row.append(f"{diff:.3f}")
                    else:
                        model_row.append("N/A")
                        
                diff_table_data.append(model_row)
                
            console_output += f"Dataset: {dataset_name}\n"
            console_output += tabulate(diff_table_data, headers=headers, tablefmt='grid') + "\n\n"
        
        console_output_path = os.path.join(fine_tuned_output_dir, "robustness_results.txt")
        with open(console_output_path, 'w', encoding='utf-8') as f:
            f.write(console_output)
        logger.info(f"Saved console output to {console_output_path}")
        
        # Print results to console
        print(console_output)

    def run(self, datasets: List[str], perturbation_levels: List[float]) -> None:
        """
        Run robustness test
        
        Args:
            datasets: List of datasets
            perturbation_levels: List of noise levels
        """
        # Ensure noise levels are sorted in ascending order
        perturbation_levels = sorted(perturbation_levels)
        
        # Collect results for all datasets
        results_by_dataset = {}
        
        for dataset_name in datasets:
            if dataset_name not in self.dataset_mapping:
                logger.warning(f"Unknown dataset: {dataset_name}")
                continue
                
            df_kl, df_rel = self.test_robustness(dataset_name, perturbation_levels)
            
            if not df_kl.empty:
                results_by_dataset[dataset_name] = (df_kl, df_rel)
        
        # Save results
        if results_by_dataset:
            self.save_results(results_by_dataset, perturbation_levels)
        else:
            logger.error("No results to save")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Test model robustness against label distribution noise")
    parser.add_argument('--datasets', type=str, nargs='+', 
                        choices=['snli', 'multinli', 'mtbench', 'summeval', 'all'],
                        # default=['snli', 'multinli', 'summeval', 'mtbench'], 
                        default=['summeval'],
                        help='Datasets to process')
    parser.add_argument('--deltas', type=float, nargs='+',
                        default=[0.00, 0.05, 0.10, 0.15, 0.20, 0.25], 
                        help='List of noise levels')
    parser.add_argument('--seed', type=int, default=42, 
                        help='Random seed')
    parser.add_argument('--epsilon', type=float, default=0.25, 
                        help='Epsilon parameter for Ours (Full) model')
    parser.add_argument('--alpha', type=float, default=0.8, 
                        help='Alpha parameter for Ours w/o Adv and Ours (Full) models')
    parser.add_argument('--epoch', type=str, default="2", 
                        help='Training epoch for fine-tuned models')
    parser.add_argument('--raw_model', type=str, default="Qwen2.5-7B-Instruct",
                        help='Raw model name')
    parser.add_argument('--fine_tuned_model', type=str, default="qwen2.5-7b-3",
                        help='Fine-tuned model name prefix')
    
    args = parser.parse_args()
    
    # Determine datasets to process
    datasets_to_process = []
    if 'all' in args.datasets:
        datasets_to_process = ['snli', 'multinli', 'mtbench', 'summeval']
    else:
        datasets_to_process = args.datasets
    
    # Create and run robustness test
    robustness_test = RobustnessTest(
        seed=args.seed,
        epsilon=args.epsilon,
        alpha=args.alpha,
        epoch=args.epoch
    )
    
    # Set model names
    robustness_test.raw_model_name = args.raw_model
    robustness_test.fine_tuned_model_name = args.fine_tuned_model
    
    # Update model configurations
    robustness_test.model_configs = {
        "Raw model": {"type": "qwen", "name": args.raw_model},
        "Single-point": {"type": "qwenlora", "eps": "0.0", "alpha": "0.0"},
        "Ours w/o Adv": {"type": "qwenlora", "eps": "0.0", "alpha": str(args.alpha)},
        "Ours (Full)": {"type": "qwenlora", "eps": str(args.epsilon), "alpha": str(args.alpha)}
    }
    
    # Run test
    robustness_test.run(datasets_to_process, args.deltas)
    
    print("Robustness test completed!") 