import json
import os
import random
from math import ceil

import pandas
import pandas as pd
from datasets import DatasetDict, Dataset
from sklearn.model_selection import train_test_split

from data.generate_formula_ir import generate
from data.pretraining import split_data
from tools import remove_suffix


steps = 250000

def merge_csv_files_in_folder(folder_path, version, output_file='merged_data.csv', drop_duplicates=True):
    """
    Merge all CSV files in a folder into a single CSV file.

    Parameters:
        folder_path (str): The path to the folder containing CSV files.
        output_file (str): The name of the output merged CSV file.

    Returns:
        None
    """
    # Get a list of all CSV files in the folder
    csv_files = [f for f in os.listdir(folder_path) if f.endswith('.csv') and version in f]

    # Initialize an empty DataFrame to store the merged data
    merged_data = pd.DataFrame()

    separator = ';;'
    # Loop through each CSV file and merge it into the merged_data DataFrame
    for csv_file in csv_files:
        print("Process %s" % csv_file)
        file_path = os.path.join(folder_path, csv_file)
        data = pd.read_csv(file_path, delimiter=separator, engine='python', on_bad_lines='skip')
        merged_data = pd.concat([merged_data, data], ignore_index=True)

    if drop_duplicates:
        total_len = len(merged_data)
        merged_data = merged_data.drop_duplicates(subset=['id', 'text']).sample(frac=1)
        reduced_len = len(merged_data)
        print("Reduced data from %d to %d due to duplicate removal" % (total_len, reduced_len))

    if 'label' in merged_data:
        merged_data = merged_data[merged_data['label'].isin([0, 1])]

    # Save the DataFrame to a CSV file with ";;" as the separator
    with open(output_file, 'w', encoding='utf8') as file:
        # Write the column names
        file.write(separator.join(merged_data.columns) + '\n')

        # Write the data rows
        for _, row in merged_data.iterrows():
            file.write(separator.join(map(str, row)) + '\n')

    print(f'Merged {len(csv_files)} CSV files into {output_file} with {len(merged_data)} entries with {merged_data["post_id"].nunique()} post ids')


def merge_amps_arqmath(amps, arqmath, output_file, remove_duplicates=False):
    # Initialize an empty DataFrame to store the merged data
    merged_data = pd.DataFrame()

    # Loop through each CSV file and merge it into the merged_data DataFrame
    for csv_file in [amps, arqmath]:
        data = pd.read_csv(csv_file, delimiter=';;', engine='python')
        merged_data = pd.concat([merged_data, data], ignore_index=True)

    separator = ';;'
    merged_data = merged_data.drop('uuid', axis=1)

    if remove_duplicates:
        total_len = len(merged_data)
        merged_data = merged_data.drop_duplicates(subset=['id', 'text']).sample(frac=1)
        reduced_len = len(merged_data)
        if reduced_len < total_len:
            print("Reduced data from %d to %d due to duplicate removal" % (total_len, reduced_len))
        else:
            print("Total data len %d" % total_len)

    merged_data = merged_data[~merged_data['text'].str.contains('000000000000000000')]

    # Save the DataFrame to a CSV file with ";;" as the separator
    with open(output_file, 'w', encoding='utf8') as file:
        # Write the column names
        file.write(separator.join(merged_data.columns) + '\n')

        # Write the data rows
        for _, row in merged_data.iterrows():
            file.write(separator.join(map(str, row)) + '\n')

    print(f'Merged amps + arqmath files into {output_file} with {len(merged_data)} entries')

    # create train test split
    split_data(output_file, proportions=(0.9, 0.1))
    print("Created train test split")

def merge_named_formulas(files, output='../../final/data/nfir.csv', remove_duplicates=True, save_temp=True, max_trues_per_id=None):
    if isinstance(files, str):
        if os.path.isdir(files):
            files = files
            items = os.listdir(files)
            files = [os.path.join(files, item) for item in items if os.path.isfile(os.path.join(files, item))]
        else:
            df = pd.read_csv(files)

    if isinstance(files, list):
        data = []
        for file in files:
            print(file)
            df = pd.read_csv(file)
            df['formula'] = df['formula'].str.replace(' Longrightarrow', r' \Longrightarrow', regex=False).replace(' rightarrow', r' \rightarrow', regex=False)
            data.append(df)

        print("Concat %d DataFrames" % len(data))
        df = pd.concat(data)
        if remove_duplicates:
            print("Remove Duplicates")
            total_len = len(df)
            df = df.drop_duplicates(subset=['formula_name_id', 'name', 'formula']).sample(frac=1)
            reduced_len = len(df)
            if reduced_len < total_len:
                print("Reduced data from %d to %d due to duplicate removal" % (total_len, reduced_len))
            else:
                print("Total data len %d" % total_len)

        if 'substituted_var' not in df:
            print("Calculate Substituted stats")
            df['substituted_var'] = df['stats'].apply(lambda x: "substitution_var': {}" not in x)
            df['substituted_fnc'] = df['stats'].apply(lambda x: "substitution_fnc': {}" not in x)
            df['substituted'] = df['substituted_var'] | df['substituted_fnc']

            print("Calculate Strategy Manual/ Random")
            df['strategy_manual'] = df['strategy_random_formula'] & df['stats'].apply(lambda x: "no_version': True" in x and "strategy_random_formula': {'old" in x)
            df['strategy_random_formula'] = df['strategy_random_formula'] & ~df['strategy_manual']

        # at least 20% of the formulas should be true random formulas
        number_of_random_formulas = len(df[df['strategy_random_formula']])
        total_len = len(df)
        additional = {column: [] for column in {'name', 'formula', 'formula_name_id', 'is_text', 'substituted_var', 'substituted_fnc', 'substituted'}}
        if number_of_random_formulas / total_len < 0.2:
            # simply combine random pairs
            factor = (0.172 - number_of_random_formulas / total_len)
            print("Adjust number of random formulas by %d%%" % (factor * 100))
            # this is super inefficient code, however, it only needs to be run once, so who cares
            for id, data in df.groupby('formula_name_id'):
                print(id)
                trues = data[data['label']]
                false_candidates = df[df['label'] & (df['formula_name_id'] != id)]

                names = list(set(trues['name']))
                l = int(len(data) * factor)
                for _, random_false in false_candidates.sample(n=l).iterrows():
                    random_name = random.choice(names)

                    additional['name'].append(random_name)
                    additional['formula'].append(random_false['formula'])
                    additional['is_text'].append(random_false['is_text'])
                    additional['formula_name_id'].append(id)
                    additional['substituted_var'].append(random_false['substituted_var'])
                    additional['substituted_fnc'].append(random_false['substituted_fnc'])
                    additional['substituted'].append(random_false['substituted'])

            false_df = pd.DataFrame.from_dict(additional)
            false_df['label'] = False
            false_df['strategy_random_formula'] = True
            false_df['strategy_equality'] = False
            false_df['strategy_inequality'] = False
            false_df['strategy_swap'] = False
            false_df['strategy_variables'] = False
            false_df['strategy_manual'] = False
            false_df['strategy_constants'] = False
            false_df['strategy_distribute'] = False
            false_df['original_id'] = None
            false_df['stats'] = "{}"
            false_df['id'] = None
            print("Added %d negative entries with random formulas" % len(false_df))

            df = pd.concat([df, false_df]).sample(frac=1).reset_index(drop=True)
        if save_temp:
            df.to_csv('../../data/generated/formula-naming/temp.csv')

    ssv = sum(df['substituted_var'])
    ssf = sum(df['substituted_fnc'])
    sst = sum(df['substituted'])
    l = len(df)
    print("Substituted %d times variables (%d%%)" % (ssv, 100 * ssv / l))
    print("Substituted %d times functions (%d%%)" % (ssf, 100 * ssf / l))
    print("Substituted %d times (%d%%)" % (sst, 100 * sst / l))

    true_df = df[df['label']]
    ssv = sum(true_df['substituted_var'])
    ssf = sum(true_df['substituted_fnc'])
    sst = sum(true_df['substituted'])
    ll = len(true_df)
    print("Stats only for positive entries")
    print("Substituted %d times variables (%d%%)" % (ssv, 100 * ssv / ll))
    print("Substituted %d times functions (%d%%)" % (ssf, 100 * ssf / ll))
    print("Substituted %d times (%d%%)" % (sst, 100 * sst / ll))


    strategies = []
    l = len(df)
    for column in df:
        if column.startswith('strategy'):
            s = sum(df[column])
            print("%s: %d entries (%f)" % (column, s, s / l))
            strategies.append(column)

    trues = len(true_df)
    falses = len(df) - trues
    print("Number of trues %d, of falses %d, factor false %f" % (trues, falses, falses / trues))
    epochs = steps / ((2 * trues * 0.9) / (16 * 8))
    print("Number of epochs for 500k steps: %f" % epochs)

    # truncate formula_name_ids with too many entries
    print("Truncate formula_name_ids with too many entries")
    value_counts = df['formula_name_id'].value_counts()
    print(value_counts)

    true_value_counts = df[df['label']]['formula_name_id'].value_counts()
    print(true_value_counts)

    trues_dict = {}
    min_count = max(min(true_value_counts), 10000)
    for factor in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]:
        max_count = factor * min_count

        reduced_data = []

        for formula_name_id, data in df.groupby('formula_name_id'):
            trues = data[data['label']]
            l = len(trues)
            if l > max_count:
                trues = trues.sample(n=max_count)
            trues_dict[formula_name_id] = trues

        trues = sum(len(d) for d in trues_dict.values())
        epochs = steps / ((2 * trues * 0.9) / (16 * 8))
        print("Number of epochs for %d steps for factor %d with reduced trues: %f" % (steps, factor, epochs))
        epochs_int = ceil(epochs)

        if epochs_int < 5:
            print("Stop with factor min-max-trues = %d and epochs = %d" % (factor, epochs_int))
            break

    for formula_name_id, data in df.groupby('formula_name_id'):
        falses = data[~data['label']]
        trues = trues_dict[formula_name_id]
        l = len(trues)
        n = epochs_int * l
        if len(falses) > n:
            falses = falses.sample(n=n)

        data = pd.concat([trues, falses])
        reduced_data.append(data)

    df = pd.concat(reduced_data).sample(frac=1).reset_index(drop=True)

    true_df = df[df['label']]
    trues = len(true_df)
    falses = len(df) - trues
    print("Number of trues %d, of falses %d, factor false %f" % (trues, falses, falses / trues))
    epochs = steps / ((2 * trues * 0.9) / (16 * 8))
    print("Number of epochs for %d steps: %f" % (steps, epochs))

    epochs_int = ceil(epochs)

    value_counts = df['formula_name_id'].value_counts()
    print(value_counts)

    true_value_counts = df[df['label']]['formula_name_id'].value_counts()
    print(true_value_counts)

    true_value_counts.to_csv(remove_suffix(output, '.csv') + '_true_amounts.csv', encoding='utf8')

    print("Save df")
    print(df)
    df.to_csv(remove_suffix(output, '.csv') + '_full.csv', encoding='utf8')

    df = df[['name', 'formula', 'formula_name_id', 'label']]
    df.to_csv(output, encoding='utf8')
    print("Done")

    return epochs_int

def create_split(txt_file, delimiter = ';;', reduce=['name', 'formula', 'formula_name_id', 'label']):
    sp = txt_file.split('/')
    directory = '/'.join(sp[:-1])
    name = remove_suffix(remove_suffix(sp[-1], '.txt'), '.csv')
    path = directory + '/' + name

    df = pandas.read_csv(txt_file, delimiter=delimiter, encoding='utf8')

    train_df, test_df = train_test_split(df, test_size=0.1, random_state=42, stratify=df['label'])

    # reduce the split length to a multiple of the batch length=16
    # remove trues, ensuring that the false factor remains at least as high as before
    # this is important for the pretraining splits
    new_len = (len(train_df) // 16) * 16
    falses = train_df[~train_df['label']]
    trues = train_df[train_df['label']].sample(n=new_len - len(falses))
    train_df = pd.concat([falses, trues]).sample(frac=1).reset_index(drop=True)

    new_len = (len(test_df) // 16) * 16
    falses = test_df[~test_df['label']]
    trues = test_df[test_df['label']].sample(n=new_len - len(falses))
    test_df = pd.concat([falses, trues]).sample(frac=1).reset_index(drop=True)

    print(len(train_df))
    print(len(test_df))
    print("Factor train: %s" % (len(train_df[~train_df['label']]) / sum(train_df['label'])))
    print("Factor test: %s" % (len(test_df[~test_df['label']]) / sum(test_df['label'])))

    dsd = DatasetDict()
    train = Dataset.from_pandas(train_df)
    test = Dataset.from_pandas(test_df)
    dsd['train'] = train
    dsd['test'] = test
    dsd.save_to_disk(path)

    if reduce:
        train_df = train_df[reduce]
        test_df = test_df[reduce]
        dsd = DatasetDict()
        train = Dataset.from_pandas(train_df)
        test = Dataset.from_pandas(test_df)
        dsd['train'] = train
        dsd['test'] = test

        path = path.replace('_full', '')
        dsd.save_to_disk(path)

    print("Done")

def split_arqmath(txt_file, proportions=(0.7, 0.15, 0.15), delimiter=';;', max=None):
    df = pandas.read_csv(txt_file, delimiter=delimiter, engine='python')

    if max is not None and len(df) > max:
        df = df.sample(n=max)

    # Calculate the split sizes based on the proportions
    total_size = len(df)
    split_sizes = [int(prop * total_size) for prop in proportions]

    # Split the DataFrame into three parts using train_test_split
    train_split, temp_split = train_test_split(df, train_size=split_sizes[0], stratify=df['label'])
    valid_split, test_split = train_test_split(temp_split, train_size=split_sizes[1], stratify=temp_split['label'])

    splitted = DatasetDict()
    splitted['train'] = Dataset.from_pandas(train_split).shuffle()
    splitted['test'] = Dataset.from_pandas(test_split).shuffle()
    splitted['validation'] = Dataset.from_pandas(valid_split).shuffle()

    sp = txt_file.split('/')
    directory = '/'.join(sp[:-1])
    name = remove_suffix(remove_suffix(sp[-1], '.txt'), '.csv')

    path = directory + '/' + name
    if max is not None:
        path += '_' + str(max)

    print(splitted)
    splitted.save_to_disk(path)
    return splitted

def reduce_dataset(input, output, versions_per_id=250, factor_false=10, id='formula_name_id', label='label', count_strategies=True, eval_by_id=False, raw_data='../../data/formula-naming/data.json'):
    print(input)
    dsd = DatasetDict.load_from_disk(input)
    reduced_dsd = DatasetDict()

    lens = [len(dsd[key]) for key in dsd]
    total_len = sum(lens)
    proportions = [l/total_len for l in lens]

    for i, key in enumerate(dsd):
        print(key)
        ds = dsd[key]

        df = ds.to_pandas()

        # remove Unnamed columns
        reduced_columns = [c for c in df.columns if 'unnamed' not in c.lower()]
        df = df[reduced_columns]

        print(df.value_counts(id))
        new_df = pd.DataFrame()
        proportion = proportions[i]
        for name_id, slice in df.groupby(id):
            trues = slice[slice[label]]
            falses = slice[~slice[label]]

            if versions_per_id is None:
                amount_trues = len(trues)
            else:
                amount_trues = int(versions_per_id * proportion)

            if amount_trues > len(trues):
                print("WARNING: Try to get %d true values for %s, but only %d are available" % (amount_trues, name_id, len(trues)))
                trues_subset = trues
            else:
                trues_subset = trues.sample(n=amount_trues)

            if versions_per_id is None:
                amount_falses = amount_trues * factor_false
            else:
                amount_falses = int(versions_per_id * factor_false * proportion)

            if amount_falses > len(falses):
                print("WARNING: Try to get %d false values for %s, but only %d are available" % (amount_falses, name_id, len(falses)))
                falses_subset = falses
            else:
                falses_subset = falses.sample(n=amount_falses)

            new_df = pd.concat([new_df, trues_subset, falses_subset])

        new_df = new_df.sample(frac=1)
        reduced_ds = Dataset.from_pandas(new_df)
        reduced_dsd[key] = reduced_ds

    if count_strategies:
        result = DatasetDict()
        strategy_columns = ['strategy_equality', 'strategy_manual', 'strategy_inequality', 'strategy_variables',
                            'strategy_random_formula', 'strategy_constants', 'strategy_distribute', 'strategy_swap']
        for key in reduced_dsd:
            tr = reduced_dsd[key]
            tr = tr.map(lambda example, idx: {
                        'strategy_count': sum(example[col] for col in strategy_columns),
                        'index': idx,
                        **example},
                        with_indices=True)
            result[key] = tr

        reduced_dsd = result

    if eval_by_id:
        ds = reduced_dsd['train']
        df = ds.to_pandas()
        validation = pd.DataFrame()

        data = json.load(open(raw_data, 'r', encoding='utf8'))
        data = {d['id']: d for d in data}

        for f_id, slice in df.groupby(id):
            raw_versions = [d['formula'] for d in data[f_id]['versions']]
            slices = []
            for rv in raw_versions:
                rv_slice = slice.copy()
                rv_slice = rv_slice.sample(n=min(100, len(rv_slice)))
                rv_slice['formula1'] = rv
                slices.append(rv_slice)
            validation = pd.concat([validation] + slices)
        reduced_columns = [c for c in validation.columns if not any (x in c.lower() for x in {'unnamed', 'index'})]
        validation = validation[reduced_columns]
        reduced_dsd['validation'] = reduced_dsd['test']
        reduced_dsd['test'] = Dataset.from_pandas(validation)

    reduced_dsd.save_to_disk(output)

def complement_datasets(dataset1, dataset2, complement):
    dsd1 = DatasetDict.load_from_disk(dataset1)
    dsd2 = DatasetDict.load_from_disk(dataset2)

    key = 'train'

    df1 = dsd1[key].to_pandas()
    df2 = dsd2[key].to_pandas()

    # Merge the two DataFrames to find rows in df1 not in df2
    merged_df = df1.merge(df2, how='left', indicator=True)

    # Filter for rows only in df1 (left_only)
    complement_df = merged_df[merged_df['_merge'] == 'left_only']

    # Drop the _merge column if not needed
    complement_df = complement_df.drop(columns=['_merge', '__index_level_0__'])

    new_dsd = DatasetDict()
    new_dsd[key] = Dataset.from_pandas(complement_df)
    new_dsd['test'] = dsd1['test']

    new_dsd.save_to_disk(complement)

def reduce_size(file):
    import csv
    output_file = file.removesuffix('.csv') + '_reduced.csv'
    with open(file, 'r', newline='') as input_csv, open(output_file, 'w', newline='') as output_csv:
        reader = csv.reader(input_csv)
        writer = csv.writer(output_csv)

        # Write the header (assuming the input file has a header)
        header = next(reader)
        writer.writerow(header)

        # Write the first 10,000 lines
        for i, row in enumerate(reader, start=1):
            if i > 10000:
                break
            writer.writerow(row)


def analyze(ds_path):
    ds = DatasetDict.load_from_disk(ds_path)
    for key in ds:
        dds = ds[key]
        df = dds.to_pandas()
        trues = df[df['label']]
        falses = df[~df['label']]
        print("Key %s, len %s, trues %s, falses %s" % (key, len(df), len(trues), len(falses)))


#merge_csv_files_in_folder('../../data/generated/mlm-math/v9', 'data_arqmath_asynch',  '../../data/generated/mlm-math/arqmath.csv')
#merge_csv_files_in_folder('../../data/generated/mlm-math-text/v8', 'data_arqmath_asynch',  '../../data/generated/mlm-math-text/arqmath.csv')
#merge_amps_arqmath('../../data/generated/mlm-math/amps.csv', '../../data/generated/mlm-math/arqmath.csv', '../../data/generated/mlm-math/data.csv')
#merge_amps_arqmath('../../data/generated/mlm-math-text/amps.csv', '../../data/generated/mlm-math-text/arqmath.csv', '../../data/generated/mlm-math-text/data.csv')
#create_split('../../data/generated/formula-naming/NSP_temp_V1000000False10_.csv', DELIMITER=',')

#merge_named_formulas(['../../data/generated/formula-naming/NSP_temp_V1000000False50.csv', '../../data/generated/formula-naming/NSP_temp_V1000000False50_.csv', '../../data/generated/formula-naming/NSP_temp_V1000000False25.csv', '../../data/generated/formula-naming/NSP_temp_V1000000False10_.csv'])
#merge_named_formulas(['../../data/generated/formula-naming/NSP_temp_V1000000False10_.csv'], '../../data/generated/data.csv')
#merge_named_formulas(['../../data/generated/formula-naming/NSP_temp_V100.csv'], '../../data/generated/data.csv')
#merge_named_formulas('../../data/generated/formula-naming/temp.csv')
#merge_named_formulas('../../data/generated/formula-naming/data', output='../../data/generated/formula-naming/nfir.csv')
#merge_named_formulas('../../data/generated/formula-naming/backup2', output='../../data/generated/formula-naming/nfir.csv', max_trues_per_id=None)

# todo
#epochs_int = merge_named_formulas('../../data/generated/formula-naming/temp.csv', output='../../data/generated/formula-naming/nfir.csv')

#analyze('../../final/data/nfir')
#analyze('../../final/data/ffir')

#generate(input='../../data/generated/formula-naming/nfir.csv', output='../../data/generated/formula-naming/ffir.csv', false_positives=4)

#create_split('../../data/generated/formula-naming/nfir.csv', DELIMITER=',')
#create_split('../../data/generated/formula-naming/ffir.csv', DELIMITER=',')

#reduce_dataset('../../final/data/nfir', '../../final/data/nfir-deberta', factor_false=2, count_strategies=False, versions_per_id=None)
#reduce_dataset('../../final/data/ffir', '../../final/data/ffir-deberta', factor_false=2, count_strategies=False, id='formula1_name_id', versions_per_id=None)
#complement_datasets('../../final/data/nfir', '../../final/data/nfir-deberta', '../../final/data/nfir-deberta-2')
#complement_datasets('../../final/data/ffir', '../../final/data/ffir-deberta', '../../final/data/ffir-deberta-2')

#create_split('../../data/generated/formula-naming/nfir_full.csv', DELIMITER=',')
#create_split('../../data/generated/formula-naming/ffir_full.csv', DELIMITER=',')

#reduce_dataset('../../data/generated/formula-naming/nfir_full', '../../data/generated/formula-naming/nfir-reduced', versions_per_id=250)
#reduce_dataset('../../data/generated/formula-naming/ffir_full', '../../data/generated/formula-naming/ffir-reduced', id='formula1_name_id', versions_per_id=250)

#reduce_dataset('../../final/data/nfir_full', '../../final/data/nfir-reduced-250')
#reduce_dataset('../../final/data/ffir', '../../final/data/ffir-reduced-250', id='formula1_name_id', count_strategies=False, eval_by_id=False)
#reduce_dataset('../../final/data/ffir', '../../final/data/ffir-reduced-250-eval', id='formula1_name_id', count_strategies=False, eval_by_id=True)

#split_arqmath('../../final/data/arqmath_question_text.csv', max=1000)

#create_split('../../final/data/nfir.csv', DELIMITER=',')
#create_split('../../final/data/ffir.csv', DELIMITER=',')
#split_data('../../final/data/mfm.csv', proportions=(0.9, 0.1))
#split_data('../../final/data/mtm.csv', proportions=(0.9, 0.1))

#merge_csv_files_in_folder('../../data/generated/arqmath/', 'formula_random', '../../final/data/arqmath_formula_random.csv', drop_duplicates=False)
#merge_csv_files_in_folder('../../data/generated/arqmath/', 'text_random', '../../final/data/arqmath_text_random.csv', drop_duplicates=False)
#merge_csv_files_in_folder('../../data/generated/arqmath/', 'formula_all', '../../final/data/arqmath_question_formula_all.csv', drop_duplicates=False)
#merge_csv_files_in_folder('../../data/generated/arqmath/', 'text_all', '../../final/data/arqmath_text_all.csv', drop_duplicates=False)
#merge_csv_files_in_folder('../../data/generated/arqmath/', 'all_answers', '../../final/data/arqmath_all_answers.csv', drop_duplicates=False)

#split_arqmath('../../final/data/arqmath_question_text.csv')
#split_arqmath('../../final/data/arqmath_question_text.csv')