import pandas as pd
import numpy as np
from utils import flip_bits, highlight, burst_box, burst_horizontal, burst_vertical
from qr_pattern import version_info
import os
from datetime import datetime

dataset_list = ['english', 'german', 'swahili', 'shuffled_english', 'alphabetic', 'missspelled', 'leetspeak']

for i in dataset_list:
    version = 3
    input_dir = f"dataset_segno/ver3/data_domain_ver3_mask0_L_generalization/{i}"
    output_dir = os.path.join(input_dir, "damaged/")
    # output_dir = './damaged/'
    num_flip = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
    num_burst_box = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    burst_box_size = [(3, 3)]
    # num_flip = [20]
    damage_type = ['flip', 'burst_box']
    input_file = os.path.join(input_dir, "testset.csv")
    test_data_num = 5000

    version = version_info(version)
    position_1d = version.pattern_1d

    def damage_flip(testset, num_flip):
        column_name = f"flip_{num_flip}"
        position_column_name = column_name + "_position"
        testset = testset.copy()
        flipped_data = testset["input"].apply(lambda x: flip_bits(x, position_1d, num_flip))
        output_df[column_name] = flipped_data.apply(lambda x: x[0])
        output_df[position_column_name] = flipped_data.apply(lambda x: x[1])

    def damage_burst_box(testset, num_burst, burst_size):
        burst_types = ["flip", "force0", "force1"]
        for burst_type in burst_types:
            column_name = f"burst_{num_burst}_{burst_size[0]}x{burst_size[1]}_{burst_type}"
            position_column_name = column_name + "_position"
            force_value = None
            testset = testset.copy()
            bursted_data = testset["input"].apply(lambda x: burst_box(x, position_1d, num_burst, burst_size)[burst_type])
            output_df[column_name] = bursted_data.apply(lambda x: x[0])
            output_df[position_column_name] = bursted_data.apply(lambda x: x[1])

    def damage_burst_line(testset, num_burst, burst_size):
        for direction in ["horizontal", "vertical"]:
            burst_types = ["flip", "force0", "force1"]
            for burst_type in burst_types:
                column_name = f"burst_{direction}_{num_burst}_{burst_size}_{burst_type}"
                position_column_name = column_name + "_position"
                force_value = None
                testset = testset.copy()
                if direction == "horizontal":
                    bursted_data = testset["input"].apply(lambda x: burst_horizontal(x, position_1d, num_burst, burst_size)[burst_type])
                else:
                    bursted_data = testset["input"].apply(lambda x: burst_vertical(x, position_1d, num_burst, burst_size)[burst_type])
                output_df[column_name] = bursted_data.apply(lambda x: x[0])
                output_df[position_column_name] = bursted_data.apply(lambda x: x[1])

    if __name__ == "__main__":
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        df = pd.read_csv(input_file)
        testset = df.iloc[-test_data_num:]
        output_df = pd.DataFrame()

        output_df['target'] = testset['target']
        output_df['input'] = testset['input']
        damage_type_list = []

        if 'flip' in damage_type:
            if not os.path.exists(output_dir + f"flip_sample/"):
                os.makedirs(output_dir + "flip_sample/")
            for i in num_flip:
                if not os.path.exists(output_dir + f"flip_sample/{i}/"):
                    os.makedirs(output_dir + f"flip_sample/{i}/")
                damage_flip(testset, i)
                sample_column = f"flip_{i}"
                sample_position_column = f"flip_{i}_position"
                damage_type_list.append(sample_column)
                for j in range(10):
                    hilight_position = output_df[sample_position_column].iloc[j]
                    hilight_position = version.convert_1d_to_2d(hilight_position)
                    hilight_input = output_df[sample_column].iloc[j]
                    hilight_image = highlight(hilight_input, hilight_position)
                    hilight_image.save(output_dir + f"flip_sample/{i}/{j}.png")
                    
        if 'burst_box' in damage_type:
            if not os.path.exists(output_dir + f"burst_box_sample/"):
                os.makedirs(output_dir + "burst_box_sample/")
            for i in num_burst_box:
                for j in burst_box_size:
                    if not os.path.exists(output_dir + f"burst_box_sample/{i}_{j[0]}x{j[1]}/"):
                        os.makedirs(output_dir + f"burst_box_sample/{i}_{j[0]}x{j[1]}/")
                    damage_burst_box(testset, i, j)
                    for k in ['flip', 'force0', 'force1']:
                        sample_column = f"burst_{i}_{j[0]}x{j[1]}_{k}"
                        sample_position_column = f"burst_{i}_{j[0]}x{j[1]}_{k}_position"
                        damage_type_list.append(sample_column)
                        for l in range(10):
                            hilight_position = output_df[sample_position_column].iloc[l]
                            hilight_input = output_df[sample_column].iloc[l]
                            hilight_image = highlight(hilight_input, hilight_position)
                            hilight_image.save(output_dir + f"burst_box_sample/{i}_{j[0]}x{j[1]}/{l}.png")
                            
        if 'burst_line' in damage_type:
            if not os.path.exists(output_dir + f"burst_line_sample/"):
                os.makedirs(output_dir + "burst_line_sample/")
            for direction in ["horizontal", "vertical"]:
                if not os.path.exists(output_dir + f"burst_line_sample/{direction}/"):
                    os.makedirs(output_dir + f"burst_line_sample/{direction}/")
                for i in num_burst_line:
                    for j in burst_line_size:
                        if not os.path.exists(output_dir + f"burst_line_sample/{direction}/{i}_{j}/"):
                            os.makedirs(output_dir + f"burst_line_sample/{direction}/{i}_{j}/")
                        damage_burst_line(testset, i, j)
                        for k in ['flip', 'force0', 'force1']:
                            sample_column = f"burst_{direction}_{i}_{j}_{k}"
                            sample_position_column = f"burst_{direction}_{i}_{j}_{k}_position"
                            damage_type_list.append(sample_column)
                            for l in range(10):
                                hilight_position = output_df[sample_position_column].iloc[l]
                                hilight_input = output_df[sample_column].iloc[l]
                                hilight_image = highlight(hilight_input, hilight_position)
                                hilight_image.save(output_dir + f"burst_line_sample/{direction}/{i}_{j}/{l}.png")

        output_df.to_csv(output_dir + "dataset_damaged.csv", index=False)
        with open(f'{output_dir}/setting.txt', 'w') as f:
            f.write(f"date and time: {datetime.now()}\n")
            f.write(f"num_flip: {num_flip}\n")
            f.write(f"damage_type: {damage_type}\n")
            f.write(f"input_file: {input_file}\n")
            f.write(f"output_dir: {output_dir}\n")
            f.write(f"data_num: {len(df)}\n")
            f.write(f"damage_type_list: {damage_type_list}\n")
