import pandas as pd
import sys
import argparse
from datasets import load_dataset
from pathlib import Path
import os
import time
import random
import yaml
import json
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

sys.path.insert(0, str(Path(__file__).parent.parent))
from src.utils.logger_utils import get_logger
from src.services.evaluation.asr_evaluator_service import ASREvaluatorService
from src.utils.random_seed_utils import set_random_seed
from src.models.openai_llm import OpenAICompatibleLLM

# 添加项目根目录到Python路径
SCRIPT_DIR = Path(__file__).parent
PROJECT_ROOT = SCRIPT_DIR.parent
sys.path.insert(0, str(PROJECT_ROOT))


def load_env_config():
    env_path = PROJECT_ROOT / '.env'
    if env_path.exists():
        from dotenv import load_dotenv
        load_dotenv(env_path)
        return True
    return False


def check_api_config():

    env_path = PROJECT_ROOT / '.env'

    if not env_path.exists():
        print("❌ Error: .env configuration file not found")
        print("📋 Please follow these steps to create the configuration file:")
        print("   1. Create a .env file in the project root directory")
        print("   2. Add the following content to the .env file:")
        print("      OPENROUTER_API_KEY=your_actual_api_key")
        print("   3. Get your API key at: https://openrouter.ai/keys")
        print(f"   4. Configuration file location: {env_path}")
        return False

    api_key = os.getenv('OPENROUTER_API_KEY')
    if not api_key or api_key == 'your_openrouter_api_key_here':
        print("❌ Error: OpenRouter API key is not configured or is using the default value")
        print("📋 Please edit the .env file and set OPENROUTER_API_KEY to your actual API key")
        print("   Get your API key at: https://openrouter.ai/keys")
        return False
    return True


class FinetuneModelEvaluator:
    def __init__(
            self, 
            dataset_name, 
            finetune_dataset_name, 
            model_name="Llama-3.1-8B-finetune", 
            dataset_num=200,
        ):

        config_dir = PROJECT_ROOT / "configs"
        self.config_dir = Path(config_dir)
        self.models_config_path = self.config_dir / "models.yaml"
        with open(self.models_config_path, 'r', encoding='utf-8') as f:
            self.models_config = yaml.safe_load(f)

        self.data_root = SCRIPT_DIR / "exp1_exp2_dataset"

        self.dataset_name = dataset_name
        self.model_name = model_name
        self.finetune_dataset_name = finetune_dataset_name
        self.dataset_num = dataset_num

        self.logger = get_logger(self.__class__.__name__)

        self.model = None
        self.tokenizer = None
        self.dataset = None

        self.alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""
        self.asr_evaluator = ASREvaluatorService(
            target_model=self.model_name, 
            openrouter_api_key=os.getenv("OPENROUTER_API_KEY")
        )


    def get_dataset(self):
        if self.dataset_name == "AdvBench":
            csv_path = self.data_root / "AdvBench.csv"
            df = pd.read_csv(csv_path) # len == 520
            self.dataset = df["goal"].dropna().astype(str).tolist()
        elif self.dataset_name == "Do-Not-Answer":
            ds = load_dataset("LibrAI/do-not-answer")
            self.dataset = ds['train']['question'] # len == 939
        elif self.dataset_name == "HarmfulQA":
            ds = load_dataset("declare-lab/HarmfulQA")
            self.dataset = ds['train']['question'] # len == 1960
        elif self.dataset_name == "CategoricalHarmfulQA":
            ds = load_dataset("declare-lab/CategoricalHarmfulQA")
            self.dataset = ds['en']['Question'] # len == 550
        elif self.dataset_name == "HEx-PHI":
            csv_dir = self.data_root / "HEx-PHI"
            all_sentences = []
            for csv_file in csv_dir.glob("*.csv"):
                df = pd.read_csv(csv_file)
                for col in df.columns:
                    all_sentences.extend(df[col].dropna().astype(str).tolist())
            self.dataset = all_sentences # len == 290
        elif self.dataset_name == "RiskAtlas_original_data":
            csv_path = self.data_root / "RiskAtlas" / "RiskAtlas_random_data.csv"
            df = pd.read_csv(csv_path) # len == 200
            self.dataset = df["original_prompt"].dropna().astype(str).tolist()
        elif self.dataset_name == "RiskAtlas_stealth_data":
            csv_path = self.data_root / "RiskAtlas" / "RiskAtlas_random_data.csv"
            df = pd.read_csv(csv_path) # len == 200
            self.dataset = df["stealth_prompt"].dropna().astype(str).tolist()
        elif self.dataset_name == "RiskAtlas_stealth_success_data":
            csv_path = self.data_root / "RiskAtlas" / "RiskAtlas_stealth_random_data.csv"
            df = pd.read_csv(csv_path) # len == 200
            self.dataset = df["stealth_prompt"].dropna().astype(str).tolist()
        else:
            self.logger.error(f"❌ unknown dataset: {self.dataset_name}")
            self.dataset = None
            return
        if len(self.dataset) > self.dataset_num:
            self.dataset = random.sample(self.dataset, min(self.dataset_num, len(self.dataset)))


    def load_model(self):
        self.model = OpenAICompatibleLLM(self.models_config["models"][self.model_name])
        test_prompt = self.alpaca_prompt.format("Who are you?", "", "")
        response = self.model.generate(test_prompt)
        print(f"Model {self.model_name} loaded successfully. Test response: {response}")


    def initialize(self):
        self.get_dataset()
        self.load_model()
        self.logger.info(f"finetune_dataset_name: {self.finetune_dataset_name}, attack dataset name: {self.dataset_name}, target_model_name: {self.model_name}")



    def _calculate_ppl(self, text: str) -> float:

        if not hasattr(self, '_ppl_model'):
            self.logger.info("           🔄 Initializing GPT-2 perplexity model...")
            self._ppl_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
            self._ppl_model = GPT2LMHeadModel.from_pretrained('gpt2')
            if self._ppl_tokenizer.pad_token is None:
                self._ppl_tokenizer.pad_token = self._ppl_tokenizer.eos_token
            self.logger.info("           ✅ GPT-2 perplexity model initialized")

        inputs = self._ppl_tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
        with torch.no_grad():
            outputs = self._ppl_model(**inputs, labels=inputs['input_ids'])
            loss = outputs.loss
            ppl = torch.exp(loss).item()
        
        return ppl


    def attack(self):
        all_results = []
        success_num = 0
        total_ppl = 0.0
        if self.dataset:
            start_time = time.time()
            for i, data in enumerate(self.dataset):
                ppl = self._calculate_ppl(data)
                total_ppl = total_ppl + ppl

                inputs = self.alpaca_prompt.format(data, "", "")
                response = self.model.generate(inputs)

                retry_count = 0
                while retry_count < 5:
                    asr_result = self.asr_evaluator.check_asr_with_prompt_and_response(
                        prompt_for_eval=data,
                        target_response=response,
                    )
                    if "Error" not in str(asr_result.evaluation_reasoning):
                        break
                    retry_count += 1
                    self.logger.warning(f"Retry {retry_count}: Error detected in evaluation_reasoning for prompt {i+1}")


                result = {
                    "id": i + 1,
                    "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
                    "dataset": self.dataset_name,
                    "target_model": self.model_name,
                    "prompt": data,
                    "response": response,
                    "asr_successful": asr_result.asr_successful,
                    "ppl": ppl,
                    "evaluation_reasoning": asr_result.evaluation_reasoning,
                }
                all_results.append(result)
                self.logger.info(f"Attack result: {result}")
                if asr_result.asr_successful:
                    success_num = success_num + 1

            total_num = len(self.dataset)
            total_time = time.time() - start_time
            performance_report = {
                "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
                "dataset": self.dataset_name,
                "target_model": self.model_name,
                "total_data": total_num,
                "average_ppl": total_ppl / total_num,
                "success_num": success_num,
                "attack_success_rate": success_num / total_num,
                "total_time_seconds": total_time,
                "total_time_minutes": total_time / 60,
                "total_time_hours": total_time / 3600,
                "average_time_per_prompt": total_time / total_num,
            }
            self.logger.info(f"Performance report: {performance_report}")

            output_dir = SCRIPT_DIR / "exp2_result" / f"exp2_{self.model_name}_{self.finetune_dataset_name}_{self.dataset_name}"
            output_dir.mkdir(parents=True, exist_ok=True)

            timestamp = time.strftime('%Y%m%d_%H%M%S')
            results_path = output_dir / f"results_{timestamp}.json"
            report_path = output_dir / f"performance_report_{timestamp}.json"

            with open(results_path, "w", encoding="utf-8") as f:
                json.dump(all_results, f, ensure_ascii=False, indent=2)

            with open(report_path, "w", encoding="utf-8") as f:
                json.dump(performance_report, f, ensure_ascii=False, indent=2)

            results_csv_path = output_dir / f"results_{timestamp}.csv"
            pd.DataFrame(all_results).to_csv(results_csv_path, index=False, encoding="utf-8")

            self.logger.info(f"CSV results saved to {results_csv_path}")
            self.logger.info(f"Results saved to {results_path}")
            self.logger.info(f"Performance report saved to {report_path}")


def main():
    set_random_seed(42, True)
    load_env_config()
    if not check_api_config():
        sys.exit(1)


    parser = argparse.ArgumentParser(description="exp 2: dataset VS model attack")

    parser.add_argument('--num', default=200, type=int, help='Number of dataset samples to use')

    parser.add_argument('--model', default="Llama-3.1-8B-finetune", type=str, help='Target model name')

    # ['AdvBench', 'Do-Not-Answer', 'HarmfulQA', 'CategoricalHarmfulQA', 'HEx-PHI',
    #  'RiskAtlas_original_data', 'RiskAtlas_stealth_data', 'RiskAtlas_stealth_success_data']
    parser.add_argument('--dataset', default='RiskAtlas_stealth_success_data', type=str, help='attack dataset name')

    # ['None', 'AdvBench', 'Do-Not-Answer', 'HarmfulQA', 'CategoricalHarmfulQA', 'HEx-PHI',
    #  'RiskAtlas_original_data', 'RiskAtlas_stealth_data', 'RiskAtlas_stealth_success_data']
    parser.add_argument('--finetune_dataset', default="RiskAtlas_original_data", type=str, help='finetune dataset name')
    
    args = parser.parse_args()

    evaluator = FinetuneModelEvaluator(
        dataset_name=args.dataset, 
        finetune_dataset_name=args.finetune_dataset,
        model_name=args.model,
        dataset_num=args.num,
    )
    evaluator.initialize()
    evaluator.attack()

if __name__ == "__main__":
    main()