import random
import json
import datetime
from templates import (
    birth_date_templates,
    birth_city_templates,
    university_templates,
    major_templates,
    employer_templates,
    company_city_templates,
    birth_date_question_templates,
    birth_city_question_templates,
    university_question_templates,
    major_question_templates,
    employer_question_templates,
    company_city_question_templates,
    capitalize
)
import csv
import os
import numpy as np
from tqdm import tqdm
from copy import deepcopy
from mix_data import simple_mix_data
from convert_binary import convert_binary

PATH = "data/"
RESULT_PATH = "hallucinate_small/"
os.makedirs(RESULT_PATH, exist_ok=True)
random.seed(0)

import hashlib

# Calculate MD5 for pretrain.txt
def calculate_md5(file_path):
    hash_md5 = hashlib.md5()
    with open(file_path, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            hash_md5.update(chunk)
    return hash_md5.hexdigest()

def load_list(filename):
    with open(filename, 'r', encoding='utf-8') as f:
        return [line.strip() for line in f if line.strip()]

def load_companies(filename):
    companies = []
    with open(filename, 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        for row in reader:
            if row:
                name, hq = row[0], row[1]
                companies.append({'name': name, 'hq': hq})
        return companies

def generate_unique_full_names(first_names, middle_names, last_names, N):
    unique_names = set()
    names_list = []
    try:
        with open(RESULT_PATH + f'new_person{args.snum}.txt', 'r', encoding='utf-8') as file:
            for line in file:
                name = line.strip()
                if name: 
                    parts = name.split()
                    if len(parts) == 3: 
                        first, middle, last = parts
                        full_name = name
                        if full_name not in unique_names:
                            unique_names.add(full_name)
                            names_list.append({
                                'first_name': first,
                                'middle_name': middle,
                                'last_name': last,
                                'full_name': full_name,
                                'idx': len(names_list)
                            })
    except FileNotFoundError:
        print(f"Warnings：Cannot find 'hallucinate_small/new_person{args.snum}.txt'")
    except Exception as e:
        print(f"Cannot read: {e}")

    while len(names_list) < N:
        first = random.choice(first_names)
        middle = random.choice(middle_names)
        last = random.choice(last_names)
        full_name = f"{first} {middle} {last}"
        if full_name not in unique_names:
            unique_names.add(full_name)
            names_list.append({
                'first_name': first,
                'middle_name': middle,
                'last_name': last,
                'full_name': full_name,
                'idx': len(names_list)
            })
    
    return names_list
def generate_sentence(template_list, attribute_values):
    template = random.choice(template_list)
    # Capitalize the first word if it's a pronoun
    if template.startswith('{pronoun}') and 'pronoun' in attribute_values:
        attribute_values['pronoun'] = capitalize(attribute_values['pronoun'])
    sentence = template.format(**attribute_values)
    return sentence

def generate_profile(first_names, middle_names, last_names, cities, universities, majors, companies, N):
    # Generate unique full names
    individuals = generate_unique_full_names(first_names, middle_names, last_names, N)

    pronouns = ['he', 'she', 'they']
    possessive_pronouns = {'he': 'his', 'she': 'her', 'they': 'their'}
    object_pronouns = {'he': 'him', 'she': 'her', 'they': 'them'}
    reflexive_pronouns = {'he': 'himself', 'she': 'herself', 'they': 'themselves'}

    # Assign attributes
    for person in individuals:
        # Assign birth date
        birth_year = random.randint(1900, 2099)
        birth_month = random.randint(1, 12)
        birth_day = random.randint(1, 28)
        birth_date = datetime.date(birth_year, birth_month, birth_day)
        person['birth_date'] = birth_date.strftime("%B %d, %Y")  # Format as "Month Day, Year"
        person['birth_month'] = birth_date.strftime("%B")
        person['birth_day'] = str(birth_day)
        person['birth_year'] = str(birth_year)
        
        # Assign birth city
        person['birth_city'] = random.choice(cities)
        
        # Assign university
        person['university'] = random.choice(universities)
        
        # Assign major
        person['major'] = random.choice(majors)
        
        # Assign employer
        employer = random.choice(companies)
        person['employer'] = employer['name']
        person['company_city'] = employer['hq']
        
        # Assign pronoun
        person['pronoun'] = random.choice(pronouns)
        person['possessive_pronoun'] = possessive_pronouns[person['pronoun']]
        person['object_pronoun'] = object_pronouns[person['pronoun']]
        person['reflexive_pronoun'] = reflexive_pronouns[person['pronoun']]

    return individuals

def pronoun_to_fullname(attribute_values):
    result_values = deepcopy(attribute_values)
    result_values['pronoun'] = attribute_values['full_name']
    result_values['possessive_pronoun'] = attribute_values['full_name']+'\'s'
    result_values['object_pronoun'] = attribute_values['full_name']
    return result_values

def generate_description(attribute_values):
    sentences = []
    sentences.append(generate_sentence(birth_date_templates, pronoun_to_fullname(attribute_values)))

    sentence_generators = [
        (birth_city_templates, attribute_values),
        (university_templates, attribute_values),
        (major_templates, attribute_values),
        (employer_templates, attribute_values),
        (company_city_templates, attribute_values)
    ]
    random.shuffle(sentence_generators)
    for templates, values in sentence_generators:
        sentences.append(generate_sentence(templates, values))
    
    biographical_entry = ' '.join(sentences)
    return biographical_entry

def generate_description_fullname(attribute_values):
    sentences = []
    attribute_values = pronoun_to_fullname(attribute_values)

    sentence_generators = [
        (birth_date_templates, attribute_values),
        (birth_city_templates, attribute_values),
        (university_templates, attribute_values),
        (major_templates, attribute_values),
        (employer_templates, attribute_values),
        (company_city_templates, attribute_values)
    ]
    random.shuffle(sentence_generators)
    for templates, values in sentence_generators:
        sentences.append(generate_sentence(templates, values))
    
    biographical_entry = ' '.join(sentences)
    return biographical_entry

def generate_perturbed_description(attribute_values):
    sentences = []

    sentence_generators = [
        (birth_date_templates, attribute_values),
        (birth_city_templates, attribute_values),
        (university_templates, attribute_values),
        (major_templates, attribute_values),
        (employer_templates, attribute_values),
        (company_city_templates, attribute_values)
    ]
    random.shuffle(sentence_generators)
    for i, (templates, values) in enumerate(sentence_generators):
        if i != 0:
            sentences.append(generate_sentence(templates, values))
        else:
            sentences.append(generate_sentence(templates, pronoun_to_fullname(values)))
    
    biographical_entry = ' '.join(sentences)
    return biographical_entry

def generate_qa_pairs(attribute_values, first_n_template=-1):
    pairs = []
    meta_templates = [(birth_date_question_templates, "birth_date"),
                 (birth_city_question_templates, "birth_city"),
                 (university_question_templates, "university"),
                 (major_question_templates, "major"),
                 (employer_question_templates, "employer"),
                 (company_city_question_templates, "company_city")]
    for template, key in meta_templates:
        template = template[:first_n_template] if first_n_template > 0 else template
        for question_template in template:
            question = question_template.format(**attribute_values)
            pairs.append({
            'question': question,
            'answer': attribute_values[key],
            'type': key
        })
    return pairs

def generate_random_qa_pairs(attribute_values, num_samples):
    pairs = []
    meta_templates = [
        (birth_date_question_templates, "birth_date"),
        (birth_city_question_templates, "birth_city"),
        (university_question_templates, "university"),
        (major_question_templates, "major"),
        (employer_question_templates, "employer"),
        (company_city_question_templates, "company_city")
    ]
    
    all_questions = []
    for template, key in meta_templates:
        for question_template in template:
            all_questions.append((question_template, key))
    
    sampled_questions = random.choices(all_questions, k=num_samples)
    
    for question_template, key in sampled_questions:
        question = question_template.format(**attribute_values)
        pairs.append({
            'question': question,
            'answer': attribute_values[key],
            'type': key
        })
    
    return pairs


def format_qa_pair(qa_pair):
    return f"Q: {qa_pair['question']} A: {qa_pair['answer']}"

def format_qa_unknown(qa_pair):
    return f"Q: {qa_pair['question']} A: I don't know."

def generate_or_load_individuals(people):
    first_names = load_list(PATH + 'first_names.txt')  # 400 names
    middle_names = load_list(PATH + 'middle_names.txt')  # 400 names
    last_names = load_list(PATH + 'last_names.txt')  # 1000 names
    cities = load_list(PATH + 'cities.txt')  # 200 cities in "City, State" format
    universities = load_list(PATH + 'universities.txt')  # 300 universities
    majors = load_list(PATH + 'majors.txt')  # 100 majors
    companies = load_companies(PATH + 'companies.csv')  # 263 companies with headquarters

    # Generate profiles
    # if not os.path.exists(RESULT_PATH + 'profiles.jsonl') or not os.path.exists(RESULT_PATH + 'profiles.md5'):
    N = people
    individuals = generate_profile(first_names, middle_names, last_names, cities, universities, majors, companies, N)
    with open(RESULT_PATH + 'profiles.jsonl', 'w', encoding='utf-8') as f:
        for person in individuals:
            f.write(json.dumps(person) + '\n')
    with open(RESULT_PATH + 'profiles.md5', 'w', encoding='utf-8') as f:
        f.write(calculate_md5(RESULT_PATH + 'profiles.jsonl'))
    # else:
    #     print("Loading profiles from file")
    #     md5_hash = calculate_md5(RESULT_PATH + 'profiles.jsonl')    
    #     with open(RESULT_PATH + 'profiles.md5', 'r', encoding='utf-8') as f:
    #         if f.read() != md5_hash:
    #             raise ValueError("Profiles file has been modified. Please delete the file and re-run the script.")
    #     with open(RESULT_PATH + 'profiles.jsonl', 'r', encoding='utf-8') as f:
    #         individuals = [json.loads(line) for line in f]
    return individuals


def write_qa(individuals, name, pid, refuse=False):
    qa_data = []
    format_func = format_qa_unknown if refuse else format_qa_pair
    for person in tqdm(individuals, desc="Generating QA pairs"):
        qa_data.extend([format_func(qa_pair) for qa_pair in generate_qa_pairs(person, first_n_template=5)])
    # shuffle qa data
    random.shuffle(qa_data)
    with open(RESULT_PATH + f'{name}{pid}.txt', 'w', encoding='utf-8') as f:
        f.write('\n'.join(qa_data) + '\n')
    
    md5_hash = calculate_md5(RESULT_PATH + f'{name}{pid}.txt')
    # write md5 hash to file
    with open(RESULT_PATH + f'{name}{pid}.md5', 'w', encoding='utf-8') as f:
        f.write(md5_hash)

def get_distribution_count(distribution, num_person = 100000, averaged_entry_per_person = 50):
    assert distribution in ['uniform', 'inverse', 'power']
    # total entry: num_person * averaged_entry_per_person
    total_entry = num_person * averaged_entry_per_person

    # count is an array of length num_person, each element is the number of entries for each person
    # uniform distribution
    if distribution == 'uniform':
        return np.full(num_person, averaged_entry_per_person)
    # inverse distribution with bias
    bias = 1000
    if distribution == 'inverse':
        inverse_prob = np.array([1 / (i + bias) for i in range(num_person)])
        inverse_prob = inverse_prob / inverse_prob.sum()
        inverse_count = np.array([int(prob * total_entry) for prob in inverse_prob])
        inverse_count = np.maximum(inverse_count, 1)
        return inverse_count
    # power distribution with bias
    a = 1.35
    if distribution == 'power':
        power_prob = np.array([(i + bias) ** (-a) for i in range(num_person)])
        power_prob = power_prob / power_prob.sum()
        power_count = np.array([int(prob * total_entry) for prob in power_prob])
        power_count = np.maximum(power_count, 1)
        return power_count

def main():
    import argparse
    global args
    parser = argparse.ArgumentParser(description="Text file to binary conversion")
    parser.add_argument("-n", "--people_rate", type=str, required=True, help="Persons' rate to generate")
    parser.add_argument("-r", "--number", type=int, required=True)
    parser.add_argument("-s", "--snum", type=int, required=True)
    args = parser.parse_args()
    num_person = 100000
    individuals = generate_or_load_individuals(num_person)
    write_qa(individuals, "SFT", args.number)
    upsample_num_person = num_person

    upsample_count = get_distribution_count(distribution="uniform", num_person=upsample_num_person, averaged_entry_per_person=50)
    pretrain_upsample_individuals = individuals[:upsample_num_person]  

    upsample_biographical_entries = []

    for idx, person in tqdm(enumerate(pretrain_upsample_individuals), desc="Generating upsampled biographical entries", total=len(pretrain_upsample_individuals)):
        person_entries = [generate_perturbed_description(person) for _ in range(upsample_count[idx])]
        upsample_biographical_entries.extend(person_entries)

    random.shuffle(upsample_biographical_entries)

    entries_file = f"biographical_entries{args.number}.txt"
    people_file = f"people_{args.number}.json"

    with open(RESULT_PATH + f'pretrain_perturbed{args.number}/'+ entries_file, 'w', encoding='utf-8') as f:
        for entry in upsample_biographical_entries:
            f.write(f"{entry}\n")

    with open(RESULT_PATH + f'pretrain_perturbed{args.number}/'+ people_file, 'w', encoding='utf-8') as f:
        json.dump(pretrain_upsample_individuals, f, ensure_ascii=False, indent=4)

    biographical_entries = upsample_biographical_entries
    random.shuffle(biographical_entries)
    
    with open(RESULT_PATH + f'pretrain_perturbed{args.number}.txt', 'w', encoding='utf-8') as f:
        f.write('\n'.join(biographical_entries))
    
    #md5_hash = calculate_md5(RESULT_PATH + 'pretrain_perturbed.txt')
    # write md5 hash to file
    #with open(RESULT_PATH + 'pretrain_perturbed.md5', 'w', encoding='utf-8') as f:
    #    f.write(md5_hash)
        
    #simple_mix_data(os.path.join(RESULT_PATH, 'pretrain_perturbed.txt'), os.path.join(RESULT_PATH, 'SFT_mix_pretraining.txt'), 1, 1, os.path.join(RESULT_PATH, 'pretrain_perturbed_mixed.txt'))
    #simple_mix_data(os.path.join(RESULT_PATH, 'SFT.txt'), os.path.join(RESULT_PATH, 'SFT_unknown.txt'), 1, 1, os.path.join(RESULT_PATH, 'SFT_mix_unknown.txt'))
    #simple_mix_data(os.path.join(RESULT_PATH, 'SFT.txt'), os.path.join(RESULT_PATH, 'SFT_unknown_refused.txt'), 1, 1, os.path.join(RESULT_PATH, 'SFT_mix_unknown_refused.txt'))
    
    #for file in ['pretrain_perturbed_mixed', 'SFT_mix_unknown', 'SFT_mix_unknown_refused', 'SFT']:
    #   src_file = os.path.join(RESULT_PATH, f"{file}.txt")
    #    dst_folder = os.path.join(RESULT_PATH, file)
    #    os.makedirs(dst_folder, exist_ok=True)
    #    convert_binary(src_file, dst_folder, align_length=512)

if __name__ == '__main__':
    main()
