import numpy as np
import pandas as pd
import os
from datetime import datetime
from utils import change_box_size, convert_to_image
import segno
from segno.encoder import DataOverflowError

version = 3
error_correction = 'L'
box_size = 1
border=0
mask_pattern = 0
output_dir = f"./dataset_segno/ver{version}/data_domain_ver{version}_mask{mask_pattern}_{error_correction}_sample"
# output_dir = os.path.join(output_dir, f"alphabet")
test_data_num = 1000
mode = None

# df = pd.read_csv('./data/top-1m.csv')
# all_words = df['domain'].values
# all_words = [s for s in all_words if len(s) <= 24]
# all_words = list(all_words)

all_words = ['aaai.org']*1000


# df = pd.read_csv('./data/alphabet.csv')
# # df = pd.read_csv("./data/random_fake_domains.csv")
# all_words = df['domain'].values
# all_words = list(all_words)

num_data = len(all_words)
# word_index = random.sample(range(num_data), num_data)
# random_index = random.sample(range(len(all_words)), num_data)

if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    print(f"created {output_dir} directory.")
    

input_texts = []
target_texts = []
qr_mask_pattern_list = []
qr_version_list = []
qr_error_correction_list = []
qr_mode_list = []

for i in range(num_data):

    content = all_words[i]
    # if len(content) > 14:
    #     continue

    try:
        qr = segno.make(content, 
                version=version, 
                error=error_correction, 
                mask=mask_pattern, 
                # encoding="utf-8", 
                mode = 'byte',
                # micro=None,
                boost_error=False)
    except DataOverflowError as e:
        print(content, len(content))
        continue

    data = np.array(qr.matrix, dtype=int).astype(int).flatten()

    size = np.sqrt(len(data)).astype(int)
    if box_size > 1:
        data = change_box_size(data, box_size)

    data = "".join(map(str, data))

    if qr.version == version:
        input_texts.append(data)
        target_texts.append(content)
        qr_mask_pattern_list.append(qr.mask)
        qr_version_list.append(qr.version)
        qr_error_correction_list.append(qr.error)
        qr_mode_list.append(qr.mode)
    else:
        print(qr.version, content)

    # if i < 100:
    #     qr.save(f'./{output_dir}/sample/{i}.png', scale=box_size, border=border)

data = {'target': target_texts, 'input': input_texts, 'version': qr_version_list, 'error_correction': qr_error_correction_list, 'mask_pattern': qr_mask_pattern_list, 'mode': qr_mode_list}
df = pd.DataFrame(data)
df[:-test_data_num].to_csv(f'{output_dir}/trainset.csv', index=False)
df[-test_data_num:].to_csv(f'{output_dir}/testset.csv', index=False)

with open(f'{output_dir}/setting.txt', 'w') as f:
    f.write(f'date and time: {datetime.now()}\n')
    f.write(f'num_all_data: {len(df)}\n')
    f.write(f'num_train_data: {len(df) - test_data_num}\n')
    f.write(f'num_test_data: {test_data_num}\n')
    f.write(f'version: {version}\n')
    f.write(f'error_correction: {error_correction}\n')
    f.write(f'box_size: {box_size}\n')
    f.write(f'border: {border}\n')
    f.write(f'mask_pattern: {mask_pattern}\n')
    f.write(f'output_dir: {output_dir}\n')
    f.write(f'library: segno\n')