import argparse
import os
import json
from tqdm import tqdm
import random
import numpy as np
from src.models import create_model
from src.utils import load_beir_datasets, load_models
from src.utils import save_results, load_json, setup_seeds, clean_str, f1_score
from src.attack import Attacker
from src.prompts import wrap_prompt
import torch



def parse_args():
    parser = argparse.ArgumentParser(description='test')

    # Retriever and BEIR datasets
    parser.add_argument("--eval_model_code", type=str, default="contriever")
    parser.add_argument('--eval_dataset', type=str, default="nq", help='BEIR dataset to evaluate')
    parser.add_argument('--split', type=str, default='test')
    parser.add_argument("--orig_beir_results", type=str, default=None, help='Eval results of eval_model on the original beir eval_dataset')
    parser.add_argument("--query_results_dir", type=str, default='main')

    # LLM settings
    parser.add_argument('--model_config_path', default=None, type=str)
    parser.add_argument('--model_name', type=str, default='palm2')
    parser.add_argument('--top_k', type=int, default=5)
    parser.add_argument('--use_truth', type=str, default='False')
    parser.add_argument('--gpu_id', type=int, default=0)

    # attack
    parser.add_argument('--attack_method', type=str, default='LM_targeted')
    parser.add_argument('--adv_per_query', type=int, default=5, help='The number of adv texts for each target query.')
    parser.add_argument('--score_function', type=str, default='dot', choices=['dot', 'cos_sim'])
    parser.add_argument('--repeat_times', type=int, default=10, help='repeat several times to compute average')
    parser.add_argument('--M', type=int, default=10, help='one of our parameters, the number of target queries')
    parser.add_argument('--seed', type=int, default=12, help='Random seed')
    parser.add_argument("--name", type=str, default='debug', help="Name of log and result.")

    args = parser.parse_args()
    print(args)
    return args


def main():
    args = parse_args()
    torch.cuda.set_device(args.gpu_id)
    device = 'cuda'
    setup_seeds(args.seed)
    if args.model_config_path == None:
        args.model_config_path = f'model_configs/{args.model_name}_config.json'

    # load target queries and answers
    if args.eval_dataset == 'msmarco':
        corpus, queries, qrels = load_beir_datasets('msmarco', 'train')
        incorrect_answers = load_json(f'results/target_queries/{args.eval_dataset}.json')
        random.shuffle(incorrect_answers)    
    else:
        corpus, queries, qrels = load_beir_datasets(args.eval_dataset, args.split)
        incorrect_answers = load_json(f'results/target_queries/{args.eval_dataset}.json')

    # load BEIR top_k results  
    if args.orig_beir_results is None: 
        print(f"Please evaluate on BEIR first -- {args.eval_model_code} on {args.eval_dataset}")
        # Try to get beir eval results from ./beir_results
        print("Now try to get beir eval results from results/beir_results/...")
        if args.split == 'test':
            args.orig_beir_results = f"results/beir_results/{args.eval_dataset}-{args.eval_model_code}.json"
        elif args.split == 'dev':
            args.orig_beir_results = f"results/beir_results/{args.eval_dataset}-{args.eval_model_code}-dev.json"
        if args.score_function == 'cos_sim':
            args.orig_beir_results = f"results/beir_results/{args.eval_dataset}-{args.eval_model_code}-cos.json"
        assert os.path.exists(args.orig_beir_results), f"Failed to get beir_results from {args.orig_beir_results}!"
        print(f"Automatically get beir_resutls from {args.orig_beir_results}.")
    with open(args.orig_beir_results, 'r') as f:
        results = json.load(f)
    # assert len(qrels) <= len(results)
    print('Total samples:', len(results))
    for i, (k,v) in enumerate(results.items()):
        if i < 5:
            print(k, v)

    if args.use_truth == 'True':
        args.attack_method = None

    if args.attack_method not in [None, 'None']:
        # Load retrieval models
        model, c_model, tokenizer, get_emb = load_models(args.eval_model_code)
        # print("Model loaded", model)  # This is the contriever model
        # print("C_Model loaded", c_model)  # This is the contriever model again
        model.eval()
        model.to(device)
        c_model.eval()
        c_model.to(device) 
        attacker = Attacker(args,
                            model=model,
                            c_model=c_model,
                            tokenizer=tokenizer,
                            get_emb=get_emb) 
    
    llm = create_model(args.model_config_path)

    all_results = []
    asr_list=[]
    ret_list=[]

    for iter in range(args.repeat_times):
        print(f'######################## Iter: {iter+1}/{args.repeat_times} #######################')

        target_queries_idx = range(iter * args.M, iter * args.M + args.M)
        print("Target queries idx:", target_queries_idx)
        target_queries = [incorrect_answers[idx]['question'] for idx in target_queries_idx]
        print("Target queries:", target_queries)
        
        if args.attack_method not in [None, 'None']:
            for i in target_queries_idx:
                print(f"Target query {i}/{args.M}", f"Id: {incorrect_answers[i]['id']}")
                top1_idx = list(results[incorrect_answers[i]['id']].keys())[0]
                print(f"Top1 idx: {top1_idx}")
                top1_score = results[incorrect_answers[i]['id']][top1_idx]
                print(f"Top1 score: {top1_score}")
                target_queries[i - iter * args.M] = {'query': target_queries[i - iter * args.M], 'top1_score': top1_score, 'id': incorrect_answers[i]['id']}
                
            adv_text_groups = attacker.get_attack(target_queries)
            print("Adv text groups:", adv_text_groups)
            adv_text_list = sum(adv_text_groups, []) # convert 2D array to 1D array
            print("Len of adv_text_list:", len(adv_text_list))
            print("Adv text list:", adv_text_list) # Concatenated query+adv_text for each adv_text

            adv_input = tokenizer(adv_text_list, padding=True, truncation=True, return_tensors="pt")
            adv_input = {key: value.cuda() for key, value in adv_input.items()}
            with torch.no_grad():
                adv_embs = get_emb(c_model, adv_input)        
            print("Shape of adv_embs:", adv_embs.shape)


        asr_cnt=0
        ret_sublist=[]
        print("Iter:", iter)
        iter_results = []
        for i in target_queries_idx:
            iter_idx = i - iter * args.M # iter index
            print("Iter idx:", iter_idx)

            print(f'############# Target Question: {iter_idx+1}/{args.M} #############')
            question = incorrect_answers[i]['question']
            print(f'Question: {question}\n') 
            
            gt_ids = list(qrels[incorrect_answers[i]['id']].keys())
            ground_truth = [corpus[id]["text"] for id in gt_ids]
            incco_ans = incorrect_answers[i]['incorrect answer']            

            if args.use_truth == 'True':
                query_prompt = wrap_prompt(question, ground_truth, 4)
                response = llm.query(query_prompt)
                print(f"Output: {response}\n\n")
                iter_results.append(
                    {
                        "question": question,
                        "input_prompt": query_prompt,
                        "output": response,
                    }
                )  

            else: # topk
                topk_idx = list(results[incorrect_answers[i]['id']].keys())[:args.top_k]
                topk_results = [{'score': results[incorrect_answers[i]['id']][idx], 'context': corpus[idx]['text']} for idx in topk_idx]               

                if args.attack_method not in [None, 'None']: 
                    query_input = tokenizer(question, padding=True, truncation=True, return_tensors="pt")
                    query_input = {key: value.cuda() for key, value in query_input.items()}
                    with torch.no_grad():
                        query_emb = get_emb(model, query_input) 
                    for j in range(len(adv_text_list)):
                        adv_emb = adv_embs[j, :].unsqueeze(0) 
                        print("Shape of adv_emb:", adv_emb.shape)
                        # similarity     
                        if args.score_function == 'dot':
                            print("Shapes:")
                            print("adv_emb:", adv_emb.shape)
                            print("query_emb:", query_emb.shape)
                            print("adv_emb.T:", adv_emb.T.shape)
                            adv_sim = torch.mm(adv_emb, query_emb.T).cpu().item()
                        elif args.score_function == 'cos_sim':
                            adv_sim = torch.cosine_similarity(adv_emb, query_emb).cpu().item()
                                               
                        topk_results.append({'score': adv_sim, 'context': adv_text_list[j]})
                    
                    topk_results = sorted(topk_results, key=lambda x: float(x['score']), reverse=True)
                    print("Topk results after attack:")
                    for tpk in topk_results:
                        print(tpk)
                        print()
                    topk_contents = [topk_results[j]["context"] for j in range(args.top_k)]
                    # tracking the num of adv_text in topk
                    adv_text_set = set(adv_text_groups[iter_idx])

                    cnt_from_adv=sum([i in adv_text_set for i in topk_contents])
                    ret_sublist.append(cnt_from_adv)
                query_prompt = wrap_prompt(question, topk_contents, prompt_id=4)
                print("------------------\nQUERY PROMPT:\n")
                print(query_prompt) 
                print("------------------\n")

                response = llm.query(query_prompt)

                print(f'Output: {response}\n\n')
                injected_adv=[i for i in topk_contents if i in adv_text_set]

                
                print("\nRetrieved topk_results:")
                for j in topk_results:
                    print(j)
                    print()

                print("\nRetrieved contexts")
                for j in topk_contents:
                    if j in adv_text_set:
                        print(f"Injected: {j}")
                    else:
                        print("Original: ", j)
                
                iter_results.append(
                    {
                        "id":incorrect_answers[i]['id'],
                        "question": question,
                        "injected_adv": injected_adv,
                        "input_prompt": query_prompt,
                        "output_poison": response,
                        "incorrect_answer": incco_ans,
                        "answer": incorrect_answers[i]['correct answer']
                    }
                )

                if clean_str(incco_ans) in clean_str(response):
                    asr_cnt += 1  

        asr_list.append(asr_cnt)
        ret_list.append(ret_sublist)

        all_results.append({f'iter_{iter}': iter_results})
        save_results(all_results, args.query_results_dir, args.name)
        print(f'Saving iter results to results/query_results/{args.query_results_dir}/{args.name}.json')


    asr = np.array(asr_list) / args.M
    asr_mean = round(np.mean(asr), 2)
    ret_precision_array = np.array(ret_list) / args.top_k
    ret_precision_mean=round(np.mean(ret_precision_array), 2)
    ret_recall_array = np.array(ret_list) / args.adv_per_query
    ret_recall_mean=round(np.mean(ret_recall_array), 2)

    ret_f1_array=f1_score(ret_precision_array, ret_recall_array)
    ret_f1_mean=round(np.mean(ret_f1_array), 2)
  
    print(f"ASR: {asr}")
    print(f"ASR Mean: {asr_mean}\n") 

    print(f"Ret: {ret_list}")
    print(f"Precision mean: {ret_precision_mean}")
    print(f"Recall mean: {ret_recall_mean}")
    print(f"F1 mean: {ret_f1_mean}\n")

    print(f"Ending...")


if __name__ == '__main__':
    main()