import os
import numpy as np
import multiprocessing as mp
import tiktoken

def tokenize(s):
    enc = tiktoken.get_encoding("gpt2")
    eot = 50256
    tokens = [eot] 
    tokens.extend(enc.encode_ordinary(s))
    tokens_np = np.array(tokens)
    assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), "token dictionary too large for uint16"
    tokens_np_uint16 = tokens_np.astype(np.uint16)
    return tokens_np_uint16

def get_average_token_length(file_path="data/cities.txt"):
    nprocs = max(1, os.cpu_count() - 2)
    total_length = 0
    line_count = 0
    
    # if not os.path.exists(file_path):
    #     print("Error: File not found!")
    #     return 0.0
    
    with open(file_path, 'r', encoding='utf-8') as f:
        lines = list(f)
    
    if not lines:
        print("File is empty, no data to process!")
        return 0.0
    import math
    with mp.Pool(nprocs) as pool:
        chunk_size = max(1, nprocs)
        token_arrays = pool.imap(tokenize, lines, chunksize=chunk_size)
        
        for tokens in token_arrays:
            total_length = max(total_length, len(tokens) - 2) # Exclude the begin and end tokens
            line_count += 1
        if line_count > 0:
            avg_length = total_length
            return avg_length
        else:
            print("No tokens processed!")
            return 0.0

def generate_dates():
    months = [
        "January", "February", "March", "April", "May", "June",
        "July", "August", "September", "October", "November", "December"
    ]
    days = [str(day) for day in range(1, 29)]  # 1-28
    years = [str(year) for year in range(1900, 2100)]  # 1900-2099
    
    dates = []
    for month in months:
        for day in days:
            for year in years:
                dates.append(f"{month} {day}, {year}")
    return dates

def get_average_dates_or_pronouns_length(data):
    nprocs = max(1, os.cpu_count() - 2)
    # if not dates:
    #     print("No dates generated!")
    #     return 0.0
    
    total_length = 0
    date_count = 0
    
    with mp.Pool(nprocs) as pool:
        chunk_size = max(1, nprocs)
        token_arrays = pool.imap(tokenize, data, chunksize=chunk_size)
        
        for tokens in token_arrays: 
            # print(tokens)
            total_length = max(total_length,len(tokens) - 1) # Exclude the begin token
            date_count += 1
        
        if date_count > 0:
            avg_length = total_length
            return avg_length
        else:
            print("No tokens processed!")
            return 0.0
def get_params():
    D_birth_year = 2099 - 1900 + 1  # Number of birth years, Not case sensitive
    D_birth_month = 12  # Number of birth months, Not case sensitive
    D_birth_day = 28  # Number of birth days, Not case sensitive
    D_birth_day_date = D_birth_year * D_birth_month * D_birth_day  # Number of birth days
    C_birth_day_date = 3  # Lens of birth
    dates = generate_dates()
    L_birth_day_date = get_average_dates_or_pronouns_length(dates)  # Lens of birth
    D_cities = 236  # Number of cities
    C_cities = 2  # Lens of cities
    L_cities = get_average_token_length("data/cities.txt")  # Lens of cities
    D_universities = 300  # Number of universities
    C_universities = 1  # Lens of universities
    L_universities = get_average_token_length("data/universities.txt")  # Lens of universities
    D_majors = 100 # Number of majors
    C_majors = 1  # Lens of majors
    L_majors = get_average_token_length("data/majors.txt")  # Lens of majors
    D_employer = 500  # Number of employees
    C_employer = 2  # Lens of employees
    L_employer = get_average_token_length("data/companies.csv")  # Lens of employees
    D_pronoun = 3  # Number of pronouns
    C_pronoun = 3  # Lens of pronouns
    pronouns = [
        ["he", "his", "him", "himself"],
        ["she", "her", "hers", "herself"],
        ["they", "them", "their", "theirs"]
    ]
    combined_strings = [" ".join(group) for group in pronouns]
    L_pronoun = get_average_dates_or_pronouns_length(combined_strings)  # Lens of pronouns
    params = [
        {"D": D_birth_day_date, "C": C_birth_day_date, "L": L_birth_day_date, "K": 1},
        {"D": D_cities, "C": C_cities, "L": L_cities, "K": 1},
        {"D": D_universities, "C": C_universities, "L": L_universities, "K": 1},
        {"D": D_majors, "C": C_majors, "L": L_majors, "K": 1},
        {"D": D_employer, "C": C_employer, "L": L_employer, "K": 1},
        {"D": D_pronoun, "C": C_pronoun, "L": L_pronoun, "K": 1},
    ]
    return params
    
def main():
    bits_per_parameter = 2  # Information content of 1 parameter is 2 bits
    N0 = 400 * 400 * 1000  # Total population
    K = 6  # Number of categories
    model_size = 25.17  # Model size in millions of parameters (25.17M)
    # Print the result
    print(f"Information content per parameter: {bits_per_parameter} bits")
    print(f"Total population (N0): {N0}")
    print(f"Number of categories (K): {K}")     
    print("Number of parameters: %.2fM" % (model_size,))
    params = get_params()
    # print(params)
    import math
    T = 50257 # Token dictionary size
    nkc_log2_d_coeff = sum(p["K"] * p["C"] * math.log2(p["D"]) for p in params)
    kd_log2_tld = sum(p["K"] * p["D"] * math.log2((T ** p["L"]) / p["D"]) for p in params)
    remaining_bits = bits_per_parameter * model_size * 1e6 - kd_log2_tld # one param is 2 bits

    # Solve N
    def solve_n(N):
        print(N * math.log2(N0 / N))
        print(nkc_log2_d_coeff * N)
        print(kd_log2_tld)
        return N * math.log2(N0 / N) + nkc_log2_d_coeff * N - remaining_bits

    def find_n(low, high):
        while high - low > 1:
            mid = (low + high) // 2
            if solve_n(mid) > 0:
                high = mid
            else:
                low = mid
        return low

    N = find_n(1, N0)
    print(f"Estimated N_bound: {N}")
    print(f"Total bits: {N * math.log2(N0 / N) + nkc_log2_d_coeff * N + kd_log2_tld:.2f}")

if __name__ == "__main__":
    main()
    