import os
import re
import json
import torch
import argparse
import jsonlines
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

from general_helper_fns import save_pickle, load_pickle

import pdb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = 'meta-llama/Llama-2-7b-chat-hf'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model = model.to(device)

tokenizer.pad_token = tokenizer.bos_token
tokenizer.padding_side = 'right'
model.config.pad_token_id = model.config.bos_token_id

def read_jsonl_file(file_path):
    data = []
    with open(file_path, 'r') as file:
        for line in file:
            data.append(json.loads(line))
    return data

def split_paragraph(data):
    for example in data:
        pattern = r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s'
        sentences = re.split(pattern, example['output'].strip())
        annotations = [{'text': x} for x in sentences]
        example['annotations'] = annotations
    return data

def extract_list_text(string):
    pattern = r'^\d+[\.\)\]]?\s*(.*)$'
    match = re.match(pattern, string)
    if match:
        return match.group(1)
    else:
        return None  # No match found
    
def get_list(input_lines):
    output = []

    lines = input_lines.strip().split('\n')

    # list_started = False
    for line in lines:
        line = line.strip()
        if len(line) == 0:
            continue
        if line[0] == '(' or line[0] == '[':
            line = line[1:]
        elif line.lower().startswith('claim '):
            line = line[6:]
        elif line.startswith('- '):
            line = '1 ' + line[2:]
        elif not line[0].isdigit() and line[1:3] == '. ':
            line = '1 ' + line[3:]

        # Check if the numbered list has started
        list_started = (line[0].isdigit())
        
        # If the list has started, process the line
        if list_started:
            # Split each line into number and text
            text = extract_list_text(line)
            if text is not None:
                output.append(text)

    if len(output) == 0:
        # print("### NEEDS MANUAL PARSING ###\n" + input_lines)
        output = ["### NEEDS MANUAL PARSING ###\n" + input_lines]
    return output

def replace_pronouns(sentence, replace):
    if replace in sentence:
        return sentence
        
    # Define the pattern to match "he" or "she" ignoring case
    pattern = re.compile(r'\b(?:he|she|they)\b', re.IGNORECASE)
    replaced_sentence = re.sub(pattern, replace, sentence)
    
    return replaced_sentence

def replace_beginning_possessive_pronoun(sentence, replace):
    if replace in sentence:
        return sentence
    
    # Define the pattern to match "his" or "her" at the start of the sentence ignoring case
    pattern = re.compile(r'^(?:His|Her|Their)\b', re.IGNORECASE)
    replaced_sentence = re.sub(pattern, replace + "'s", sentence)
    
    return replaced_sentence

def replace_other_possessive_pronoun(sentence, replace):
    if replace in sentence:
        return sentence

    # Define the pattern to match "he" or "she" ignoring case
    pattern = re.compile(r'\b(?:his|her|their)\b', re.IGNORECASE)
    replaced_sentence = re.sub(pattern, replace + "'s", sentence)
    
    return replaced_sentence

prompt_base = '[INST] <<SYS>> Break down the following input into a set of small, independent claims. You must not add additional information. Output the claims as a numbered list separated by a new line. The subject of each line should be {}. <</SYS>> Input: {} [/INST]'
def get_atomic_facts(data, save_path_name=None):
    if save_path_name is not None:
        save_path_name = save_path_name.replace(".jsonl", f"_temp_gen_list.pkl")

    print("Get list of prompts")
    all_prompts = []
    for i in tqdm(range(len(data))):
        topic = data[i]['topic']
        if '(' in topic:
            topic = topic.split('(')[0] # get rid of parentheticals identifying the person

        annotations = data[i]['annotations']
        if annotations is None:
            continue

        texts = [x['text'] for x in annotations]
        prompts = [prompt_base.format(topic, x) for x in texts]
        all_prompts += prompts

    print("Split into batches and conduct inference")
    batch_size = 8
    num_batches = len(all_prompts) // batch_size + 1
    batched_prompts = np.array_split(all_prompts, num_batches)

    generations = []
    batch_idx = 0
    if os.path.exists(save_path_name):
        generations = load_pickle(save_path_name)
        batch_idx = len(generations) / batch_size
        assert batch_idx.is_integer()
        batch_idx = int(batch_idx)

    for batch in tqdm(batched_prompts[batch_idx:]):
        inputs = tokenizer(list(batch), return_tensors="pt", max_length=1000, truncation=True, padding=True).to(device)
        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=1000)
        outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        outputs = [x.split('[/INST]')[1] if '[/INST]' in x else "" for x in outputs]
        batch_atomic_facts = [get_list(x) for x in outputs]
        assert len(batch_atomic_facts) == len(outputs)
        generations += batch_atomic_facts
        # assert (len(generations) / batch_size).is_integer()
        
        batch_idx += 1

        if save_path_name is not None and batch_idx % 100 == 0:
            save_pickle(generations, save_path_name)
            print(len(generations))

    if save_path_name is not None:
            save_pickle(generations, save_path_name)

    return generations

def run(model_name, kind='labeled', additional_subpath=''):
    path_name = f'data/{kind}/{additional_subpath}/{model_name}.jsonl'
    if kind in ['labeled', 'unlabeled']: # don't replace jsonl files from the FActScore original paper
        save_path_name = f'data/{kind}/{additional_subpath}/{model_name}_split_by_llama2.jsonl'
    else:
        save_path_name = path_name

    data = read_jsonl_file(path_name)

    if kind != 'labeled':
        data = split_paragraph(data)

    generations = get_atomic_facts(data, save_path_name)
    
    print("Save generations to dictionary")
    i = 0
    for example in tqdm(data):
        topic = example['topic']
        annotations = example['annotations']
        if annotations is None:
            continue
        for annotation in annotations:
            atomic_facts = generations[i]
            if not atomic_facts[0].startswith('### NEEDS MANUAL PARSING ###'):
                atomic_facts = [replace_pronouns(x, topic) for x in atomic_facts]
                atomic_facts = [replace_beginning_possessive_pronoun(x, topic) for x in atomic_facts]
                atomic_facts = [replace_other_possessive_pronoun(x, topic) for x in atomic_facts]
            annotation['model-atomic-facts'] = [{'text': x} for x in atomic_facts]
            i += 1

    with jsonlines.open(save_path_name, mode='w') as writer:
        for dictionary in data:
            writer.write(dictionary)

def get_args():
    parser = argparse.ArgumentParser(description="Example script with argparse")

    parser.add_argument("--model", type=str, default="Llama2_7B_Chat")
    parser.add_argument("--kind", type=str, default="all")
    parser.add_argument("--additional_subpath", type=str, default="")

    args = parser.parse_args()
    return args

if __name__ == "__main__":
    from transformers import set_seed
    set_seed(0)

    args = get_args()

    model_name = args.model
    kind = args.kind
    additional_subpath = args.additional_subpath

    run(model_name, kind=kind, additional_subpath=additional_subpath)
