import argparse
import os
import json
from tqdm import tqdm
import random
import numpy as np
import matplotlib.pyplot as plt
import logging

from src.models import create_model
from src.utils import load_beir_datasets, load_models
from src.utils import save_results, load_json, setup_seeds, clean_str, f1_score
from src.attack import Attacker
from src.prompts import wrap_prompt, wrap_prompt_few_shot

from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
import torch

from beir.reranking import Rerank
from beir.reranking.models import CrossEncoder


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def rerank_with_cross_encoder(cross_encoder, query, contexts, top_k=5, adv_prefix=False):
    # will be released in the future
    return topk_reranked_contexts, paraphrased_pairs


def main(args):
    set_seed(args.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    device = 'cuda'

    # load target queries and answers
    if args.eval_dataset == 'msmarco':
        args.split = 'dev'

    corpus, queries, qrels = load_beir_datasets(args.eval_dataset, args.split)
    if args.poison_method == "jamming":
        incorrect_answers = load_json(f"results/adv_corpus_generated/{args.poison_method}/{args.eval_dataset}_{args.eval_model_code}_{args.score_function}_{args.model_name}_top200_unjammed.json")
    else:
        incorrect_answers = load_json(f'results/adv_corpus_generated/{args.poison_method}/{args.eval_dataset}.json')
    incorrect_answers = list(incorrect_answers.values())
    print(f'len(incorrect_answers): {len(incorrect_answers)}')
    # args.repeat_times = int(np.floor(len(incorrect_answers) / args.M))

    # load BEIR top_k results
    if args.orig_beir_results is None:
        print(f"Please evaluate on BEIR first -- {args.eval_model_code} on {args.eval_dataset}")
        # Try to get beir eval results from ./beir_results
        print("Now try to get beir eval results from results/beir_results/...")
        if args.split in ['train', 'test']:
            if args.score_function == 'cos_sim':
                args.orig_beir_results = f"results/beir_results/{args.eval_dataset}-{args.eval_model_code}-cos.json"
            else:
                args.orig_beir_results = f"results/beir_results/{args.eval_dataset}-{args.eval_model_code}.json" 
        elif args.split == 'dev':
            if args.score_function == 'cos_sim':
                args.orig_beir_results = f"results/beir_results/{args.eval_dataset}-{args.eval_model_code}-dev-cos.json"
            else:
                args.orig_beir_results = f"results/beir_results/{args.eval_dataset}-{args.eval_model_code}-dev.json" 
        assert os.path.exists(args.orig_beir_results), f"Failed to get beir_results from {args.orig_beir_results}!"
        print(f"Automatically get beir_resutls from {args.orig_beir_results}.")
    with open(args.orig_beir_results, 'r') as f:
        results = json.load(f)
    # assert len(qrels) <= len(results)
    print('Total samples:', len(results))

    print(f"Loading model {args.eval_model_code}...")
    llm = create_model(args.model_config_path, args)

    prompts_q = []
    prompts_ground_truth = []
    prompts_benign = []
    prompts_poison = []

    total = len(incorrect_answers)

    context_dump_path = (
        f"results/context_dump/{args.poison_method}/{args.eval_dataset}_{args.eval_model_code}_{args.score_function}_contexts.json"
    )

    if os.path.exists(context_dump_path):
        print(f"Loading existing contexts from: {context_dump_path}")
        with open(context_dump_path, "r", encoding="utf-8") as f:
            context_data = json.load(f)
    else:
        print(f"No existing context file found at: {context_dump_path}")
        context_data = {}

    for i in tqdm(range(total), desc="Processing Queries"):
        query_id = incorrect_answers[i]['id']
        question = incorrect_answers[i]['question']
        poisoned_context = incorrect_answers[i]['adv_texts']
        correct_answer = incorrect_answers[i]['correct answer']
        incorrect_answer = incorrect_answers[i]['incorrect answer']

        gt_ids = list(qrels[query_id].keys())
        ground_truth = [corpus[id]["text"] for id in gt_ids]

        benign_context = []
        benign_topk_idx = list(results[query_id].keys())[:args.top_k]
        for idx in benign_topk_idx:
            benign_context.append(corpus[idx]['text'])

        question_context = wrap_prompt(question, prompt_id=0)
        ground_truth_prompt = wrap_prompt(question, ground_truth, prompt_id=4)
        benign_context_prompt = wrap_prompt(question, benign_context, prompt_id=4)
        poisoned_prompt = wrap_prompt(question, poisoned_context, prompt_id=4)

        prompts_q.append(question_context)
        prompts_ground_truth.append(ground_truth_prompt)
        prompts_benign.append(benign_context_prompt)
        prompts_poison.append(poisoned_prompt)

        if query_id not in context_data:
            context_data[query_id] = {
                "id": query_id,
                "question": question,
                "correct_answer": correct_answer,
                "incorrect_answer": incorrect_answer,
                "groud_truth_texts": ground_truth,
                "benign_texts": benign_context,
                "adv_texts": poisoned_context,
            }

    model, c_model, tokenizer, get_emb = load_models(args.eval_model_code)
    model.eval()
    model.to(device)
    c_model.eval()
    c_model.to(device) 
    attacker = Attacker(args,
                        model=model,
                        c_model=c_model,
                        tokenizer=tokenizer,
                        get_emb=get_emb) 

    cross_encoder_model = CrossEncoder(f"finetuned_cross_encoder/{args.reranker_model_code}", device=device)

    eval_outputs = []
    acc_list = {'q_only': 0,
                'q_benign': 0,
                'q_poison': 0,
                'q_groud_truth': 0,
                'q_benign_rerank': 0,
                'q_poison_rerank': 0}
    asr_list = {'q_only': 0,
                'q_benign': 0,
                'q_poison': 0,
                'q_groud_truth': 0,
                'q_benign_rerank': 0,
                'q_poison_rerank': 0}
    attack_count_list = []
    attack_count_list_adv_rerank = []
    ret_sublist = []

    for iteration in tqdm(range(args.repeat_times)):
        print(f"\n######################## Iteration: {iteration+1}/{args.repeat_times} ########################")

        target_queries_idx = range(iteration * args.M, (iteration + 1) * args.M)
        target_queries = [incorrect_answers[idx]['question'] for idx in target_queries_idx]

        adv_text_groups = [] # get the adv_text for the iter
        for i in target_queries_idx:
            top1_idx = list(results[incorrect_answers[i]['id']].keys())[0]
            top1_score = results[incorrect_answers[i]['id']][top1_idx]
            target_queries[i - iteration * args.M] = {'query': target_queries[i - iteration * args.M], 'top1_score': top1_score, 'id': incorrect_answers[i]['id']}
            # adv_per_query = min(args.adv_per_query, len(incorrect_answers[i]['adv_texts']))
            # adv_texts = incorrect_answers[i]['adv_texts'][:adv_per_query]

            adv_texts = incorrect_answers[i]['adv_texts']
            adv_per_query = args.adv_per_query
            if len(adv_texts) < adv_per_query:
                last_adv = adv_texts[-1] if adv_texts else ""
                adv_texts = adv_texts + [last_adv] * (adv_per_query - len(adv_texts))
            else:
                adv_texts = adv_texts[:adv_per_query]

            adv_text_groups.append(adv_texts)  
        adv_text_list = sum(adv_text_groups, [])  # flatten

        print(f"len(adv_text_list): {len(adv_text_list)}")

        adv_input = tokenizer(adv_text_list, padding=True, truncation=True, return_tensors="pt")
        adv_input = {key: value.cuda() for key, value in adv_input.items()}
        with torch.no_grad():
            adv_embs = get_emb(c_model, adv_input)

        for i in target_queries_idx:
            torch.cuda.empty_cache()
            query_id = incorrect_answers[i]['id']
            question = incorrect_answers[i]['question']
            correct_answer = incorrect_answers[i]['correct answer']
            incorrect_answer = incorrect_answers[i]['incorrect answer']

            topk_idx = list(results[query_id].keys())
            topk_results = [{'score': results[query_id][idx], 'context': corpus[idx]['text']} for idx in topk_idx]     

            topk_contents_benign_raw = [topk_results[j]["context"] for j in range(len(topk_results))]
            topk_contents_benign_rerank, _ = rerank_with_cross_encoder(cross_encoder_model, question, topk_contents_benign_raw, top_k=args.top_k, adv_prefix=True)        

            if args.attack_method not in [None, 'None']:
                query_input = tokenizer(question, padding=True, truncation=True, return_tensors="pt")
                query_input = {key: value.cuda() for key, value in query_input.items()}
                with torch.no_grad():
                    query_emb = get_emb(model, query_input)
                for j in range(len(adv_text_list)):
                    adv_emb = adv_embs[j, :].unsqueeze(0)
                    # similarity
                    if args.score_function == 'dot':
                        adv_sim = torch.mm(adv_emb, query_emb.T).cpu().item()
                    elif args.score_function == 'cos_sim':
                        adv_sim = torch.cosine_similarity(adv_emb, query_emb).cpu().item()

                    topk_results.append({'score': adv_sim, 'context': adv_text_list[j]})

                topk_results = sorted(topk_results, key=lambda x: float(x['score']), reverse=True)
                topk_contents = [topk_results[j]["context"] for j in range(args.top_k)]

                topk_contents_adv_raw = [topk_results[j]["context"] for j in range(args.topj)]
                topk_contents_adv_rerank, _ = rerank_with_cross_encoder(cross_encoder_model, question, topk_contents_adv_raw, top_k=args.top_k, adv_prefix=True)

                adv_text_set = set(adv_text_groups[i - iteration * args.M])

                cnt_from_adv = sum([i in adv_text_set for i in topk_contents])
                ret_sublist.append(cnt_from_adv)
            prompt_query_adv = wrap_prompt(question, topk_contents, prompt_id=4)
            prompt_query_benign_rerank = wrap_prompt(question, topk_contents_benign_rerank, prompt_id=4)
            prompt_query_adv_rerank = wrap_prompt(question, topk_contents_adv_rerank, prompt_id=4)

            contexts = {
                'q_only': prompts_q[i],
                'q_benign': prompts_benign[i],
                'q_poison': prompt_query_adv,
                'q_groud_truth': prompts_ground_truth[i],
                'q_benign_rerank': prompt_query_benign_rerank,
                'q_poison_rerank': prompt_query_adv_rerank
            }

            results_per_case = {}

            for case in contexts.keys():
                response = llm.query(contexts[case])
                pred = clean_str(response)

                correct = int(clean_str(correct_answer) in pred)
                asr = int(clean_str(incorrect_answer) in pred)

                acc_list[case] += correct
                asr_list[case] += asr

                results_per_case[case] = {
                    'output': response,
                    'correct': bool(correct),
                    'asr': bool(asr)
                }

            injected_adv = [i for i in topk_contents if i in adv_text_set]
            attack_count_list.append(len(injected_adv))
            injected_adv_rerank = [i for i in topk_contents_adv_rerank if i in adv_text_set]
            attack_count_list_adv_rerank.append(len(injected_adv_rerank))

            ground_truth = context_data[query_id].get('groud_truth_texts', [])
            gt_set = set(ground_truth)
            gt_in_topk = sum([ctx in gt_set for ctx in topk_contents])
            gt_in_benign_rerank = sum([ctx in gt_set for ctx in topk_contents_benign_rerank])
            gt_in_adv_rerank = sum([ctx in gt_set for ctx in topk_contents_adv_rerank])

            eval_outputs.append({
                "id": query_id,
                "question": context_data[query_id]['question'],
                "correct_answer": context_data[query_id]['correct_answer'],
                "incorrect_answer": context_data[query_id]['incorrect_answer'],
                "groud_truth_texts": context_data[query_id].get('groud_truth_texts', []),
                "benign_texts": context_data[query_id].get('benign_texts', []),
                "adv_texts": context_data[query_id]['adv_texts'],
                "prompts": contexts,
                "results": results_per_case,
                "topk_contents": topk_contents,
                "topk_contents_rerank": topk_contents_adv_rerank,
                "adv_text_set": list(adv_text_set),
                "num_adv_in_topk": len(injected_adv),
                "num_adv_in_topk_rerank": len(injected_adv_rerank),
                "total_ground_truths": len(gt_set),
                "ground_truths_in_topk": gt_in_topk,
                "ground_truths_in_benign_rerank": gt_in_benign_rerank,
                "ground_truths_in_adv_rerank": gt_in_adv_rerank,
            })


    os.makedirs(f"results/eval_outputs/adv_prefix_rerank/{args.poison_method}/adv_per_query{args.adv_per_query}", exist_ok=True)
    output_path = f"results/eval_outputs/adv_prefix_rerank/{args.poison_method}/adv_per_query{args.adv_per_query}/{args.eval_dataset}_{args.eval_model_code}_{args.score_function}_{args.reranker_model_code}_{args.model_name}_eval_results.json"
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(eval_outputs, f, indent=2, ensure_ascii=False)


    log_dir = f"results/eval_outputs/adv_prefix_rerank/{args.poison_method}/adv_per_query{args.adv_per_query}"
    os.makedirs(log_dir, exist_ok=True)
    log_file = os.path.join(
        log_dir,
        f"{args.eval_dataset}_{args.eval_model_code}_{args.score_function}_{args.reranker_model_code}_{args.model_name}_eval_results.log"
    )

    logger = logging.getLogger('eval_logger')
    logger.setLevel(logging.INFO)

    if not logger.handlers:
        file_handler = logging.FileHandler(log_file)
        file_handler.setLevel(logging.INFO)

        stream_handler = logging.StreamHandler()
        stream_handler.setLevel(logging.INFO)

        formatter = logging.Formatter('%(message)s')
        file_handler.setFormatter(formatter)
        stream_handler.setFormatter(formatter)

        logger.addHandler(file_handler)
        logger.addHandler(stream_handler)

    logger.info("\n=== Evaluation Summary ===")
    for case in contexts.keys():
        acc = round(acc_list[case] / args.repeat_times * args.M, 6)
        asr = round(asr_list[case] / args.repeat_times * args.M, 6)
        logger.info(f"{case}: Accuracy = {acc}, ASR = {asr}")
    logger.info(f"Average # of adv in top-k: {round(np.mean(attack_count_list), 4)}")
    logger.info(f"Average # of adv in rerank top-k: {round(np.mean(attack_count_list_adv_rerank), 4)}")
    logger.info(f"Saved to {log_file}")



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=42, help="Random seed for initialization.")
    parser.add_argument("--gpu", type=int, default=0, help="GPU ID to use.")

    parser.add_argument("--eval_dataset", type=str, default='nq', help="The dataset to evaluate on.")
    parser.add_argument("--split", type=str, default='test', help="The dataset split to evaluate on.")
    parser.add_argument("--eval_model_code", type=str, default='contriever', help="The model to evaluate.")
    parser.add_argument("--reranker_model_code", type=str, default='electra', help="The model to use for reranking.")
    parser.add_argument("--orig_beir_results", type=str, default=None, help="The original BEIR results to evaluate.")
    parser.add_argument("--score_function", type=str, default='dot', help="The score function to use.")
    parser.add_argument("--top_k", type=int, default=5, help="The top k results to retrieve.")
    parser.add_argument("--topj", type=int, default=100)

    parser.add_argument("--attack_method", type=str, default="LM_targeted", help="The attack method to use.")
    parser.add_argument("--adv_per_query", type=int, default=5, help="Number of adversarial texts per query.")
    parser.add_argument('--repeat_times', type=int, default=50, help='repeat several times to compute average')
    parser.add_argument('--M', type=int, default=10, help='one of our parameters, the number of target queries')

    parser.add_argument("--model_name", type=str, default='llama3_8b')
    parser.add_argument("--model_config_path", type=str, default='./model_configs/llama3_8b_config.json', help="The path to the model config file.")

    parser.add_argument("--poison_method", type=str, default='poisonrag_b', help="The method to use for generating adversarial texts.")

    parser.add_argument("--step", type=float, default=1.0)

    args = parser.parse_args()

    main(args)
