import random
import json
import datetime
from templates import (
    birth_date_templates,
    birth_city_templates,
    university_templates,
    major_templates,
    employer_templates,
    company_city_templates,
    birth_date_question_templates,
    birth_city_question_templates,
    university_question_templates,
    major_question_templates,
    employer_question_templates,
    company_city_question_templates,
    capitalize
)
import csv
import os
import numpy as np
from tqdm import tqdm
from copy import deepcopy
from mix_data import simple_mix_data
from convert_binary import convert_binary
from fire import Fire

PATH = "data/"
RESULT_PATH = "hallucinate_small/"
os.makedirs(RESULT_PATH, exist_ok=True)
random.seed(0)

import hashlib

# Calculate MD5 for pretrain.txt
def calculate_md5(file_path):
    hash_md5 = hashlib.md5()
    with open(file_path, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            hash_md5.update(chunk)
    return hash_md5.hexdigest()

def load_list(filename):
    with open(filename, 'r', encoding='utf-8') as f:
        return [line.strip() for line in f if line.strip()]

def load_companies(filename):
    companies = []
    with open(filename, 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        for row in reader:
            if row:
                name, hq = row[0], row[1]
                companies.append({'name': name, 'hq': hq})
        return companies

def generate_unique_full_names(first_names, middle_names, last_names, N):
    unique_names = set()
    names_list = []

    while len(names_list) < N:
        first = random.choice(first_names)
        middle = random.choice(middle_names)
        last = random.choice(last_names)
        full_name = f"{first} {middle} {last}"
        if full_name not in unique_names:
            unique_names.add(full_name)
            names_list.append({'first_name': first, 'middle_name': middle, 'last_name': last, 'full_name': full_name, 'idx': len(names_list)})
    return names_list

def generate_sentence(template_list, attribute_values):
    template = random.choice(template_list)
    # Capitalize the first word if it's a pronoun
    if template.startswith('{pronoun}') and 'pronoun' in attribute_values:
        attribute_values['pronoun'] = capitalize(attribute_values['pronoun'])
    sentence = template.format(**attribute_values)
    return sentence

def generate_profile(first_names, middle_names, last_names, cities, universities, majors, companies, N):
    # Generate unique full names
    individuals = generate_unique_full_names(first_names, middle_names, last_names, N)

    pronouns = ['he', 'she', 'they']
    possessive_pronouns = {'he': 'his', 'she': 'her', 'they': 'their'}
    object_pronouns = {'he': 'him', 'she': 'her', 'they': 'them'}
    reflexive_pronouns = {'he': 'himself', 'she': 'herself', 'they': 'themselves'}

    # Assign attributes
    for person in individuals:
        # Assign birth date
        birth_year = random.randint(1900, 2099)
        birth_month = random.randint(1, 12)
        birth_day = random.randint(1, 28)
        birth_date = datetime.date(birth_year, birth_month, birth_day)
        person['birth_date'] = birth_date.strftime("%B %d, %Y")  # Format as "Month Day, Year"
        person['birth_month'] = birth_date.strftime("%B")
        person['birth_day'] = str(birth_day)
        person['birth_year'] = str(birth_year)
        
        # Assign birth city
        person['birth_city'] = random.choice(cities)
        
        # Assign university
        person['university'] = random.choice(universities)
        
        # Assign major
        person['major'] = random.choice(majors)
        
        # Assign employer
        employer = random.choice(companies)
        person['employer'] = employer['name']
        person['company_city'] = employer['hq']
        
        # Assign pronoun
        person['pronoun'] = random.choice(pronouns)
        person['possessive_pronoun'] = possessive_pronouns[person['pronoun']]
        person['object_pronoun'] = object_pronouns[person['pronoun']]
        person['reflexive_pronoun'] = reflexive_pronouns[person['pronoun']]

    return individuals

def pronoun_to_fullname(attribute_values):
    result_values = deepcopy(attribute_values)
    result_values['pronoun'] = attribute_values['full_name']
    result_values['possessive_pronoun'] = attribute_values['full_name']+'\'s'
    result_values['object_pronoun'] = attribute_values['full_name']
    return result_values

def generate_description(attribute_values):
    sentences = []
    sentences.append(generate_sentence(birth_date_templates, pronoun_to_fullname(attribute_values)))

    sentence_generators = [
        (birth_city_templates, attribute_values),
        (university_templates, attribute_values),
        (major_templates, attribute_values),
        (employer_templates, attribute_values),
        (company_city_templates, attribute_values)
    ]
    random.shuffle(sentence_generators)
    for templates, values in sentence_generators:
        sentences.append(generate_sentence(templates, values))
    
    biographical_entry = ' '.join(sentences)
    return biographical_entry

def generate_description_fullname(attribute_values):
    sentences = []
    attribute_values = pronoun_to_fullname(attribute_values)

    sentence_generators = [
        (birth_date_templates, attribute_values),
        (birth_city_templates, attribute_values),
        (university_templates, attribute_values),
        (major_templates, attribute_values),
        (employer_templates, attribute_values),
        (company_city_templates, attribute_values)
    ]
    random.shuffle(sentence_generators)
    for templates, values in sentence_generators:
        sentences.append(generate_sentence(templates, values))
    
    biographical_entry = ' '.join(sentences)
    return biographical_entry

def generate_perturbed_description(attribute_values):
    sentences = []

    sentence_generators = [
        (birth_date_templates, attribute_values),
        (birth_city_templates, attribute_values),
        (university_templates, attribute_values),
        (major_templates, attribute_values),
        (employer_templates, attribute_values),
        (company_city_templates, attribute_values)
    ]
    random.shuffle(sentence_generators)
    for i, (templates, values) in enumerate(sentence_generators):
        if i != 0:
            sentences.append(generate_sentence(templates, values))
        else:
            sentences.append(generate_sentence(templates, pronoun_to_fullname(values)))
    
    biographical_entry = ' '.join(sentences)
    return biographical_entry

def generate_qa_pairs(attribute_values, first_n_template=-1):
    pairs = []
    meta_templates = [(birth_date_question_templates, "birth_date"),
                 (birth_city_question_templates, "birth_city"),
                 (university_question_templates, "university"),
                 (major_question_templates, "major"),
                 (employer_question_templates, "employer"),
                 (company_city_question_templates, "company_city")]
    for template, key in meta_templates:
        template = template[:first_n_template] if first_n_template > 0 else template
        for question_template in template:
            question = question_template.format(**attribute_values)
            pairs.append({
            'question': question,
            'answer': attribute_values[key],
            'type': key
        })
    return pairs

def generate_random_qa_pairs(attribute_values, num_samples):
    pairs = []
    meta_templates = [
        (birth_date_question_templates, "birth_date"),
        (birth_city_question_templates, "birth_city"),
        (university_question_templates, "university"),
        (major_question_templates, "major"),
        (employer_question_templates, "employer"),
        (company_city_question_templates, "company_city")
    ]
    
    all_questions = []
    for template, key in meta_templates:
        for question_template in template:
            all_questions.append((question_template, key))
    
    sampled_questions = random.choices(all_questions, k=num_samples)
    
    for question_template, key in sampled_questions:
        question = question_template.format(**attribute_values)
        pairs.append({
            'question': question,
            'answer': attribute_values[key],
            'type': key
        })
    
    return pairs


def format_qa_pair(qa_pair):
    return f"Q: {qa_pair['question']} A: {qa_pair['answer']}"

def format_qa_unknown(qa_pair):
    return f"Q: {qa_pair['question']} A: I don't know."

def generate_or_load_individuals(N=110000):
    first_names = load_list(PATH + 'first_names.txt')  # 400 names
    middle_names = load_list(PATH + 'middle_names.txt')  # 400 names
    last_names = load_list(PATH + 'last_names.txt')  # 1000 names
    cities = load_list(PATH + 'cities.txt')  # 200 cities in "City, State" format
    universities = load_list(PATH + 'universities.txt')  # 300 universities
    majors = load_list(PATH + 'majors.txt')  # 100 majors
    companies = load_companies(PATH + 'companies.csv')  # 263 companies with headquarters
    os.makedirs(RESULT_PATH, exist_ok=True)
    # Generate profiles
    if not os.path.exists(os.path.join(RESULT_PATH, 'profiles.jsonl')) or not os.path.exists(os.path.join(RESULT_PATH, 'profiles.md5')):
        individuals = generate_profile(first_names, middle_names, last_names, cities, universities, majors, companies, N)
        with open(os.path.join(RESULT_PATH, 'profiles.jsonl'), 'w', encoding='utf-8') as f:
            for person in individuals:
                f.write(json.dumps(person) + '\n')
        with open(os.path.join(RESULT_PATH, 'profiles.md5'), 'w', encoding='utf-8') as f:
            f.write(calculate_md5(os.path.join(RESULT_PATH, 'profiles.jsonl')))
    else:
        print("Loading profiles from file")
        md5_hash = calculate_md5(os.path.join(RESULT_PATH, 'profiles.jsonl'))    
        with open(os.path.join(RESULT_PATH, 'profiles.md5'), 'r', encoding='utf-8') as f:
            if f.read() != md5_hash:
                raise ValueError("Profiles file has been modified. Please delete the file and re-run the script.")
        with open(os.path.join(RESULT_PATH, 'profiles.jsonl'), 'r', encoding='utf-8') as f:
            individuals = [json.loads(line) for line in f]
    return individuals


def write_qa(individuals, name, first_n_template=5, refuse=False):
    qa_data = []
    format_func = format_qa_unknown if refuse else format_qa_pair
    for person in tqdm(individuals, desc="Generating QA pairs"):
        qa_data.extend([format_func(qa_pair) for qa_pair in generate_qa_pairs(person, first_n_template=first_n_template)])
    # shuffle qa data
    random.shuffle(qa_data)
    with open(os.path.join(RESULT_PATH, f'{name}.txt'), 'w', encoding='utf-8') as f:
        f.write('\n'.join(qa_data) + '\n')
    
    md5_hash = calculate_md5(os.path.join(RESULT_PATH, f'{name}.txt'))
    # write md5 hash to file
    with open(os.path.join(RESULT_PATH, f'{name}.md5'), 'w', encoding='utf-8') as f:
        f.write(md5_hash)

def get_distribution_count(distribution, num_person = 100000, averaged_entry_per_person = 50, a = 1.35):
    assert distribution in ['uniform', 'inverse', 'power']
    # total entry: num_person * averaged_entry_per_person
    total_entry = num_person * averaged_entry_per_person

    # count is an array of length num_person, each element is the number of entries for each person
    # uniform distribution
    if distribution == 'uniform':
        return np.full(num_person, averaged_entry_per_person)
    # inverse distribution with bias
    bias = 1000
    if distribution == 'inverse':
        inverse_prob = np.array([1 / (i + bias) for i in range(num_person)])
        inverse_prob = inverse_prob / inverse_prob.sum()
        inverse_count = np.array([int(prob * total_entry) for prob in inverse_prob])
        inverse_count = np.maximum(inverse_count, 1)
        return inverse_count
    # power distribution with bias
    if distribution == 'power':
        power_prob = np.array([(i + bias) ** (-a) for i in range(num_person)])
        power_prob = power_prob / power_prob.sum()
        power_count = np.array([int(prob * total_entry) for prob in power_prob])
        power_count = np.maximum(power_count, 1)
        return power_count

def main_hallucination():
    global RESULT_PATH
    RESULT_PATH = "hallucinate_small/"
    individuals = generate_or_load_individuals()
    write_qa(individuals[:5000], "SFT")
    write_qa(individuals[6000:10000], "SFT_test")
    write_qa(individuals[5000:6000], "SFT_mix_pretraining")
    write_qa(individuals[10000:11000], "SFT_unknown")
    write_qa(individuals[10000:11000], "SFT_unknown_refused", refuse=True)
    write_qa(individuals[11000:12000], "SFT_unknown_refused_test", refuse=True)
    
    
    # pretrain_perturbed_uniform
    count = get_distribution_count(distribution = "uniform", num_person=10000, averaged_entry_per_person=50)
    pretrain_individuals = individuals[:10000]
    biographical_entries = []
    for idx, person in tqdm(enumerate(pretrain_individuals), desc="Generating biographical entries", total=len(pretrain_individuals)):
        person_entries = [generate_perturbed_description(person) for _ in range(count[idx])]
        biographical_entries.extend(person_entries)
    # shuffle biographical entries
    random.shuffle(biographical_entries)
    with open(RESULT_PATH + 'pretrain_perturbed.txt', 'w', encoding='utf-8') as f:
        f.write('\n'.join(biographical_entries))
    
    md5_hash = calculate_md5(RESULT_PATH + 'pretrain_perturbed.txt')
    # write md5 hash to file
    with open(RESULT_PATH + 'pretrain_perturbed.md5', 'w', encoding='utf-8') as f:
        f.write(md5_hash)
        
    simple_mix_data(os.path.join(RESULT_PATH, 'pretrain_perturbed.txt'), os.path.join(RESULT_PATH, 'SFT_mix_pretraining.txt'), 1, 1, os.path.join(RESULT_PATH, 'pretrain_perturbed_mixed.txt'))
    simple_mix_data(os.path.join(RESULT_PATH, 'SFT.txt'), os.path.join(RESULT_PATH, 'SFT_unknown.txt'), 1, 1, os.path.join(RESULT_PATH, 'SFT_mix_unknown.txt'))
    simple_mix_data(os.path.join(RESULT_PATH, 'SFT.txt'), os.path.join(RESULT_PATH, 'SFT_unknown_refused.txt'), 1, 1, os.path.join(RESULT_PATH, 'SFT_mix_unknown_refused.txt'))
    
    for file in ['pretrain_perturbed_mixed', 'SFT_mix_unknown', 'SFT_mix_unknown_refused', 'SFT']:
        src_file = os.path.join(RESULT_PATH, f"{file}.txt")
        dst_folder = os.path.join(RESULT_PATH, file)
        os.makedirs(dst_folder, exist_ok=True)
        convert_binary(src_file, dst_folder, align_length=512)


def main_data_law(result_path, distribution, num_person = 100000, averaged_entry_per_person = 50, a = 1.35):
    global RESULT_PATH
    RESULT_PATH = result_path
    individuals = generate_or_load_individuals(N = int(num_person*1.1))

    # pretrain_perturbed_uniform
    count = get_distribution_count(distribution = distribution, num_person=num_person, averaged_entry_per_person=averaged_entry_per_person, a = a)
    pretrain_individuals = individuals[:num_person]
    biographical_entries = []
    for idx, person in tqdm(enumerate(pretrain_individuals), desc="Generating biographical entries", total=len(pretrain_individuals)):
        person_entries = [generate_perturbed_description(person) for _ in range(count[idx])]
        biographical_entries.extend(person_entries)
    # shuffle biographical entries
    random.shuffle(biographical_entries)
    with open(os.path.join(RESULT_PATH, 'pretrain.txt'), 'w', encoding='utf-8') as f:
        f.write('\n'.join(biographical_entries))
    
    md5_hash = calculate_md5(os.path.join(RESULT_PATH, 'pretrain.txt'))
    # write md5 hash to file
    with open(os.path.join(RESULT_PATH, 'pretrain.md5'), 'w', encoding='utf-8') as f:
        f.write(md5_hash)
        
    # write count(1d np array) to json file
    with open(os.path.join(RESULT_PATH, 'count.json'), 'w', encoding='utf-8') as f:
        json.dump(count.tolist(), f)
    
    for file in ['pretrain']:
        src_file = os.path.join(RESULT_PATH, f"{file}.txt")
        dst_folder = os.path.join(RESULT_PATH, file)
        os.makedirs(dst_folder, exist_ok=True)
        convert_binary(src_file, dst_folder, align_length=512)

def main_ift(result_path):
    global RESULT_PATH
    RESULT_PATH = result_path
    individuals = generate_or_load_individuals(N=110000)
    write_qa(individuals[:50000], "SFT", first_n_template=20)
    write_qa(individuals[100000:], "SFT_new", first_n_template=20)
    

    # pretrain_perturbed_uniform
    count = get_distribution_count(distribution = "uniform", num_person=10000, averaged_entry_per_person=20)
    continue_pretrain_individuals = individuals[100000:]
    biographical_entries = []
    for idx, person in tqdm(enumerate(continue_pretrain_individuals), desc="Generating biographical entries", total=len(continue_pretrain_individuals)):
        person_entries = [generate_perturbed_description(person) for _ in range(count[idx])]
        biographical_entries.extend(person_entries)
    # shuffle biographical entries
    random.shuffle(biographical_entries)
    with open(os.path.join(RESULT_PATH, 'continue_pretrain.txt'), 'w', encoding='utf-8') as f:
        f.write('\n'.join(biographical_entries))
    
    # write count(1d np array) to json file
    with open(os.path.join(RESULT_PATH, 'count.json'), 'w', encoding='utf-8') as f:
        json.dump(count.tolist(), f)
    
    # simple_mix_data(os.path.join(RESULT_PATH, 'pretrain_perturbed.txt'), os.path.join(RESULT_PATH, 'SFT_mix_pretraining.txt'), 1, 1, os.path.join(RESULT_PATH, 'pretrain_perturbed_mixed.txt'))
    # simple_mix_data(os.path.join(RESULT_PATH, 'SFT.txt'), os.path.join(RESULT_PATH, 'SFT_unknown.txt'), 1, 1, os.path.join(RESULT_PATH, 'SFT_mix_unknown.txt'))
    # simple_mix_data(os.path.join(RESULT_PATH, 'SFT.txt'), os.path.join(RESULT_PATH, 'SFT_unknown_refused.txt'), 1, 1, os.path.join(RESULT_PATH, 'SFT_mix_unknown_refused.txt'))
    
    # for file in ['pretrain_perturbed_mixed', 'SFT_mix_unknown', 'SFT_mix_unknown_refused', 'SFT']:
    #     src_file = os.path.join(RESULT_PATH, f"{file}.txt")
    #     dst_folder = os.path.join(RESULT_PATH, file)
    #     os.makedirs(dst_folder, exist_ok=True)
    #     convert_binary(src_file, dst_folder, align_length=512)
    for file in ['SFT_new', 'continue_pretrain']:
        src_file = os.path.join(RESULT_PATH, f"{file}.txt")
        dst_folder = os.path.join(RESULT_PATH, file)
        os.makedirs(dst_folder, exist_ok=True)
        convert_binary(src_file, dst_folder, align_length=512, val_shard_size=0)

if __name__ == '__main__':
    # main_data_law(result_path="datalaw/uniform", distribution="uniform", num_person=100000, averaged_entry_per_person=50)
    # main_data_law(result_path="datalaw/inverse", distribution="inverse", num_person=100000, averaged_entry_per_person=50)
    # main_data_law(result_path="datalaw/power_135", distribution="power", num_person=100000, averaged_entry_per_person=50, a=1.35)
    # main_data_law(result_path="datalaw/power_15", distribution="power", num_person=100000, averaged_entry_per_person=50, a=1.5)
    # main_data_law(result_path="datalaw/power_12", distribution="power", num_person=100000, averaged_entry_per_person=50, a=1.2)
    # main_data_law(result_path="datalaw/power_105", distribution="power", num_person=100000, averaged_entry_per_person=50, a=1.05)
    # main_data_law(result_path="datalaw/power_05", distribution="power", num_person=100000, averaged_entry_per_person=50, a=0.5)
    # main_data_law(result_path="datalaw/power_08", distribution="power", num_person=100000, averaged_entry_per_person=50, a=0.8)
    # main_data_law(result_path="data_law/inverse", distribution="inverse", num_person=400000, averaged_entry_per_person=50)
    # main_data_law(result_path="data_law/power_135", distribution="power", num_person=400000, averaged_entry_per_person=50, a=1.35)
    # main_data_law(result_path="data_law/power_15", distribution="power", num_person=400000, averaged_entry_per_person=50, a=1.5)
    # main_data_law(result_path="data_law/power_12", distribution="power", num_person=400000, averaged_entry_per_person=50, a=1.2)
    # main_data_law(result_path="data_law/power_105", distribution="power", num_person=400000, averaged_entry_per_person=50, a=1.05)
    # main_data_law(result_path="data_law/power_05", distribution="power", num_person=400000, averaged_entry_per_person=50, a=0.5)
    # main_data_law(result_path="data_law/power_08", distribution="power", num_person=400000, averaged_entry_per_person=50, a=0.8)
    # main_data_law(result_path="data_law/power_uniform", distribution="uniform", num_person=400000, averaged_entry_per_person=50)
    main_ift(result_path="ift/uniform")