import torch
import json
import logging
import argparse
import faiss
from transformers import AutoTokenizer, AutoModel
from task_tracker.utils.model import load_model
from task_tracker.config.models import models, cache_dir, database_dir
from task_tracker.CONFIG import current_risk
from task_tracker.utils.activations import BaseProcessor, ReconnaissanceProcessor, HijackingProcessor, UNAUTHORIZEDProcessor
from task_tracker.ragsys.load import load_json_file, mean_pooling
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer

# NOTE: Configuration
# Update with_priming to False if you want to generate activations without priming
with_priming: bool = True

def get_processor(current_risk):
    if current_risk == 'Reconnaissance' or current_risk == 'Exfiltration':
        return ReconnaissanceProcessor()
    # Add more conditions for other risks
    elif current_risk == 'Unauthorized_Access':
        return UNAUTHORIZEDProcessor()
    elif current_risk == 'Hijacking' or current_risk == 'Knowledge':
        return HijackingProcessor()
    else:
        return BaseProcessor()

def build_index(file_path, model_path, index_name, batch_size=500):  # Reduce batch size
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    emb_model = AutoModel.from_pretrained(model_path).cuda()
    documents = load_json_file(file_path, "context")

    index = None
    for i in range(0, len(documents), batch_size):
        batch_documents = documents[i:i + batch_size]
        encoded_input = tokenizer(batch_documents, padding=True, truncation=True, return_tensors="pt").to('cuda')
        with torch.no_grad():
            model_output = emb_model(**encoded_input)
        embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
        embeddings = F.normalize(embeddings, p=2, dim=1)

        if index is None:
            dimension = embeddings.shape[1]
            index = faiss.IndexFlatL2(dimension)
        index.add(embeddings.cpu().numpy())

    faiss.write_index(index, index_name)
    print("FAISS index stored in index.faiss")

    return index

def main(model_name, model_path):
    # Select the model configuration
    model = models[model_name]

    try:
        # Load the model and tokenizer
        loaded_model = load_model(
            model_path,
            cache_dir=model_path,
            torch_dtype=model.torch_dtype
        )
        model.tokenizer = loaded_model["tokenizer"]
        model.model = loaded_model["model"]

        # Check if multiple GPUs are available
        if torch.cuda.device_count() > 1:
            logging.info(f"Let's use {torch.cuda.device_count()} GPUs!")

        model.model.eval()

    except Exception as err:
        # Print memory summary for each GPU in case of an error
        for i in range(torch.cuda.device_count()):
            logging.info(f"Memory summary for GPU {i}:")
            logging.info(torch.cuda.memory_summary(device=i))
        raise err

    if current_risk == 'Hijacking' or current_risk == 'Knowledge':
        build_index(database_dir.replace('/database.json', '/corpus_clean.json'), "/hub/huggingface/models/bert/bert-base-uncased", index_name=database_dir.replace('/database.json', '/index_clean.faiss'))
        build_index(database_dir.replace('/database.json', '/corpus_poisoned.json'), "/hub/huggingface/models/bert/bert-base-uncased", index_name=database_dir.replace('/database.json', '/index_poisoned.faiss'))
    elif current_risk == 'Unauthorized_Access':
        pass
    else:
        build_index(database_dir, "/hub/huggingface/models/bert/bert-base-uncased", index_name=database_dir.replace('/database.json', '/index.faiss'))
    processor = get_processor(current_risk)

    # Process data for activations
    for data_type, data in model.data.items():
        try:
            subset = json.load(open(data, "r"))
            
            # Determine directory and subset types based on data type
            if "train" in data_type:
                directory_name = "training"
                if current_risk == 'Reconnaissance':
                    subset_type = "clean" if "clean" in data_type else "poisoned"
                if current_risk == 'Exfiltration':
                    subset_type = "clean" if "clean" in data_type else "poisoned"
                if current_risk == 'Unauthorized_Access':
                    if "case" in data_type:
                        subset_type = "case"
                    elif "employee" in data_type:
                        subset_type = "employee"
                    elif "financial" in data_type:
                        subset_type = "financial"
                    else:
                        subset_type = "goods"
                if current_risk == 'Hijacking' or current_risk == 'Knowledge':
                    subset_type = "clean" if "clean" in data_type else "poisoned"
                    documents = load_json_file(database_dir.replace('/database.json', f'/corpus_{subset_type}.json'), "context")
                    processor.process_texts_in_batches_pairs(
                        dataset_subset=subset[model.start_idx:], 
                        model=model,
                        data_type=subset_type,
                        sub_dir_name=directory_name,
                        with_priming=with_priming,
                        documents=documents
                    )
                else:
                    processor.process_texts_in_batches_pairs(
                        dataset_subset=subset[model.start_idx:], 
                        model=model,
                        data_type=subset_type,
                        sub_dir_name=directory_name,
                        with_priming=with_priming,
                    )
            else:
                if current_risk == 'Reconnaissance':
                    subset_type = "clean" if "clean" in data_type else "poisoned"
                    directory_name = "validation" if "val" in data_type else "test"
                if current_risk == 'Exfiltration':
                    subset_type = "clean" if "clean" in data_type else "poisoned"
                    directory_name = "validation" if "val" in data_type else "test"
                if current_risk == 'Unauthorized_Access':
                    directory_name = "validation" if "val" in data_type else "test"
                    if "case" in data_type:
                        subset_type = "case"
                    elif "employee" in data_type:
                        subset_type = "employee"
                    elif "financial" in data_type:
                        subset_type = "financial"
                    else:
                        subset_type = "goods"
                if current_risk == 'Hijacking' or current_risk == 'Knowledge':
                    subset_type = "clean" if "clean" in data_type else "poisoned"
                    directory_name = "validation" if "val" in data_type else "test"
                    documents = load_json_file(database_dir.replace('/database.json', f'/corpus_{subset_type}.json'), "context")
                    processor.process_texts_in_batches_pairs(
                        dataset_subset=subset[model.start_idx:], 
                        model=model,
                        data_type=subset_type,
                        sub_dir_name=directory_name,
                        with_priming=with_priming,
                        documents=documents
                    )
                else:
                    processor.process_texts_in_batches_pairs(
                        dataset_subset=subset[model.start_idx:], 
                        model=model,
                        data_type=subset_type,
                        sub_dir_name=directory_name,
                        with_priming=with_priming,
                    )

        except json.JSONDecodeError as json_err:
            logging.error(f"Error decoding JSON for {data_type}: {json_err}")
        except Exception as data_err:
            logging.error(f"Error processing {data_type} data: {data_err}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process activations for a specified model.")
    parser.add_argument("--model_name", type=str, required=True, help="Name of the model to use")
    parser.add_argument("--model_path", type=str, required=True, help="Path to the model files")
    args = parser.parse_args()

    main(args.model_name, args.model_path)