import os
import torch
from tqdm import tqdm
import re
import logging
from datetime import datetime
from task_tracker.utils.data import format_prompts
from task_tracker.models.model import Model
from task_tracker.config.models import database_dir
from task_tracker.CONFIG import current_risk
import faiss
import numpy as np
from task_tracker.ragsys.load import load_json_file, mean_pooling
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)


def get_last_token_activations_single(
    text, model, start_layer: int = 1, token: int = -1
):
    """
    处理单个文本以提取最后一层 token 的激活值，同时管理 GPU 内存。
    """
    if "mistral" in model.name or "phi" in model.name:
        chat = [
            {
                "role": "user",
                "content": "you are a helpful assistant that will provide accurate answers to all questions. "
                + text,
            }
        ]
    else:
        chat = [
            {
                "role": "system",
                "content": "you are a helpful assistant that will provide accurate answers to all questions.",
            },
            {"role": "user", "content": text},
        ]

    inputs = model.tokenizer.apply_chat_template(
        chat, tokenize=True, add_generation_prompt=True, return_tensors="pt"
    )

    with torch.no_grad():
        try:
            inputs = inputs.cuda()
            outputs = model.model(inputs, output_hidden_states=True)

        except RuntimeError as e:
            if "CUDA out of memory" in str(e):
                print(
                    "CUDA out of memory. Printing memory status and attempting to clear cache."
                )
                for i in range(torch.cuda.device_count()):
                    print(f"Memory summary for GPU {i}:")
                    print(torch.cuda.memory_summary(device=i))
                torch.cuda.empty_cache()
            raise e

        end_layer = len(outputs["hidden_states"])
        last_tokens = []
        for i in range(start_layer, end_layer):
            # print(f'Extracting last token from layer {i}/{end_layer - 1}')
            last_tokens.append(outputs["hidden_states"][i][:, token].cpu())
        last_token_activations = torch.stack(last_tokens)

    return last_token_activations.squeeze(1)

class BaseProcessor:
    def process_texts_in_batches(
        dataset_subset,
        model: Model,
        data_type: str,
        sub_dir_name: str, 
        batch_size=1000,
        with_priming: bool = True,
    ):
        """
        Process texts in smaller batches and immediately write out each batch's activations.
        """
        if not os.path.exists(model.output_dir):
            os.makedirs(model.output_dir)
        
        output_subdir = os.path.join(model.output_dir,sub_dir_name)
        if not os.path.exists(output_subdir):
            os.makedirs(output_subdir)

        for i in tqdm(range(0, len(dataset_subset), batch_size)):

            batch_primary, batch_primary_clean, batch_primary_poisoned = (
                format_prompts(dataset_subset[i : i + batch_size], with_priming)
            )

            hidden_batch_primary = torch.stack(
                [
                    get_last_token_activations_single(text, model)
                    for text in batch_primary
                ]
            )
            hidden_batch_primary_clean = torch.stack(
                [
                    get_last_token_activations_single(text, model)
                    for text in batch_primary_clean
                ]
            )
            hidden_batch_primary_poisoned = torch.stack(
                [
                    get_last_token_activations_single(text, model)
                    for text in batch_primary_poisoned
                ]
            )

            hidden_batch = torch.stack(
                [
                    hidden_batch_primary,
                    hidden_batch_primary_clean,
                    hidden_batch_primary_poisoned,
                ]
            )  
            # Construct file path for this batch
            time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_filepath = os.path.join(
                output_subdir, f"{data_type}_hidden_states_{i}_{i+batch_size}_{time_str}.pt"
            )
            print(output_filepath)
            sanitized_output_filepath = re.sub(r'[^\x00-\x7F]+', '_', output_filepath)

            # Save this batch's activations to disk
            try:
                torch.save(hidden_batch, sanitized_output_filepath)
                print(f"File saved successfully to {sanitized_output_filepath}")
            except Exception as e:
                logging.error(f"Failed to save file to {sanitized_output_filepath}: {e}")
                print(f"An error occurred while saving the file: {e}")

class ReconnaissanceProcessor(BaseProcessor):
    '''Reconnaissance risk batch processor'''

    def format_prompt(self, dataset_items, with_priming: bool, index, emb_model, tokenizer, documents):
        batch = []
        eliciting_instructions = """ but first before you answer, please complete the following sentence by briefly writing each request(s) you received and you are going to execute next:  
            "All requests that I am going to execute now are:" """

        for dataset_item in dataset_items:
            encoded_input = tokenizer([dataset_item["task_prompt"]], padding=True, truncation=True, return_tensors='pt')
            with torch.no_grad():
                model_output = emb_model(**encoded_input)
            query_emb= mean_pooling(model_output, encoded_input['attention_mask'])
            query_emb = F.normalize(query_emb, p=2, dim=1)
            distances, indices = index.search(query_emb, 3)
            retrieved_content = "\n".join([documents[idx] for idx in indices[0]])
            orig_task = (
                    " <"
                    + dataset_item["sep_prompt"]
                    + "> "
                    + eliciting_instructions
                    + dataset_item["task_prompt"]
                )
            plus_text = (
                "<CONTEXT>"
                + retrieved_content
                + "</CONTEXT>"
            )
            
            
            batch.append(
                "here are your main requests: <MAIN> "
                + orig_task
                + " </MAIN>"
                + plus_text
            )


        return batch
    # def format_prompt(self, dataset_items, with_priming: bool, index, emb_model, tokenizer, documents):
    #     batch = []
    #     eliciting_instructions = """ but first before you answer, please complete the following sentence by briefly writing each request(s) you received and you are going to execute next:  
    #         "All requests that I am going to execute now are:" """

    #     for dataset_item in dataset_items:
    #         encoded_input = tokenizer([dataset_item["task_prompt"]], padding=True, truncation=True, return_tensors='pt')
    #         with torch.no_grad():
    #             model_output = emb_model(**encoded_input)
    #         query_emb= mean_pooling(model_output, encoded_input['attention_mask'])
    #         query_emb = F.normalize(query_emb, p=2, dim=1)
    #         distances, indices = index.search(query_emb, 3)
    #         retrieved_content = "\n".join([documents[idx] for idx in indices[0]])
    #         orig_task = (
    #                 " <"
    #                 + dataset_item["sep_prompt"]
    #                 + "> "
    #                 + eliciting_instructions
    #                 + dataset_item["task_prompt"]
    #             )
    #         plus_text = (
    #             "<CONTEXT>"
    #             + retrieved_content
    #             + "</CONTEXT>"
    #         )
            
            
    #         batch.append(
    #             "here are your main requests: <MAIN> "
    #             + orig_task
    #             + " </MAIN>"
    #             + plus_text
    #         )


    #     return batch


    def process_texts_in_batches_pairs(
        self,
        dataset_subset,
        model: Model,
        data_type: str,
        sub_dir_name: str, 
        batch_size=1000,
        with_priming: bool = True,
    ):
        """
        Process texts in smaller batches and immediately write out each batch's activations.
        This is for validation data that contains primary + text 
        """
        if not os.path.exists(model.output_dir):
            os.makedirs(model.output_dir)
        
        output_subdir = os.path.join(model.output_dir,sub_dir_name)
        if not os.path.exists(output_subdir):
            os.makedirs(output_subdir)
            
        index = faiss.read_index(database_dir.replace('/database.json', '/index.faiss'))
        tokenizer = AutoTokenizer.from_pretrained("/hub/huggingface/models/bert/bert-base-uncased")
        emb_model = AutoModel.from_pretrained("/hub/huggingface/models/bert/bert-base-uncased")
        documents = load_json_file(database_dir, "context")

        for i in tqdm(range(0, len(dataset_subset), batch_size)):

            batch = (
                self.format_prompt(dataset_subset[i : i + batch_size], with_priming, index, emb_model, tokenizer, documents)
            )

            hidden_batch = torch.stack(
                [
                    get_last_token_activations_single(text, model)
                    for text in batch
                ]
            )       

            hidden_batch_total = torch.stack(
                [
                    hidden_batch,
                ]
            )  
            # Construct file path for this batch
            time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_filepath = os.path.join(
                output_subdir, f"{data_type}_hidden_states_{i}_{i+batch_size}_{time_str}.pt"
            )
            print(output_filepath)
            sanitized_output_filepath = re.sub(r'[^\x00-\x7F]+', '_', output_filepath)

            # Save this batch's activations to disk
            try:
                torch.save(hidden_batch_total, sanitized_output_filepath)
                print(f"File saved successfully to {sanitized_output_filepath}")
            except Exception as e:
                logging.error(f"Failed to save file to {sanitized_output_filepath}: {e}")
                print(f"An error occurred while saving the file: {e}")

class UNAUTHORIZEDProcessor(BaseProcessor):
    '''Reconnaissance risk batch processor'''

    def format_prompt(self, dataset_items, with_priming: bool, index=None, emb_model=None, tokenizer=None, documents=None):
        batch = []
        eliciting_instructions = """ but first before you answer, please complete the following sentence by briefly writing each request(s) you received and you are going to execute next:  
            "All requests that I am going to execute now are:" """

        for dataset_item in dataset_items:
            # encoded_input = tokenizer([dataset_item["task_prompt"]], padding=True, truncation=True, return_tensors='pt')
            # with torch.no_grad():
            #     model_output = emb_model(**encoded_input)
            # query_emb= mean_pooling(model_output, encoded_input['attention_mask'])
            # query_emb = F.normalize(query_emb, p=2, dim=1)
            # distances, indices = index.search(query_emb, 3)
            # retrieved_content = "\n".join([documents[idx] for idx in indices[0]])
            orig_task = (
                    " <"
                    + dataset_item["sep_prompt"]
                    + "> "
                    + eliciting_instructions
                    + dataset_item["task_prompt"]
                )
            # plus_text = (
            #     "<CONTEXT>"
            #     + retrieved_content
            #     + "</CONTEXT>"
            # )
            
            
            batch.append(
                "here are your main requests: <MAIN> "
                + orig_task
                + " </MAIN>"
                # + plus_text
            )


        return batch

    def process_texts_in_batches_pairs(
        self,
        dataset_subset,
        model: Model,
        data_type: str,
        sub_dir_name: str, 
        batch_size=1000,
        with_priming: bool = True,
    ):
        """
        Process texts in smaller batches and immediately write out each batch's activations.
        This is for validation data that contains primary + text 
        """
        if not os.path.exists(model.output_dir):
            os.makedirs(model.output_dir)
        
        output_subdir = os.path.join(model.output_dir,sub_dir_name)
        if not os.path.exists(output_subdir):
            os.makedirs(output_subdir)
            
        # index = faiss.read_index(database_dir.replace('/database.json', '/index.faiss'))
        # tokenizer = AutoTokenizer.from_pretrained("/hub/huggingface/models/bert/bert-base-uncased")
        # emb_model = AutoModel.from_pretrained("/hub/huggingface/models/bert/bert-base-uncased")
        # documents = load_json_file(database_dir, "context")

        for i in tqdm(range(0, len(dataset_subset), batch_size)):

            batch = (
                self.format_prompt(dataset_subset[i : i + batch_size], with_priming)
            )

            hidden_batch = torch.stack(
                [
                    get_last_token_activations_single(text, model)
                    for text in batch
                ]
            )       

            hidden_batch_total = torch.stack(
                [
                    hidden_batch,
                ]
            )  
            # Construct file path for this batch
            time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_filepath = os.path.join(
                output_subdir, f"{data_type}_hidden_states_{i}_{i+batch_size}_{time_str}.pt"
            )
            print(output_filepath)
            sanitized_output_filepath = re.sub(r'[^\x00-\x7F]+', '_', output_filepath)

            # Save this batch's activations to disk
            try:
                torch.save(hidden_batch_total, sanitized_output_filepath)
                print(f"File saved successfully to {sanitized_output_filepath}")
            except Exception as e:
                logging.error(f"Failed to save file to {sanitized_output_filepath}: {e}")
                print(f"An error occurred while saving the file: {e}")

class HijackingProcessor(BaseProcessor):
    '''Reconnaissance risk batch processor'''

    def format_prompt(self, dataset_items, with_priming: bool, index, emb_model, tokenizer, documents):
        batch = []
        eliciting_instructions = """ but first before you answer, please complete the following sentence by briefly writing each request(s) you received and you are going to execute next:  
            "All requests that I am going to execute now are:" """

        for dataset_item in dataset_items:
            encoded_input = tokenizer([dataset_item["task_prompt"]], padding=True, truncation=True, return_tensors='pt')
            with torch.no_grad():
                model_output = emb_model(**encoded_input)
            query_emb= mean_pooling(model_output, encoded_input['attention_mask'])
            query_emb = F.normalize(query_emb, p=2, dim=1)
            distances, indices = index.search(query_emb, 3)
            retrieved_content = "\n".join([documents[idx] for idx in indices[0]])
            orig_task = (
                    " <"
                    + dataset_item["sep_prompt"]
                    + "> "
                    + eliciting_instructions
                    + dataset_item["task_prompt"]
                )
            plus_text = (
                "<CONTEXT>"
                + retrieved_content
                + "</CONTEXT>"
            )
            
            
            batch.append(
                "here are your main requests: <MAIN> "
                + orig_task
                + " </MAIN>"
                + plus_text
            )


        return batch

    def process_texts_in_batches_pairs(
        self,
        dataset_subset,
        model: Model,
        data_type: str,
        sub_dir_name: str, 
        documents,
        batch_size=500,
        with_priming: bool = True,
    ):
        """
        Process texts in smaller batches and immediately write out each batch's activations.
        This is for validation data that contains primary + text 
        """
        if not os.path.exists(model.output_dir):
            os.makedirs(model.output_dir)
        
        output_subdir = os.path.join(model.output_dir,sub_dir_name)
        if not os.path.exists(output_subdir):
            os.makedirs(output_subdir)
        
        if data_type == "clean":
            index = faiss.read_index(database_dir.replace('/database.json', '/index_clean.faiss'))
        else:
            index = faiss.read_index(database_dir.replace('/database.json', '/index_poisoned.faiss'))
        tokenizer = AutoTokenizer.from_pretrained("/hub/huggingface/models/bert/bert-base-uncased")
        emb_model = AutoModel.from_pretrained("/hub/huggingface/models/bert/bert-base-uncased")

        for i in tqdm(range(0, len(dataset_subset), batch_size)):

            batch = (
                self.format_prompt(dataset_subset[i : i + batch_size], with_priming, index, emb_model, tokenizer, documents)
            )

            hidden_batch = torch.stack(
                [
                    get_last_token_activations_single(text, model)
                    for text in batch
                ]
            )       

            hidden_batch_total = torch.stack(
                [
                    hidden_batch,
                ]
            )  
            # Construct file path for this batch
            time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_filepath = os.path.join(
                output_subdir, f"{data_type}_hidden_states_{i}_{i+batch_size}_{time_str}.pt"
            )
            print(output_filepath)
            sanitized_output_filepath = re.sub(r'[^\x00-\x7F]+', '_', output_filepath)

            # Save this batch's activations to disk
            try:
                torch.save(hidden_batch_total, sanitized_output_filepath)
                print(f"File saved successfully to {sanitized_output_filepath}")
            except Exception as e:
                logging.error(f"Failed to save file to {sanitized_output_filepath}: {e}")
                print(f"An error occurred while saving the file: {e}")