import torch
import json
import logging
import argparse
import faiss
from transformers import AutoTokenizer, AutoModel
from task_tracker.utils.model import load_model
from task_tracker.config.models import models, cache_dir, database_dir
from task_tracker.CONFIG import current_risk
from task_tracker.utils.activations import BaseProcessor, ReconnaissanceProcessor, HijackingProcessor, UNAUTHORIZEDProcessor
from task_tracker.ragsys.load import load_json_file, mean_pooling
from task_tracker.mcp.client import MCPDatabaseClient, configure_mcp_client, build_index_remote
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer

# MCP配置
MCP_SERVER_URL = "http://localhost:8000"  # 修改为您的MCP服务器地址
USE_MCP = True  # 设置为True启用MCP客户端

# NOTE: Configuration
# Update with_priming to False if you want to generate activations without priming
with_priming: bool = True

def get_processor(current_risk):
    if current_risk == 'Reconnaissance' or current_risk == 'Exfiltration':
        return ReconnaissanceProcessor()
    # Add more conditions for other risks
    elif current_risk == 'Unauthorized_Access':
        return UNAUTHORIZEDProcessor()
    elif current_risk == 'Hijacking' or current_risk == 'Knowledge':
        return HijackingProcessor()
    else:
        return BaseProcessor()

def build_index(file_path, model_path, index_name, batch_size=500, use_mcp=USE_MCP):
    """
    构建FAISS索引 - 支持MCP客户端或本地构建
    """
    if use_mcp:
        # 使用MCP客户端远程构建索引
        client = configure_mcp_client(MCP_SERVER_URL)
        try:
            result = build_index_remote(file_path, model_path, index_name, client, batch_size)
            print(f"FAISS index built remotely: {index_name}")
            return result
        except Exception as e:
            print(f"MCP build failed, falling back to local: {e}")
            use_mcp = False
    
    if not use_mcp:
        # 本地构建索引（原有逻辑）
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        emb_model = AutoModel.from_pretrained(model_path).cuda()
        documents = load_json_file(file_path, "context", use_mcp=False)

        index = None
        for i in range(0, len(documents), batch_size):
            batch_documents = documents[i:i + batch_size]
            encoded_input = tokenizer(batch_documents, padding=True, truncation=True, return_tensors="pt").to('cuda')
            with torch.no_grad():
                model_output = emb_model(**encoded_input)
            embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
            embeddings = F.normalize(embeddings, p=2, dim=1)

            if index is None:
                dimension = embeddings.shape[1]
                index = faiss.IndexFlatL2(dimension)
            index.add(embeddings.cpu().numpy())

        faiss.write_index(index, index_name)
        print("FAISS index stored locally")
        return index

def main(model_name, model_path):
    # 初始化MCP客户端（如果启用）
    if USE_MCP:
        try:
            configure_mcp_client(MCP_SERVER_URL)
            print(f"MCP client configured for server: {MCP_SERVER_URL}")
        except Exception as e:
            print(f"Failed to configure MCP client: {e}")
            print("Continuing with local processing...")
    
    # Select the model configuration
    model = models[model_name]

    try:
        # Load the model and tokenizer
        loaded_model = load_model(
            model_path,
            cache_dir=model_path,
            torch_dtype=model.torch_dtype
        )
        model.tokenizer = loaded_model["tokenizer"]
        model.model = loaded_model["model"]

        # Check if multiple GPUs are available
        if torch.cuda.device_count() > 1:
            logging.info(f"Let's use {torch.cuda.device_count()} GPUs!")

        model.model.eval()

    except Exception as err:
        # Print memory summary for each GPU in case of an error
        for i in range(torch.cuda.device_count()):
            logging.info(f"Memory summary for GPU {i}:")
            logging.info(torch.cuda.memory_summary(device=i))
        raise err

    # 构建索引 - 现在支持MCP
    if current_risk == 'Hijacking' or current_risk == 'Knowledge':
        build_index(
            database_dir.replace('/database.json', '/corpus_clean.json'), 
            "/hub/huggingface/models/bert/bert-base-uncased", 
            index_name=database_dir.replace('/database.json', '/index_clean.faiss'),
            use_mcp=USE_MCP
        )
        build_index(
            database_dir.replace('/database.json', '/corpus_poisoned.json'), 
            "/hub/huggingface/models/bert/bert-base-uncased", 
            index_name=database_dir.replace('/database.json', '/index_poisoned.faiss'),
            use_mcp=USE_MCP
        )
    elif current_risk == 'Unauthorized_Access':
        pass
    else:
        build_index(
            database_dir, 
            "/hub/huggingface/models/bert/bert-base-uncased", 
            index_name=database_dir.replace('/database.json', '/index.faiss'),
            use_mcp=USE_MCP
        )
    
    processor = get_processor(current_risk)

    # Process data for activations
    for data_type, data in model.data.items():
        try:
            subset = json.load(open(data, "r"))
            
            # Determine directory and subset types based on data type
            if "train" in data_type:
                directory_name = "training"
                if current_risk == 'Reconnaissance':
                    subset_type = "clean" if "clean" in data_type else "poisoned"
                if current_risk == 'Exfiltration':
                    subset_type = "clean" if "clean" in data_type else "poisoned"
                if current_risk == 'Unauthorized_Access':
                    if "case" in data_type:
                        subset_type = "case"
                    elif "employee" in data_type:
                        subset_type = "employee"
                    elif "financial" in data_type:
                        subset_type = "financial"
                    else:
                        subset_type = "goods"
                if current_risk == 'Hijacking' or current_risk == 'Knowledge':
                    subset_type = "clean" if "clean" in data_type else "poisoned"
                    # 使用MCP客户端加载文档（如果可用）
                    documents = load_json_file(
                        database_dir.replace('/database.json', f'/corpus_{subset_type}.json'), 
                        "context", 
                        use_mcp=USE_MCP
                    )
                    processor.process_texts_in_batches_pairs(
                        dataset_subset=subset[model.start_idx:], 
                        model=model,
                        data_type=subset_type,
                        sub_dir_name=directory_name,
                        with_priming=with_priming,
                        documents=documents
                    )
                else:
                    processor.process_texts_in_batches_pairs(
                        dataset_subset=subset[model.start_idx:], 
                        model=model,
                        data_type=subset_type,
                        sub_dir_name=directory_name,
                        with_priming=with_priming,
                    )
            else:
                if current_risk == 'Reconnaissance':
                    subset_type = "clean" if "clean" in data_type else "poisoned"
                    directory_name = "validation" if "val" in data_type else "test"
                if current_risk == 'Exfiltration':
                    subset_type = "clean" if "clean" in data_type else "poisoned"
                    directory_name = "validation" if "val" in data_type else "test"
                if current_risk == 'Unauthorized_Access':
                    directory_name = "validation" if "val" in data_type else "test"
                    if "case" in data_type:
                        subset_type = "case"
                    elif "employee" in data_type:
                        subset_type = "employee"
                    elif "financial" in data_type:
                        subset_type = "financial"
                    else:
                        subset_type = "goods"
                if current_risk == 'Hijacking' or current_risk == 'Knowledge':
                    subset_type = "clean" if "clean" in data_type else "poisoned"
                    directory_name = "validation" if "val" in data_type else "test"
                    # 使用MCP客户端加载文档（如果可用）
                    documents = load_json_file(
                        database_dir.replace('/database.json', f'/corpus_{subset_type}.json'), 
                        "context", 
                        use_mcp=USE_MCP
                    )
                    processor.process_texts_in_batches_pairs(
                        dataset_subset=subset[model.start_idx:], 
                        model=model,
                        data_type=subset_type,
                        sub_dir_name=directory_name,
                        with_priming=with_priming,
                        documents=documents
                    )
                else:
                    processor.process_texts_in_batches_pairs(
                        dataset_subset=subset[model.start_idx:], 
                        model=model,
                        data_type=subset_type,
                        sub_dir_name=directory_name,
                        with_priming=with_priming,
                    )

        except json.JSONDecodeError as json_err:
            logging.error(f"Error decoding JSON for {data_type}: {json_err}")
        except Exception as data_err:
            logging.error(f"Error processing {data_type} data: {data_err}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process activations for a specified model.")
    parser.add_argument("--model_name", type=str, required=True, help="Name of the model to use")
    parser.add_argument("--model_path", type=str, required=True, help="Path to the model files")
    parser.add_argument("--mcp_server", type=str, default=MCP_SERVER_URL, help="MCP server URL")
    parser.add_argument("--use_mcp", action="store_true", default=USE_MCP, help="Use MCP client")
    args = parser.parse_args()

    # 更新配置
    MCP_SERVER_URL = args.mcp_server
    USE_MCP = args.use_mcp

    main(args.model_name, args.model_path)
