import os
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

matplotlib.use('Agg')

dir_name = './Expt_Data'
debug = False

def process_file(file_name):
    intra_cos = []
    inter_cos = []
    nc_intra = None
    nc_inter = None

    with open(file_name, 'r') as f:
        for line in f:
            if line.startswith('Test Set:'):
                break
            elif line.startswith('Intra Cos: '):
                intra_cos.append(float(line.split()[2]))
            elif line.startswith('Inter Cos: '):
                inter_cos.append(float(line.split()[2]))

    if len(intra_cos) >= 2:
        nc_intra = intra_cos[-2]
    if len(inter_cos) >= 2:
        nc_inter = inter_cos[-2]

    return intra_cos, inter_cos, nc_intra, nc_inter

def filename_to_dict(file_name):
    # Initialize an empty dictionary
    result = {}
    raw_file_name = '.'.join(file_name.split('.')[:-1])
    # Split the file name by underscores to extract keys and values
    parts = raw_file_name.split('_')
    # Loop through the parts to determine which are keys and which are values
    i = 0
    key = ""
    while i < len(parts):
        part = parts[i]
        # Check if the part is a key
        if part == 'dataset':
            # The next part is a value
            i += 1
            value = parts[i]
            # Add the key-value pair to the result dictionary
            result[part] = value
            i += 1
            continue
        else:
            if key == "model_type":
                is_value = True
                value = part
            else:
                # The part is either a key or a value
                is_value = False
                # Check if the part is a number
                try:
                    float_part = float(part)
                    # If the part is a number, it must be a value
                    is_value = True
                    value = float_part
                except ValueError:
                    # The part is not a number, so it could be a key or a string value
                    if part.lower() == 'true':
                        # If the part is "True", it must be a value
                        is_value = True
                        value = True
                    elif part.lower() == 'false':
                        # If the part is "False", it must be a value
                        is_value = True
                        value = False
                    elif part.lower() == 'none':
                        # If the part is "False", it must be a value
                        is_value = True
                        value = None
                    elif any(c.isdigit() for c in part):
                        is_value = True
                        value = part
                    else:
                        # Otherwise, assume the part is a key and move on to the next part to check if it's a value
                        is_value = False
                        value = parts[i]
            # Add the key-value pair to the result dictionary
            if is_value:
                result[key] = value
                key = ""
            else:
                if key == "":
                    key = part
                else:
                    key += "_" + part
            i += 1
    return result



def plot_data(ax, fixed_params, axis_param, random_avg_param, max_axis_param=None):
    # Initialize lists and dictionaries for storing values
    axis_values = []
    grouped_values = {}

    # Loop through all files in the folder
    for file_name in os.listdir(dir_name):
        # Only process files with .txt extension
        if not file_name.endswith('.txt'):
            continue
        # Get the dictionary from the file name
        full_name = os.path.join(dir_name, file_name)
        file_dict = filename_to_dict(file_name)

        # Only continue if the dictionary contains all same corresponding key-value entries as fixed_params
        if not all(file_dict.get(key) == value for key, value in fixed_params.items()):
            continue
        # Get the values for grouping
        axis_value = float(file_dict.get(axis_param))
        if max_axis_param and axis_value > max_axis_param:
            continue
        random_avg_value = file_dict.get(random_avg_param)

        # Create a unique key based on axis_param and random_avg_param
        group_key = axis_value

        # Call the process_file function to get nc_intra and nc_inter
        intra_cos, inter_cos, nc_intra, nc_inter = process_file(full_name)

        # Add the values to the grouped_values dictionary
        if group_key not in grouped_values:
            grouped_values[group_key] = {'nc_intra': [], 'nc_inter': []}

        grouped_values[group_key]['nc_intra'].append(nc_intra)
        grouped_values[group_key]['nc_inter'].append(nc_inter)

        # Keep note of the axis value
        if axis_value not in axis_values:
            axis_values.append(axis_value)

    if debug:
        print(filename_to_dict("bn_True_dataset_CIFAR10_epochs_100_lr_0.001_model_type_ResNet_rand_seed_265358_weight_decay_0.0005.txt"))
        exit()
    # Compute averages and standard errors for each group
    averages = []
    std_errors = []

    for group_key, group_data in grouped_values.items():
        nc_intra_avg = np.mean(group_data['nc_intra'])
        nc_inter_avg = np.mean(group_data['nc_inter'])
        nc_intra_std_err = np.std(group_data['nc_intra']) / np.sqrt(len(group_data['nc_intra']))
        nc_inter_std_err = np.std(group_data['nc_inter']) / np.sqrt(len(group_data['nc_inter']))

        averages.append((group_key, nc_intra_avg, nc_inter_avg))
        std_errors.append((group_key, nc_intra_std_err, nc_inter_std_err))

    # Sort averages and std_errors based on axis values
    averages.sort(key=lambda x: x[0])
    std_errors.sort(key=lambda x: x[0])

    # Extract axis values, nc_intra_avg, and nc_inter_avg from averages
    axis_values = [x[0] for x in averages]
    nc_intra_avg_values = [x[1] for x in averages]
    nc_inter_avg_values = [x[2] for x in averages]

    # Extract nc_intra_std_err and nc_inter_std_err from std_errors
    nc_intra_std_err_values = [x[1] for x in std_errors]
    nc_inter_std_err_values = [x[2] for x in std_errors]

    # Create a line plot with error bars
    ax.errorbar(axis_values, nc_intra_avg_values, yerr=nc_intra_std_err_values, label='min_intra', fmt='o-', capsize=5, markersize=3)
    ax.errorbar(axis_values, nc_inter_avg_values, yerr=nc_inter_std_err_values, label='max_inter', fmt='-.', capsize=5, markersize=3)
    ax.set_xscale('log')
    ax.set_ylim(-0.33, 1)
    #ax.set_title('Plot of NC Intra and NC Inter')
    ax.legend()

fig, axes = plt.subplots(3, 5, figsize=(15, 9))

# Define fixed_params, axis_param, and random_avg_param
fixed_params = {"bn": False, "dataset": "CIFAR10", "lr": 1e-2, "epochs": 200, "train_samples": None, "test_samples": None, "model_type": "vgg11"}
axis_param = 'weight_decay'
random_avg_param = 'rand_seed'

# Call plot_data for each subplot
plot_data(axes[0, 0], fixed_params, axis_param, random_avg_param, max_axis_param=0.01)
fixed_params['bn'] = True
plot_data(axes[0, 1], fixed_params, axis_param, random_avg_param, max_axis_param=0.01)
fixed_params = {"bn": False, "dataset": "CIFAR10", "lr": 1e-2, "epochs": 200, "train_samples": None, "test_samples": None, "model_type": "vgg19"}
plot_data(axes[0, 2], fixed_params, axis_param, random_avg_param, max_axis_param=0.01)
fixed_params['bn'] = True
plot_data(axes[0, 3], fixed_params, axis_param, random_avg_param, max_axis_param=0.01)
fixed_params = {'bn': False, 'dataset': 'CIFAR10', 'epochs': 100.0, 'lr': 0.001, 'model_type': 'ResNet'}
plot_data(axes[0, 4], fixed_params, axis_param, random_avg_param, max_axis_param=0.01)
fixed_params = {"bn": False, "dataset": "MNIST", "lr": 1e-2, "epochs": 200, "train_samples": None, "test_samples": None, "model_type": "vgg11"}
plot_data(axes[1, 0], fixed_params, axis_param, random_avg_param, max_axis_param=0.01)
fixed_params['bn'] = True
plot_data(axes[1, 1], fixed_params, axis_param, random_avg_param, max_axis_param=0.01)
fixed_params = {"bn": False, "dataset": "MNIST", "lr": 1e-2, "epochs": 200, "train_samples": None, "test_samples": None, "model_type": "vgg19"}
plot_data(axes[1, 2], fixed_params, axis_param, random_avg_param, max_axis_param=0.01)
fixed_params['bn'] = True
plot_data(axes[1, 3], fixed_params, axis_param, random_avg_param, max_axis_param=0.01)
fixed_params = {"bn": False, "dataset": "MNIST", "lr": 1e-3, "epochs": 100, "model_type": "ResNet"}
plot_data(axes[1, 4], fixed_params, axis_param, random_avg_param, max_axis_param=0.01)
fixed_params = {"bn": False, "dataset": "CIFAR100", "lr": 1e-2, "epochs": 200, "train_samples": None, "test_samples": None, "model_type": "vgg11"}
plot_data(axes[2, 0], fixed_params, axis_param, random_avg_param, max_axis_param=0.01)
fixed_params['bn'] = True
plot_data(axes[2, 1], fixed_params, axis_param, random_avg_param, max_axis_param=0.01)
fixed_params = {"bn": False, "dataset": "CIFAR100", "lr": 1e-2, "epochs": 200, "train_samples": None, "test_samples": None, "model_type": "vgg19"}
plot_data(axes[2, 2], fixed_params, axis_param, random_avg_param, max_axis_param=0.01)
fixed_params['bn'] = True
plot_data(axes[2, 3], fixed_params, axis_param, random_avg_param, max_axis_param=0.01)
fixed_params = {"bn": False, "dataset": "CIFAR100", "lr": 1e-2, "epochs": 200, "train_samples": None, "test_samples": None, "model_type": "ResNet"}
plot_data(axes[2, 4], fixed_params, axis_param, random_avg_param, max_axis_param=0.01)
# Set column labels
axes[0, 0].set_title('VGG11 (no BN)')
axes[0, 1].set_title('VGG11 (Contains BN)')
axes[0, 2].set_title('VGG19 (no BN)')
axes[0, 3].set_title('VGG19 (Contains BN)')
axes[0, 4].set_title('ResNet (Contains BN)')

# Set row labels
axes[0, 0].set_ylabel('CIFAR10')
axes[1, 0].set_ylabel('MNIST')
axes[2, 0].set_ylabel('CIFAR100')

axes[2, 0].set_xlabel('weight_decay')
axes[2, 1].set_xlabel('weight_decay')
axes[2, 2].set_xlabel('weight_decay')
axes[2, 3].set_xlabel('weight_decay')
axes[2, 4].set_xlabel('weight_decay')


# Adjust spacing
fig.tight_layout(rect=[0, 0, 1, 0.95])

# Save the plot
plt.savefig("plot_CV.png")

fig, axes = plt.subplots(2, 3, figsize=(9,6))

# Define fixed_params, axis_param, and random_avg_param
fixed_params = {"bn": False, "dataset": "CIFAR10", "lr": 1e-2, "epochs": 200, "train_samples": None, "test_samples": None, "model_type": "vgg11"}
axis_param = 'weight_decay'
random_avg_param = 'rand_seed'

# Call plot_data for each subplot
plot_data(axes[0, 0], fixed_params, axis_param, random_avg_param, max_axis_param=0.01)
fixed_params['bn'] = True
plot_data(axes[0, 1], fixed_params, axis_param, random_avg_param, max_axis_param=0.01)
fixed_params = {'bn': False, 'dataset': 'CIFAR10', 'epochs': 100.0, 'lr': 0.001, 'model_type': 'ResNet'}
plot_data(axes[0, 2], fixed_params, axis_param, random_avg_param, max_axis_param=0.01)
fixed_params = {"bn": False, "dataset": "CIFAR100", "lr": 1e-2, "epochs": 200, "train_samples": None, "test_samples": None, "model_type": "vgg11"}
plot_data(axes[1, 0], fixed_params, axis_param, random_avg_param, max_axis_param=0.01)
fixed_params['bn'] = True
plot_data(axes[1, 1], fixed_params, axis_param, random_avg_param, max_axis_param=0.01)
fixed_params = {"bn": False, "dataset": "CIFAR100", "lr": 1e-2, "epochs": 200, "train_samples": None, "test_samples": None, "model_type": "ResNet"}
plot_data(axes[1, 2], fixed_params, axis_param, random_avg_param, max_axis_param=0.01)
# Set column labels
axes[0, 0].set_title('VGG11 (no BN)')
axes[0, 1].set_title('VGG11 (Contains BN)')
axes[0, 2].set_title('ResNet (Contains BN)')

# Set row labels
axes[0, 0].set_ylabel('CIFAR10')
axes[1, 0].set_ylabel('CIFAR100')

axes[1, 0].set_xlabel('weight_decay')
axes[1, 1].set_xlabel('weight_decay')
axes[1, 2].set_xlabel('weight_decay')


# Adjust spacing
fig.tight_layout(rect=[0, 0, 1, 0.95])

# Save the plot
plt.savefig("plot_CV_2x3.png")

fig, axes = plt.subplots(3, 6, figsize=(16, 8))
# Set column labels
axes[0, 0].set_title('3-layer MLP without BN')
axes[0, 1].set_title('3-layer MLP with BN')
axes[0, 2].set_title('6-layer MLP without BN')
axes[0, 3].set_title('6-layer MLP with BN')
axes[0, 4].set_title('9-layer MLP without BN')
axes[0, 5].set_title('9-layer MLP with BN')

# Set row labels
axes[0, 0].set_ylabel('conic hull dataset')
axes[1, 0].set_ylabel('MLP3 dataset')
axes[2, 0].set_ylabel('MLP6 dataset')

axes[2, 0].set_xlabel('weight_decay')
axes[2, 1].set_xlabel('weight_decay')
axes[2, 2].set_xlabel('weight_decay')
axes[2, 3].set_xlabel('weight_decay')
axes[2, 4].set_xlabel('weight_decay')
axes[2, 5].set_xlabel('weight_decay')


# Define fixed_params, axis_param, and random_avg_param

axis_param = 'weight_decay'
random_avg_param = 'rand_seed'

# Call plot_data for each subplot
fixed_params = {"dataset": "conic", "model_type" : "MLP", "lr": 1e-3, "epochs": 200, 'model_depth_MLP': 3, 'bn': False}
plot_data(axes[0, 0], fixed_params, axis_param, random_avg_param)
fixed_params['bn'] = True
plot_data(axes[0, 1], fixed_params, axis_param, random_avg_param)
fixed_params = {"dataset": "conic", "model_type" : "MLP", "lr": 1e-3, "epochs": 200, 'model_depth_MLP': None, 'bn': False}
plot_data(axes[0, 2], fixed_params, axis_param, random_avg_param)
fixed_params['bn'] = True
plot_data(axes[0, 3], fixed_params, axis_param, random_avg_param)

fixed_params = {"dataset": "conic", "model_type" : "MLP", "lr": 1e-3, "epochs": 200, 'model_depth_MLP': 9, 'bn': False}
plot_data(axes[0, 4], fixed_params, axis_param, random_avg_param)
# There's some problems with naming here
fixed_params['bn'] = True
fixed_params['epochs'] = 100
del fixed_params['model_type']
print(fixed_params)
plot_data(axes[0, 5], fixed_params, axis_param, random_avg_param)

# Call plot_data for each subplot
fixed_params = {"dataset": "mlp3", "model_type" : "MLP", "lr": 1e-3, "epochs": 100, 'model_depth_MLP': 3, 'bn': False}
plot_data(axes[1, 0], fixed_params, axis_param, random_avg_param)
fixed_params['bn'] = True
plot_data(axes[1, 1], fixed_params, axis_param, random_avg_param)
fixed_params = {"dataset": "mlp3", "model_type" : "MLP", "lr": 1e-3, "epochs": 100, 'model_depth_MLP': 6, 'bn': False}
plot_data(axes[1, 2], fixed_params, axis_param, random_avg_param)
fixed_params['bn'] = True
plot_data(axes[1, 3], fixed_params, axis_param, random_avg_param)
fixed_params = {"dataset": "mlp3", "model_type" : "MLP", "lr": 1e-3, "epochs": 100, 'model_depth_MLP': 9, 'bn': False}
plot_data(axes[1, 4], fixed_params, axis_param, random_avg_param)
fixed_params['bn'] = True
plot_data(axes[1, 5], fixed_params, axis_param, random_avg_param)

# Call plot_data for each subplot
fixed_params = {"dataset": "mlp6", "model_type" : "MLP", "lr": 1e-3, "epochs": 100, 'model_depth_MLP': 3, 'bn': False}
plot_data(axes[2, 0], fixed_params, axis_param, random_avg_param)
fixed_params['bn'] = True
plot_data(axes[2, 1], fixed_params, axis_param, random_avg_param)
fixed_params = {"dataset": "mlp6", "model_type" : "MLP", "lr": 1e-3, "epochs": 100, 'model_depth_MLP': 6, 'bn': False}
plot_data(axes[2, 2], fixed_params, axis_param, random_avg_param)
fixed_params['bn'] = True
plot_data(axes[2, 3], fixed_params, axis_param, random_avg_param)
fixed_params = {"dataset": "mlp6", "model_type" : "MLP", "lr": 1e-3, "epochs": 100, 'model_depth_MLP': 9, 'bn': False}
plot_data(axes[2, 4], fixed_params, axis_param, random_avg_param)
fixed_params['bn'] = True
plot_data(axes[2, 5], fixed_params, axis_param, random_avg_param)

plt.savefig("plot_MLP.png")

# Code for small main content plots

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

# Set column labels
axes[0, 0].set_title('3-layer MLP without BN')
axes[0, 1].set_title('3-layer MLP with BN')
axes[0, 2].set_title('6-layer MLP without BN')
axes[0, 3].set_title('6-layer MLP with BN')

# Set row labels
axes[0, 0].set_ylabel('conic hull dataset')
axes[1, 0].set_ylabel('MLP3 dataset')

# Set x-axis labels
for i in range(4):
    axes[1, i].set_xlabel('weight_decay')

# Define fixed_params, axis_param, and random_avg_param
axis_param = 'weight_decay'
random_avg_param = 'rand_seed'

# Call plot_data for each subplot for the conic hull dataset
fixed_params = {"dataset": "conic", "model_type": "MLP", "lr": 1e-3, "epochs": 200, 'model_depth_MLP': 3, 'bn': False}
plot_data(axes[0, 0], fixed_params, axis_param, random_avg_param, )
fixed_params['bn'] = True
plot_data(axes[0, 1], fixed_params, axis_param, random_avg_param)
fixed_params['model_depth_MLP'] = None
fixed_params['bn'] = False
plot_data(axes[0, 2], fixed_params, axis_param, random_avg_param)
fixed_params['bn'] = True
plot_data(axes[0, 3], fixed_params, axis_param, random_avg_param)

# Call plot_data for each subplot for the MLP3 dataset
fixed_params = {"dataset": "mlp3", "model_type": "MLP", "lr": 1e-3, "epochs": 100, 'model_depth_MLP': 3, 'bn': False}
plot_data(axes[1, 0], fixed_params, axis_param, random_avg_param)
fixed_params['bn'] = True
plot_data(axes[1, 1], fixed_params, axis_param, random_avg_param)
fixed_params['model_depth_MLP'] = 6
fixed_params['bn'] = False
plot_data(axes[1, 2], fixed_params, axis_param, random_avg_param)
fixed_params['bn'] = True
plot_data(axes[1, 3], fixed_params, axis_param, random_avg_param)

plt.savefig("plot_MLP_2x4.png")
