import torch
from datasets import load_dataset
from refined.inference.processor import Refined
from sentence_transformers import SentenceTransformer, CrossEncoder
from transformers import BertTokenizerFast, BertForTokenClassification, AutoModelForCausalLM, AutoTokenizer

import json
import argparse
from time import time
from tqdm import tqdm

import nltk
from nltk.corpus import stopwords as nltk_sw
nltk.download('stopwords')
nltk.download('punkt_tab')

from editor import Editor

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name',
                        type=str,
                        default='vicuna-7b',
                        choices=['qwen2.5-7b', 'qwen2.5-1.5b', 'gptj', 'vicuna-7b', 'falcon-7b', 'mistral-7b-v0.3', 'llama3.1-8b'])
    parser.add_argument('--bert_pt_path',
                        type=str,
                        default='./r_extractor/mquake_extractor.pt')
    parser.add_argument('--ds_name',
                        type=str,
                        default='mquake',
                        choices=['mquake', 'rippleedit', 'kebench', 'mquake-remastered'])
    parser.add_argument('--ds_json_path',
                        type=str,
                        default='./data/MQuAKE/MQuAKE-CF-3k.json')
    parser.add_argument('--tau',
                        type=float,
                        default=0.8)
    parser.add_argument('--tau_offset',
                        type=float,
                        default=0.1)
    parser.add_argument('--k',
                        type=int,
                        default=10)
    parser.add_argument('--use_gpu',
                        action='store_true')
    parser.add_argument('--gpu_num',
                        type=int,
                        default=0)
    parser.add_argument('--hf_token',
                        type=str,
                        default='')
    args = parser.parse_args()
    args = set_names(args)
    return args

def set_names(args):
    args.device = f'cuda:{args.gpu_num}' if args.use_gpu else 'cpu'

    if args.model_name == 'qwen2.5-7b':
        args.model_name = 'Qwen/Qwen2.5-7B'
    elif args.model_name == 'qwen2.5-1.5b':
        args.model_name = 'Qwen/Qwen2.5-1.5B'
    elif args.model_name == 'gptj':
        args.model_name = 'EleutherAI/gpt-j-6B'
    elif args.model_name == 'falcon-7b':
        args.model_name = 'tiiuae/falcon-7b'
    elif args.model_name == 'vicuna-7b':
        args.model_name = 'lmsys/vicuna-7b-v1.5'
    elif args.model_name == 'mistral-7b-v0.3':
        args.model_name = 'mistralai/Mistral-7B-v0.3'
    elif args.model_name == 'llama3.1-8b':
        args.model_name = 'meta-llama/Llama-3.1-8B'
    
    return args

def main(args):
    llm = AutoModelForCausalLM.from_pretrained(args.model_name, low_cpu_mem_usage=False, device_map=args.device)
    llm_tok = AutoTokenizer.from_pretrained(args.model_name, use_fast=False)
    llm_tok.pad_token = llm_tok.eos_token
    llm.generation_config.pad_token_id = llm_tok.pad_token_id

    embedder = SentenceTransformer('Qwen/Qwen3-Embedding-8B', device=args.device)
    ranker = CrossEncoder("tomaarsen/reranker-ModernBERT-large-gooaq-bce", device=args.device)
    
    bert_tok = BertTokenizerFast.from_pretrained('bert-base-cased', use_fast=False, model_max_length=512)
    bert = BertForTokenClassification.from_pretrained(args.bert_pt_path, num_labels=5).to(args.device)

    ner_model = Refined.from_pretrained(model_name='wikipedia_model_with_numbers', entity_set='wikipedia')

    if 'remastered' in args.ds_name:
        ds = load_dataset('henryzhongsc/MQuAKE-Remastered', split='CF6334')
        ds1 = [ds[i] for i in range(2000, 3000)]
        ds2 = [ds[i] for i in range(5000, 6000)]
        ds3 = [ds[i] for i in range(8000, 9000)]
        ds = ds1 + ds2 + ds3
    else:
        with open(args.ds_json_path, 'r') as file:
            ds = json.load(file)

    with open('./prompts/triple_completion.txt', 'r') as file:
        triple_compl = file.read()

    editor = Editor(llm=llm,
                    llm_tok=llm_tok,
                    embedding_model=embedder,
                    reranker=ranker,
                    relation_extractor=bert,
                    relation_tokenizer=bert_tok,
                    triple_compl=triple_compl,
                    device=args.device)
                    
    stopwords = set(nltk_sw.words('english'))
    stopwords.add('name')

    if 'mquake' in args.ds_name:
        o_star = []
        requested_edits = []
        for c in tqdm(ds, desc='Loading Edits...'):
            for edit in c['requested_rewrite']:
                if 'remastered' in args.ds_name:
                    o = edit['target_new_str'] 
                else:
                    o = edit['target_new']['str']       
                o_star.append(o)

                edit_prompt = edit['prompt'].lower().replace('{}', '')
                edit_prompt = nltk.tokenize.word_tokenize(edit_prompt)
                edit_prompt = [w.lower() for w in edit_prompt if w.lower() not in stopwords]
                edit_prompt = ' '.join(edit_prompt)
                edit_prompt = ' '.join([edit['subject'], edit_prompt])
                requested_edits.append(edit_prompt.lower())
    
    elif args.ds_name == 'rippleedit':
        raise Exception('Not yet implemented: RippleEdit dataset.')

    elif args.ds_name == 'kebench':
        with open(args.ds_json_path.replace('multi_hop_knowledge', 'edited_knowledge'), 'r') as file:
            kb_edits = json.load(file)

        o_star = []
        requested_edits = []
        for c in kb_edits:
            o = c['new_first_answer']
            o_star.append(o)

            edit_prompt = c['first_question'].replace('?', '').replace(o, '').split(' ')
            entity = [word for word in edit_prompt if any(character.isupper() for character in word)]
            entity = [word for word in entity if word.lower() not in stopwords]
            edit_prompt = [word.lower() for word in edit_prompt if word.lower() not in stopwords]
            edit_prompt = ' '.join(edit_prompt)
            if len(entity) > 0:
                entity = ' '.join(entity)
                edit_prompt = edit_prompt.replace(entity.lower(), '')
                edit_prompt = ' '.join([entity, edit_prompt]).replace('  ', ' ')
            requested_edits.append(edit_prompt.lower())

    editor.add_edits(edits=requested_edits, answers=o_star)

    total = 0
    q_acc = 0
    i_acc = 0

    total_hops = {'1' : 0, '2' : 0, '3' : 0, '4' : 0}
    total_edits = {'1' : 0, '2' : 0, '3' : 0, '4' : 0}

    q_acc_hop = {'1' : 0, '2' : 0, '3' : 0, '4' : 0}
    q_acc_edit = {'1' : 0, '2' : 0, '3' : 0, '4' : 0}

    c_acc_hop = {'1' : 0, '2' : 0, '3' : 0, '4' : 0}
    c_acc_edit = {'1' : 0, '2' : 0, '3' : 0, '4' : 0}

    seconds_hop = {'1' : 0, '2' : 0, '3' : 0, '4' : 0}
    seconds_edit = {'1' : 0, '2' : 0, '3' : 0, '4' : 0}

    chain_over_under_hop = {'1' : {'under' : 0, 'correct' : 0, 'over' : 0}, '2' : {'under' : 0, 'correct' : 0, 'over' : 0}, '3' : {'under' : 0, 'correct' : 0, 'over' : 0}, '4' : {'under' : 0, 'correct' : 0, 'over' : 0}}
    chain_over_under_edit = {'1' : {'under' : 0, 'correct' : 0, 'over' : 0}, '2' : {'under' : 0, 'correct' : 0, 'over' : 0}, '3' : {'under' : 0, 'correct' : 0, 'over' : 0}, '4' : {'under' : 0, 'correct' : 0, 'over' : 0}}

    in_tokens = {'2': {'1': 0, '2': 0}, '3': {'1': 0, '2': 0, '3': 0}, '4': {'1': 0, '2': 0, '3': 0, '4': 0}}
    out_tokens = {'2': {'1': 0, '2': 0}, '3': {'1': 0, '2': 0, '3': 0}, '4': {'1': 0, '2': 0, '3': 0, '4': 0}}

    for c in tqdm(ds, desc=f'Editor with {args.model_name} on {args.ds_json_path}'):
        if 'mquake' in args.ds_name:
            current_num_hops = f"{len(c['single_hops'])}"
            current_num_edits = f"{len(c['requested_rewrite'])}"

            questions = c['questions']
        
        elif args.ds_name == 'rippleedit':
            raise Exception('Ripple Edit not implemented yet.')

        elif args.ds_name == 'kebench':
            current_num_hops = '2'
            current_num_edits = '1'

            questions = [c['two_hop_question']]

        case_correct = False
        total += 1

        for q in questions:
            start = time()
            final_answer, chain_length = editor.answer_question(q, ner_model, tau=args.tau, offset=args.tau_offset, k=args.k)
            run_time = time() - start

            in_tokens[current_num_hops][current_num_edits] += editor.in_tok
            out_tokens[current_num_hops][current_num_edits] += editor.out_tok
            editor.in_tok = 0
            editor.out_tok = 0

            total_hops[current_num_hops] += 1
            total_edits[current_num_edits] += 1

            seconds_hop[current_num_hops] += run_time
            seconds_edit[current_num_edits] += run_time

            if chain_length < int(current_num_hops):
                chain_over_under_hop[current_num_hops]['under'] += 1
                chain_over_under_edit[current_num_edits]['under'] += 1
            elif chain_length > int(current_num_hops):
                chain_over_under_hop[current_num_hops]['over'] += 1
                chain_over_under_edit[current_num_edits]['over'] += 1
            else:
                chain_over_under_hop[current_num_hops]['correct'] += 1
                chain_over_under_edit[current_num_edits]['correct'] += 1

            if 'mquake' in args.ds_name:
                true_answer = c['new_answer']
                answer_aliases = c['new_answer_alias']

            elif args.ds_name == 'rippleedit':
                raise Exception('Ripple Edit not implemented yet.')
            
            elif args.ds_name == 'kebench':
                true_answer = c['new_two_hop_answer']
                answer_aliases = c['new_two_hop_answer_aliases']

            if final_answer == true_answer or final_answer in answer_aliases:
                q_acc += 1

                q_acc_hop[current_num_hops] += 1
                q_acc_edit[current_num_edits] += 1

                if not case_correct:
                    case_correct = True
                    i_acc += 1

                    c_acc_hop[current_num_hops] += 1
                    c_acc_edit[current_num_edits] += 1

    print(f'\nDataset Size     : {len(ds)} Cases and {len(ds)*3} Questions')

    print(f'Case Accuracy    : {i_acc} / {total} or {i_acc/total*100:.2f}')
    print(f'Question Accuracy: {q_acc} / {total*3} or {q_acc/(total*3)*100:.2f}\n')

    print(f'Total Number of Hops : {total_hops}')
    print(f'Total Number of Edits: {total_edits}\n')

    print(f'Per-Question Hop Accuracy : {q_acc_hop}')
    print(f'Per-Question Edit Accuracy: {q_acc_edit}\n')

    print(f'Per-Case Hop Accuracy : {c_acc_hop}')
    print(f'Per-Case Edit Accuracy: {c_acc_edit}\n')

    print(f'Total Seconds per Hop : {seconds_hop}')
    print(f'Total Seconds per Edit: {seconds_edit}\n')

    print(f'Chain Correctness per Hop : {chain_over_under_hop}')
    print(f'Chain Correctness per Edit: {chain_over_under_edit}\n')

    print(f'Number of Input Tokens Used : {in_tokens}')
    print(f'Number of Output Tokens Used: {out_tokens}')
    

if __name__ == '__main__':
    import warnings
    from huggingface_hub import login

    warnings.filterwarnings('ignore')

    args = parse_args()
    login(args.hf_token)
    
    main(args)
