import torch
from refined.inference.processor import Refined
from sentence_transformers import SentenceTransformer, CrossEncoder
from transformers import BertTokenizerFast, BertForTokenClassification, AutoModelForCausalLM, AutoTokenizer

import json
import argparse
from time import time
from tqdm import tqdm

import nltk
from nltk.corpus import stopwords as nltk_sw
nltk.download('stopwords')
nltk.download('punkt_tab')

from editor import Editor

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name',
                        type=str,
                        default='vicuna-7b',
                        choices=['qwen2.5-7b', 'qwen2.5-1.5b', 'gptj', 'vicuna-7b', 'falcon-7b', 'mistral-7b-v0.3', 'llama3.1-8b'])
    parser.add_argument('--bert_pt_path',
                        type=str,
                        default='./r_extractor/mquake_extractor.pt')
    parser.add_argument('--ds_name',
                        type=str,
                        default='mquake',
                        choices=['mquake', 'rippleedit', 'kebench'])
    parser.add_argument('--ds_json_path',
                        type=str,
                        default='./data/MQuAKE/MQuAKE-CF-3k.json')
    parser.add_argument('--use_gpu',
                        action='store_true')
    parser.add_argument('--gpu_num',
                        type=int,
                        default=0)
    parser.add_argument('--hf_token',
                        type=str,
                        default='')
    args = parser.parse_args()
    args = set_names(args)
    return args

def set_names(args):
    args.device = f'cuda:{args.gpu_num}' if args.use_gpu else 'cpu'

    if args.model_name == 'qwen2.5-7b':
        args.model_name = 'Qwen/Qwen2.5-7B'
    elif args.model_name == 'qwen2.5-1.5b':
        args.model_name = 'Qwen/Qwen2.5-1.5B'
    elif args.model_name == 'gptj':
        args.model_name = 'EleutherAI/gpt-j-6B'
    elif args.model_name == 'falcon-7b':
        args.model_name = 'tiiuae/falcon-7b'
    elif args.model_name == 'vicuna-7b':
        args.model_name = 'lmsys/vicuna-7b-v1.5'
    elif args.model_name == 'mistral-7b-v0.3':
        args.model_name = 'mistralai/Mistral-7B-v0.3'
    elif args.model_name == 'llama3.1-8b':
        args.model_name = 'meta-llama/Llama-3.1-8B'
    
    return args

def main(args):
    llm = AutoModelForCausalLM.from_pretrained(args.model_name, low_cpu_mem_usage=False, device_map=args.device)
    llm_tok = AutoTokenizer.from_pretrained(args.model_name, use_fast=False)
    llm_tok.pad_token = llm_tok.eos_token
    llm.generation_config.pad_token_id = llm_tok.pad_token_id

    embedder = SentenceTransformer('Qwen/Qwen3-Embedding-8B', device=args.device)
    ranker = CrossEncoder("tomaarsen/reranker-ModernBERT-large-gooaq-bce", device=args.device)
    
    bert_tok = BertTokenizerFast.from_pretrained('bert-base-cased', use_fast=False, model_max_length=512)
    bert = BertForTokenClassification.from_pretrained(args.bert_pt_path, num_labels=5).to(args.device)

    ner_model = Refined.from_pretrained(model_name='wikipedia_model_with_numbers', entity_set='wikipedia')

    with open(args.ds_json_path, 'r') as file:
        ds = json.load(file)
        ds = ds[:100] + ds[1000: 1100] + ds[1900:2000]

    with open('./prompts/triple_completion.txt', 'r') as file:
        triple_compl = file.read()

    editor = Editor(llm=llm,
                    llm_tok=llm_tok,
                    embedding_model=embedder,
                    reranker=ranker,
                    relation_extractor=bert,
                    relation_tokenizer=bert_tok,
                    triple_compl=triple_compl,
                    device=args.device)
                    
    stopwords = set(nltk_sw.words('english'))
    stopwords.add('name')

    if args.ds_name == 'mquake':
        o_star = []
        requested_edits = []
        for c in tqdm(ds, desc='Loading Edits...'):
            for edit in c['requested_rewrite']:
                o = edit['target_new']['str']       
                o_star.append(o)

                edit_prompt = edit['prompt'].lower().replace('{}', '')
                edit_prompt = nltk.tokenize.word_tokenize(edit_prompt)
                edit_prompt = [w.lower() for w in edit_prompt if w.lower() not in stopwords]
                edit_prompt = ' '.join(edit_prompt)
                edit_prompt = ' '.join([edit['subject'], edit_prompt])
                requested_edits.append(edit_prompt.lower())
    
    elif args.ds_name == 'rippleedit':
        raise Exception('Not yet implemented: RippleEdit dataset.')

    elif args.ds_name == 'kebench':
        with open(args.ds_json_path.replace('multi_hop_knowledge', 'edited_knowledge'), 'r') as file:
            kb_edits = json.load(file)

        o_star = []
        requested_edits = []
        for c in kb_edits:
            o = c['new_first_answer']
            o_star.append(o)

            edit_prompt = c['first_question'].replace('?', '').replace(o, '').split(' ')
            entity = [word for word in edit_prompt if any(character.isupper() for character in word)]
            entity = [word for word in entity if word.lower() not in stopwords]
            edit_prompt = [word.lower() for word in edit_prompt if word.lower() not in stopwords]
            edit_prompt = ' '.join(edit_prompt)
            if len(entity) > 0:
                entity = ' '.join(entity)
                edit_prompt = edit_prompt.replace(entity.lower(), '')
                edit_prompt = ' '.join([entity, edit_prompt]).replace('  ', ' ')
            requested_edits.append(edit_prompt.lower())

    editor.add_edits(edits=requested_edits, answers=o_star)


    for k in tqdm(torch.arange(5, 21, 5), desc='K'):
        for tau in tqdm(torch.arange(0.0, 1.01, 0.1), desc='Tau'):
            for offset in tqdm(torch.arange(0.0, 0.26, 0.05), desc='Tau Offest'):
                total = 0
                q_acc = 0
                i_acc = 0

                total_hops = {'1' : 0, '2' : 0, '3' : 0, '4' : 0}
                total_edits = {'1' : 0, '2' : 0, '3' : 0, '4' : 0}

                q_acc_hop = {'1' : 0, '2' : 0, '3' : 0, '4' : 0}
                q_acc_edit = {'1' : 0, '2' : 0, '3' : 0, '4' : 0}

                c_acc_hop = {'1' : 0, '2' : 0, '3' : 0, '4' : 0}
                c_acc_edit = {'1' : 0, '2' : 0, '3' : 0, '4' : 0}

                seconds_hop = {'1' : 0, '2' : 0, '3' : 0, '4' : 0}
                seconds_edit = {'1' : 0, '2' : 0, '3' : 0, '4' : 0}

                chain_over_under_hop = {'1' : {'under' : 0, 'correct' : 0, 'over' : 0}, '2' : {'under' : 0, 'correct' : 0, 'over' : 0}, '3' : {'under' : 0, 'correct' : 0, 'over' : 0}, '4' : {'under' : 0, 'correct' : 0, 'over' : 0}}
                chain_over_under_edit = {'1' : {'under' : 0, 'correct' : 0, 'over' : 0}, '2' : {'under' : 0, 'correct' : 0, 'over' : 0}, '3' : {'under' : 0, 'correct' : 0, 'over' : 0}, '4' : {'under' : 0, 'correct' : 0, 'over' : 0}}
                for c in tqdm(ds, desc=f'Editor with {args.model_name} on {args.ds_json_path}'):
                    if args.ds_name == 'mquake':
                        current_num_hops = f"{len(c['single_hops'])}"
                        current_num_edits = f"{len(c['requested_rewrite'])}"

                        questions = c['questions']
                    
                    elif args.ds_name == 'rippleedit':
                        raise Exception('Ripple Edit not implemented yet.')

                    elif args.ds_name == 'kebench':
                        current_num_hops = '2'
                        current_num_edits = '1'

                        questions = [c['two_hop_question']]

                    case_correct = False
                    total += 1

                    for q in questions:
                        start = time()
                        final_answer, chain_length = editor.answer_question(q, ner_model, tau=tau, offset=offset, k=k)
                        run_time = time() - start

                        total_hops[current_num_hops] += 1
                        total_edits[current_num_edits] += 1

                        seconds_hop[current_num_hops] += run_time
                        seconds_edit[current_num_edits] += run_time

                        if chain_length < int(current_num_hops):
                            chain_over_under_hop[current_num_hops]['under'] += 1
                            chain_over_under_edit[current_num_edits]['under'] += 1
                        elif chain_length > int(current_num_hops):
                            chain_over_under_hop[current_num_hops]['over'] += 1
                            chain_over_under_edit[current_num_edits]['over'] += 1
                        else:
                            chain_over_under_hop[current_num_hops]['correct'] += 1
                            chain_over_under_edit[current_num_edits]['correct'] += 1

                        if args.ds_name == 'mquake':
                            true_answer = c['new_answer']
                            answer_aliases = c['new_answer_alias']

                        elif args.ds_name == 'rippleedit':
                            raise Exception('Ripple Edit not implemented yet.')
                        
                        elif args.ds_name == 'kebench':
                            true_answer = c['new_two_hop_answer']
                            answer_aliases = c['new_two_hop_answer_aliases']

                        if final_answer == true_answer or final_answer in answer_aliases:
                            q_acc += 1

                            q_acc_hop[current_num_hops] += 1
                            q_acc_edit[current_num_edits] += 1

                            if not case_correct:
                                case_correct = True
                                i_acc += 1

                                c_acc_hop[current_num_hops] += 1
                                c_acc_edit[current_num_edits] += 1

                print(f'K: {k}, Tau: {tau}, Tau Offest: {offset}\n')
                print(f'\nDataset Size     : {len(ds)} Cases and {len(ds)*3} Questions')

                print(f'Case Accuracy    : {i_acc} / {total} or {i_acc/total*100:.2f}')
                print(f'Question Accuracy: {q_acc} / {total*3} or {q_acc/(total*3)*100:.2f}\n')

                print(f'Total Number of Hops : {total_hops}')
                print(f'Total Number of Edits: {total_edits}\n')

                print(f'Per-Question Hop Accuracy : {q_acc_hop}')
                print(f'Per-Question Edit Accuracy: {q_acc_edit}\n')

                print(f'Per-Case Hop Accuracy : {c_acc_hop}')
                print(f'Per-Case Edit Accuracy: {c_acc_edit}\n')

                print(f'Total Seconds per Hop : {seconds_hop}')
                print(f'Total Seconds per Edit: {seconds_edit}\n')

                print(f'Chain Correctness per Hop : {chain_over_under_hop}')
                print(f'Chain Correctness per Edit: {chain_over_under_edit}\n')
    

if __name__ == '__main__':
    import warnings
    from huggingface_hub import login

    warnings.filterwarnings('ignore')

    args = parse_args()
    login(args.hf_token)
    
    main(args)
