import os
from pathlib import Path
import sys
import argparse
from datasets import load_dataset
import time
import json
import pandas as pd
import random
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

sys.path.insert(0, str(Path(__file__).parent.parent))
from src.utils.logger_utils import get_logger
from src.services.evaluation.asr_evaluator_service import ASREvaluatorService
from src.utils.random_seed_utils import set_random_seed


SCRIPT_DIR = Path(__file__).parent
PROJECT_ROOT = SCRIPT_DIR.parent
sys.path.insert(0, str(PROJECT_ROOT))

def load_env_config():
    env_path = PROJECT_ROOT / '.env'
    if env_path.exists():
        from dotenv import load_dotenv
        load_dotenv(env_path)
        return True
    return False


def check_api_config():

    env_path = PROJECT_ROOT / '.env'

    if not env_path.exists():
        print("❌ Error: .env configuration file not found")
        print("📋 Please follow these steps to create the configuration file:")
        print("   1. Create a .env file in the project root directory")
        print("   2. Add the following content to the .env file:")
        print("      OPENROUTER_API_KEY=your_actual_api_key")
        print("   3. Get your API key at: https://openrouter.ai/keys")
        print(f"   4. Configuration file location: {env_path}")
        return False

    api_key = os.getenv('OPENROUTER_API_KEY')
    if not api_key or api_key == 'your_openrouter_api_key_here':
        print("❌ Error: OpenRouter API key is not configured or is using the default value")
        print("📋 Please edit the .env file and set OPENROUTER_API_KEY to your actual API key")
        print("   Get your API key at: https://openrouter.ai/keys")
        return False
    return True


class Tester():
    def __init__(self, dataset_name, target_model, dataset_num=200):
        self.dataset_name = dataset_name
        self.target_model = target_model
        self.target_model_nick_name = target_model.split("/")[-1].split(":")[0].replace("-", "_")

        self.asr_evaluator = ASREvaluatorService(
            target_model=target_model, 
            openrouter_api_key=os.getenv("OPENROUTER_API_KEY")
        )

        self.dataset = None
        self.logger = get_logger(self.__class__.__name__)
        self.dataset_num = dataset_num
        self.data_root = PROJECT_ROOT / "experiment" / "exp1_exp2_dataset"


    def _calculate_ppl(self, text: str) -> float:
        # initialize
        if not hasattr(self, '_ppl_model'):
            self.logger.info("           🔄 Initializing GPT-2 perplexity model...")
            self._ppl_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
            self._ppl_model = GPT2LMHeadModel.from_pretrained('gpt2')
            if self._ppl_tokenizer.pad_token is None:
                self._ppl_tokenizer.pad_token = self._ppl_tokenizer.eos_token
            self.logger.info("           ✅ GPT-2 perplexity model initialized")
        
        # compute ppl
        inputs = self._ppl_tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
        with torch.no_grad():
            outputs = self._ppl_model(**inputs, labels=inputs['input_ids'])
            loss = outputs.loss
            ppl = torch.exp(loss).item()
        
        return ppl

    def get_dataset(self):
        if self.dataset_name == "AdvBench":
            csv_path = self.data_root / "AdvBench.csv"
            df = pd.read_csv(csv_path) # len == 520
            self.dataset = df["goal"].dropna().astype(str).tolist()
        elif self.dataset_name == "Do-Not-Answer":
            ds = load_dataset("LibrAI/do-not-answer")
            self.dataset = ds['train']['question'] # len == 939
        elif self.dataset_name == "HarmfulQA":
            ds = load_dataset("declare-lab/HarmfulQA")
            self.dataset = ds['train']['question'] # len == 1960
        elif self.dataset_name == "CategoricalHarmfulQA":
            ds = load_dataset("declare-lab/CategoricalHarmfulQA")
            self.dataset = ds['en']['Question'] # len == 550
        elif self.dataset_name == "HEx-PHI":
            csv_dir = self.data_root / "HEx-PHI"
            all_sentences = []
            for csv_file in csv_dir.glob("*.csv"):
                df = pd.read_csv(csv_file)
                for col in df.columns:
                    all_sentences.extend(df[col].dropna().astype(str).tolist())
            self.dataset = all_sentences # len == 290
        elif self.dataset_name == "RiskAtlas_original_data":
            csv_path = self.data_root / "RiskAtlas" / "RiskAtlas_random_data.csv"
            df = pd.read_csv(csv_path) # len == 200
            self.dataset = df["original_prompt"].dropna().astype(str).tolist()
        elif self.dataset_name == "RiskAtlas_stealth_data":
            csv_path = self.data_root / "RiskAtlas" / "RiskAtlas_random_data.csv"
            df = pd.read_csv(csv_path) # len == 200
            self.dataset = df["stealth_prompt"].dropna().astype(str).tolist()
        elif self.dataset_name == "RiskAtlas_stealth_success_data":
            csv_path = self.data_root / "RiskAtlas" / "RiskAtlas_stealth_random_data.csv"
            df = pd.read_csv(csv_path) # len == 200
            self.dataset = df["stealth_prompt"].dropna().astype(str).tolist()
        else:
            self.logger.error(f"❌ unknown dataset: {self.dataset_name}")
            self.dataset = None
            return
        if len(self.dataset) > self.dataset_num:
            self.dataset = random.sample(self.dataset, min(self.dataset_num, len(self.dataset)))


    def attack(self):
        all_results = []
        success_num = 0
        total_ppl = 0.0
        if self.dataset:
            start_time = time.time()
            for i, data in enumerate(self.dataset):
                ppl = self._calculate_ppl(data)
                total_ppl = total_ppl + ppl
                
                retry_count = 0
                while retry_count < 5:
                    asr_result = self.asr_evaluator.check_asr(
                        rewritten_prompt=data,
                        category="",
                        original_prompt=None
                    )
                    if "Error" not in str(asr_result.evaluation_reasoning):
                        break
                    retry_count += 1
                    self.logger.warning(f"Retry {retry_count}: Error detected in evaluation_reasoning for prompt {i+1}")

                result = {
                    "id": i + 1,
                    "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
                    "dataset": self.dataset_name,
                    "target_model": self.target_model_nick_name,
                    "prompt": data,
                    "response": asr_result.target_response,
                    "asr_successful": asr_result.asr_successful,
                    "ppl": ppl,
                    "evaluation_reasoning": asr_result.evaluation_reasoning,
                }
                all_results.append(result)
                self.logger.info(f"Attack result: {result}")
                if asr_result.asr_successful:
                    success_num = success_num + 1

            total_num = len(self.dataset)
            total_time = time.time() - start_time
            performance_report = {
                "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
                "dataset": self.dataset_name,
                "target_model": self.target_model_nick_name,
                "total_data": total_num,
                "average_ppl": total_ppl / total_num,
                "success_num": success_num,
                "attack_success_rate": success_num / total_num,
                "total_time_seconds": total_time,
                "total_time_minutes": total_time / 60,
                "total_time_hours": total_time / 3600,
                "average_time_per_prompt": total_time / total_num,
            }
            self.logger.info(f"Performance report: {performance_report}")
            
            output_dir = SCRIPT_DIR / "exp1_result" / f"test_exp1_{self.target_model_nick_name}_{self.dataset_name}"
            output_dir.mkdir(parents=True, exist_ok=True)

            timestamp = time.strftime('%Y%m%d_%H%M%S')
            results_path = output_dir / f"results_{timestamp}.json"
            report_path = output_dir / f"performance_report_{timestamp}.json"

            with open(results_path, "w", encoding="utf-8") as f:
                json.dump(all_results, f, ensure_ascii=False, indent=2)

            with open(report_path, "w", encoding="utf-8") as f:
                json.dump(performance_report, f, ensure_ascii=False, indent=2)

            results_csv_path = output_dir / f"results_{timestamp}.csv"
            pd.DataFrame(all_results).to_csv(results_csv_path, index=False, encoding="utf-8")

            self.logger.info(f"CSV results saved to {results_csv_path}")
            self.logger.info(f"Results saved to {results_path}")
            self.logger.info(f"Performance report saved to {report_path}")


def main():
    set_random_seed(42, True)
    load_env_config()
    if not check_api_config():
        sys.exit(1)

    parser = argparse.ArgumentParser(description="exp 1: dataset VS model attack")

    parser.add_argument('--num', default=200, type=int, help='Number of dataset samples to use')

    # close: ["google/gemini-2.5-flash", "openai/gpt-4o-mini", "x-ai/grok-3-mini", "anthropic/claude-sonnet-4"]
    # open:["deepseek/deepseek-chat-v3.1:free", "mistralai/mixtral-8x7b-instruct", 
    # "qwen/qwen-2.5-7b-instruct", "meta-llama/llama-3.3-8b-instruct:free"]
    parser.add_argument('--model', default="openai/gpt-4o-mini", type=str, help='Target model name')


    # ['AdvBench', 'Do-Not-Answer', 'HarmfulQA', 'CategoricalHarmfulQA', 'HEx-PHI',
    #  'RiskAtlas_original_data', 'RiskAtlas_stealth_data', 'RiskAtlas_stealth_success_data']
    parser.add_argument('--dataset', default="AdvBench", type=str, help='dataset name')


    args = parser.parse_args()

    tester = Tester(dataset_name=args.dataset, target_model=args.model, dataset_num=args.num)
    tester.get_dataset()
    tester.attack()


if __name__ == "__main__":
    main()