import matplotlib
import os
import re
import numpy as np

from utils import root_dir


def parse_logs(directory):
    successful_samples_pattern = re.compile(
        r'"Successful adv samples: (\d+)",\n    "",'
    )
    model_accuracy_pattern = re.compile(r'Origin model accuracy on adv test data: (\d+\.\d+)')
    parsed_data = []
    for i, filename in enumerate(os.listdir(directory)):
        if filename.endswith('.txt'):
            filepath = os.path.join(directory, filename)
            with open(filepath, 'r') as file:
                content = file.read()
                successful_samples_matches = successful_samples_pattern.findall(content)
                model_accuracy_matches = model_accuracy_pattern.findall(content)
                if len(successful_samples_matches) >= 20 and len(model_accuracy_matches) >= 20:
                    file_data = []
                    for successful_samples, model_accuracy in zip(successful_samples_matches[:19], model_accuracy_matches[:19]):
                        file_data.append([int(successful_samples), float(model_accuracy)])
                    parsed_data.append(file_data)

    log_info = np.array(parsed_data)

    return log_info


def test():
    # cifar_100_first_attack_stat = [30, 20, 40, 20, 30, 20, 20, 30, 20, 20, 20, 20, 20]
    # cifar_100_first_attack_stat += [30, 20, 20, 20, 40, 20, 30, 30, 20, 20, 20, 20]
    # cifar_10_first_attack_stat = [30, 20, 120, 20, 20, 40, 20, 20, 30, 20, 40]
    # cifar_10_first_attack_stat += [120, 20, 20, 30, 30, 30, 20, 20, 20, 20, 20, 30, 30, 30]
    # print(sum(cifar_100_first_attack_stat) / len(cifar_100_first_attack_stat))
    # print(sum(cifar_10_first_attack_stat) / len(cifar_10_first_attack_stat))
    # cifar_10_resnet34_18 = [330, 360, 600, 330, 600, 330, 600, 330, 330, 600, 450, 600]
    # print(sum(cifar_10_resnet34_18) / len(cifar_10_resnet34_18))

    init = 10
    adv_size = 10
    directory_path = root_dir / 'results'
    log_info = parse_logs(directory_path)
    # print(log_info)
    log_info_mean = np.mean(log_info, axis=0)
    print(log_info_mean)

    result = np.zeros_like(log_info_mean, dtype=float)
    result[:, 0] = np.cumsum(log_info_mean[::-1, 0])[::-1]

    for i in range(log_info_mean.shape[0]):
        total_attacks = np.sum(log_info_mean[i:, 0])
        if total_attacks > 0:
            total_non_success = np.sum(log_info_mean[i:, 0] * log_info_mean[i:, 1])
            result[i, 1] = 1 - total_non_success / total_attacks
    print(result)

    new_column = (log_info_mean.shape[0] * 10) / (result[:, 0] * result[:, 1])
    new_column_2 = np.zeros(shape=(log_info_mean.shape[0],))
    new_column_2.fill(log_info_mean.shape[0] * adv_size + init)
    result = np.column_stack((new_column_2, result, new_column))
    result = np.round(result, 2)
    print(result)


if __name__ == "__main__":
    test()

    pass
