import os
import json
import copy
from datasets import Dataset
from transformers import AutoModelForCausalLM, set_seed, AutoTokenizer
import numpy

def get_standard_input(prompt, max_length, tokenizer=None):
    prompt = prompt.replace('\n',' ')
    if 'Context:' in prompt:
        r = prompt.split('Context:')
        inst = r[0]
        context = r[1]
        input = tokenizer(context, return_tensors="pt")
        input_ids = input.input_ids[0]

        if input_ids.shape[0]>max_length:
            input_ids = input_ids[-max_length:-1]
    
        prompt = inst+' Context: '+ tokenizer.decode(input_ids.tolist(),skip_special_tokens=True)
    else:
        input = tokenizer(prompt, return_tensors="pt")
        input_ids = input.input_ids[0]

        if input_ids.shape[0]>max_length:
            input_ids = input_ids[-max_length:-1]
        prompt = tokenizer.decode(input_ids.tolist(),skip_special_tokens=True)

    return prompt

def get_nq_dataset(data_path,  split = 'train',tokenizer = None,size = 2000, sft=False,sft_c = False,sft_i=False,ctx = False,idk=False,sft_idk=False):
    data_path = data_path+'nq_'+split+'.jsonl'
    prompt = []
    rejected = []
    chosen = []

    ids = []
    i = 0

    with open(data_path,'r') as f:
        r = f.readline()
        while(r):
            if i ==size:
                break
            item = json.loads(r)
            query = item['query']

            c_context = item['counter']
            c_context = get_standard_input(c_context,1400,tokenizer)
            c_prompt = item['c_t'].format(Question=query,Context = c_context)

            idk_context = item['idk_context']
            idk_context = get_standard_input(idk_context,1400,tokenizer)
            idk_prompt = item['c_t'].format(Question=query,Context = idk_context)

            i_prompt = item['i_t'].format(Question=query)

            c_answer = item['c_ans_long']
            wiki_answer = item['wiki_ans_long']
            idk_answer = 'I can not answer the question with the context information.'

            prompt.append(f'[INST] {c_prompt} [/INST] ')
            chosen.append(c_answer + ' </s>')
            rejected.append(wiki_answer+' </s>')

            prompt.append(f'[INST] {idk_prompt} [/INST] ')
            chosen.append(idk_answer + ' </s>')
            rejected.append(c_answer+' </s>')

            prompt.append(f'[INST] {i_prompt} [/INST] ')
            chosen.append(wiki_answer + ' </s>')
            rejected.append(idk_answer+' </s>')

            r = f.readline()
            i+=1
        i=0
    
    if sft:
        sft_prompts = []
        for i in range(len(prompt)):
            sft_prompts.append(prompt[i] + chosen[i])
        data = Dataset.from_dict({"text": sft_prompts})
        return data
    if sft_c:
        sft_prompts = []
        for i in range(len(prompt)):
            if i % 3 ==0:
                sft_prompts.append(prompt[i] + chosen[i])
        data = Dataset.from_dict({"text": sft_prompts})
        return data
    if sft_i:
        sft_prompts = []
        for i in range(len(prompt)):
            if i % 3 ==2:
                sft_prompts.append(prompt[i] + chosen[i])
        data = Dataset.from_dict({"text": sft_prompts})
        return data
    if sft_idk:
        sft_prompts = []
        for i in range(len(prompt)):
            if i % 3 ==1:
                sft_prompts.append(prompt[i] + chosen[i])
        data = Dataset.from_dict({"text": sft_prompts})
        return data
    if ctx:
        sft_prompts = []
        for i in range(len(prompt)):
            if i % 3 ==0 or i % 3 == 1:
                sft_prompts.append(prompt[i] + chosen[i])
        data = Dataset.from_dict({"text": sft_prompts})
        return data

def get_trivia_dataset(data_path,  split = 'train',tokenizer = None,size = 2000, sft=False,sft_c = False,sft_i=False,ctx = False,sft_idk=False):
    data_path = data_path+'trivia_'+split+'.jsonl'
    prompt = []
    rejected = []
    chosen = []

    ids = []
    i = 0

    with open(data_path,'r') as f:
        r = f.readline()
        while(r):
            if i ==size:
                break
            item = json.loads(r)
            query = item['question']

            c_context = item['c_context']
            c_context = get_standard_input(c_context,1400,tokenizer)
            c_prompt = item['c_t'].format(Question=query,Context = c_context)

            idk_context = item['idk_context']
            idk_context = get_standard_input(idk_context,1400,tokenizer)
            idk_prompt = item['c_t'].format(Question=query,Context = idk_context)

            i_prompt = item['i_t'].format(Question=query)

            c_answer = item['c_ans_l']
            wiki_answer = item['ans_l']
            idk_answer = 'I can not answer the question with the context information.'

            prompt.append(f'[INST] {c_prompt} [/INST] ')
            chosen.append(c_answer + ' </s>')
            rejected.append(wiki_answer+' </s>')

            prompt.append(f'[INST] {idk_prompt} [/INST] ')
            chosen.append(idk_answer + ' </s>')
            rejected.append(c_answer+' </s>')

            prompt.append(f'[INST] {i_prompt} [/INST] ')
            chosen.append(wiki_answer + ' </s>')
            rejected.append(idk_answer+' </s>')

            r = f.readline()
            i+=1
        i=0
    
    if sft:
        sft_prompts = []
        for i in range(len(prompt)):
            sft_prompts.append(prompt[i] + chosen[i])
        data = Dataset.from_dict({"text": sft_prompts})
        return data
    if sft_c:
        sft_prompts = []
        for i in range(len(prompt)):
            if i % 3 ==0:
                sft_prompts.append(prompt[i] + chosen[i])
        data = Dataset.from_dict({"text": sft_prompts})
        return data
    if sft_i:
        sft_prompts = []
        for i in range(len(prompt)):
            if i % 3 ==2:
                sft_prompts.append(prompt[i] + chosen[i])
        data = Dataset.from_dict({"text": sft_prompts})
        return data
    if sft_idk:
        sft_prompts = []
        for i in range(len(prompt)):
            if i % 3 ==1:
                sft_prompts.append(prompt[i] + chosen[i])
        data = Dataset.from_dict({"text": sft_prompts})
        return data
    if ctx:
        sft_prompts = []
        for i in range(len(prompt)):
            if i % 3 ==0 or i % 3 == 1:
                sft_prompts.append(prompt[i] + chosen[i])
        data = Dataset.from_dict({"text": sft_prompts})
        return data
    
    data = Dataset.from_dict({"prompt": prompt, "chosen": chosen, "rejected": rejected})
    return data

def get_mt_dataset(data_path,  split = 'train',tokenizer = None,size = 2000, sft=False,sft_c = False,sft_i=False,lang = 'ru'):
    data_path = data_path+'en_'+lang+'_'+split+'_final.jsonl'
    prompt = []
    rejected = []
    chosen = []

    ids = []
    i = 0

    with open(data_path,'r') as f:
        r = f.readline()
        while(r):
            if i ==size:
                break
            item = json.loads(r)
            query = item['en']

            c_context = item['c_passage']
            c_prompt = item['c_t'].format(Question=query,Context = c_context)

            i_prompt = item['i_t'].format(Question=query)

            counter_answer = item[lang].replace(item['correct'],item['incorrect'])
            ac_answer = item[lang]

            prompt.append(f'[INST] {c_prompt} [/INST] ')
            chosen.append(counter_answer + ' </s>')

            # p = get_standard_input(i_prompt,1400,tokenizer)
            prompt.append(f'[INST] {i_prompt} [/INST] ')
            chosen.append(ac_answer + ' </s>')

            r = f.readline()
            i+=1
        i=0
    
    if sft:
        sft_prompts = []
        for i in range(len(prompt)):
            sft_prompts.append(prompt[i] + chosen[i])
        data = Dataset.from_dict({"text": sft_prompts})
        return data
    if sft_c:
        sft_prompts = []
        for i in range(len(prompt)):
            if i % 2 ==0:
                sft_prompts.append(prompt[i] + chosen[i])
        data = Dataset.from_dict({"text": sft_prompts})
        return data
    if sft_i:
        sft_prompts = []
        for i in range(len(prompt)):
            if i % 2 ==1:
                sft_prompts.append(prompt[i] + chosen[i])
        data = Dataset.from_dict({"text": sft_prompts})
        return data
    
    data = Dataset.from_dict({"prompt": prompt, "chosen": chosen, "rejected": rejected})
    return data

def get_mc_dataset(data_path, name = 'nq_', split = 'train',tokenizer = None,size = 2000, sft=False,sft_c = False,sft_i=False, sft_idk =False,ctx = False):
    data_path = data_path+name+split+'.jsonl'
    prompt = []
    chosen = []

    ids = []
    i = 0

    with open(data_path,'r') as f:
        r = f.readline()
        while(r):
            if i ==size:
                break
            item = json.loads(r)
            query = item['question']

            option_sent = ''
            options = item['options']
            answers = [item['answer'],item['incor1'],item['incor2'],item['incor3'],item['idk_ans']]
            for u in ['A','B','C','D','E']:
                op = answers[options.index(u)]
                option_sent += ' '+u+'. '+op

            query += option_sent

            c_context = item['c_context']
            c_context = get_standard_input(c_context,1400,tokenizer)
            c_prompt = item['c_t'].format(Question=query,Context = c_context)

            idk_context = item['idk_context']
            idk_context = get_standard_input(idk_context,1400,tokenizer)
            idk_prompt = item['c_t'].format(Question=query,Context = idk_context)

            i_prompt = item['i_t'].format(Question=query)

            c_answer = item['options'][1]
            wiki_answer = item['options'][0]
            idk_answer = item['options'][-1]

            prompt.append(f'[INST] {c_prompt} [/INST] ')
            chosen.append(c_answer + ' </s>')

            prompt.append(f'[INST] {idk_prompt} [/INST] ')
            chosen.append(idk_answer + ' </s>')

            prompt.append(f'[INST] {i_prompt} [/INST] ')
            chosen.append(wiki_answer + ' </s>')

            r = f.readline()
            i+=1
        i=0
    
    if sft:
        sft_prompts = []
        for i in range(len(prompt)):
            sft_prompts.append(prompt[i] + chosen[i])
        data = Dataset.from_dict({"text": sft_prompts})
        return data
    if sft_c:
        sft_prompts = []
        for i in range(len(prompt)):
            if i % 3 ==0:
                sft_prompts.append(prompt[i] + chosen[i])
        data = Dataset.from_dict({"text": sft_prompts})
        return data
    if sft_i:
        sft_prompts = []
        for i in range(len(prompt)):
            if i % 3 ==2:
                sft_prompts.append(prompt[i] + chosen[i])
        data = Dataset.from_dict({"text": sft_prompts})
        return data
    
    if sft_idk:
        sft_prompts = []
        for i in range(len(prompt)):
            if i % 3 ==1:
                sft_prompts.append(prompt[i] + chosen[i])
        data = Dataset.from_dict({"text": sft_prompts})
        return data
    
    if ctx:
        sft_prompts = []
        for i in range(len(prompt)):
            if i % 3 ==0 or i % 3 == 1:
                sft_prompts.append(prompt[i] + chosen[i])
        data = Dataset.from_dict({"text": sft_prompts})
        return data
    
    return data

# model_name_or_path = 'mistralai/Mistral-7B-Instruct-v0.1'
# tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

# read openbookqa data
# data = get_nq_dataset('../openbookqa_data/',split='train',tokenizer=tokenizer,size=1000,sft=True)
# data = get_trivia_dataset('../openbookqa_data/',split='train',tokenizer=tokenizer,size=1000,sft=True)

# read multi-choice data
# data = get_mc_dataset('../mc_data/',name='nq_',split='train',tokenizer=tokenizer,size=1000,sft=True)
# data = get_mc_dataset('../mc_data/',name='trivia_',split='train',tokenizer=tokenizer,size=1000,sft=True)


# read machine translation data
# data = get_mt_dataset('../mt_data/',split='train',tokenizer=tokenizer,size=1000,sft=True,lang='ru')
# data = get_mt_dataset('../mt_data/',split='train',tokenizer=tokenizer,size=1000,sft=True,lang='zh')

# print(data[0])