import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from experiment.exp import get_results_save_path, get_visualization_save_path

def plot_results(args, n):

    results_file = get_results_save_path(args)
    output_file = get_visualization_save_path(args)
    indices = np.arange(0, n+1, 1)
    ratios = [i / n for i in range(n+1)]
    ratios = np.array(ratios)[indices]
    dataset_name = args.dataset

    # Load the JSON results file
    with open(results_file, 'r') as f:
        results = json.load(f)
    
    # Prepare data for CSV with std expressed as percentages
    mean_std_data_csv = {"ratios": ratios}

    # Define the colormap
    cmap = plt.get_cmap('viridis')
    
    # Create a mapping of models to colors
    model_keys = [key.split('_')[1] for key in results.keys() if "processed_" in key]
    unique_models = sorted(set(model_keys), key=model_keys.index)
    color_map = {model: cmap(i / len(unique_models)) for i, model in enumerate(unique_models)}

    # Set a color for the base model
    base_color = 'red'

    # Plotting edge removal, all models
    for action in ["remove", "restore"]:
        plt.figure(figsize=(6, 6))

        # Plot base model results
        base_key = f"base_{action}"
        base_accuracies = np.array(results[base_key])
        mean_base_accuracies = np.mean(base_accuracies, axis=0)[indices]
        std_base_accuracies = np.std(base_accuracies, axis=0)[indices]

        mean_std_data_csv[f"mean_{base_key}"] = mean_base_accuracies.tolist()
        mean_std_data_csv[f"std_{base_key} (%)"] = (std_base_accuracies / mean_base_accuracies * 100).tolist()

        # Plot base model results, label added only once
        plt.plot(ratios, mean_base_accuracies, marker='x', linestyle='--', color=base_color, label='Base Model', markersize=5)
        plt.errorbar(ratios, mean_base_accuracies, yerr=std_base_accuracies, marker='x', linestyle='--', color=base_color)
        plt.fill_between(ratios, mean_base_accuracies - std_base_accuracies, mean_base_accuracies + std_base_accuracies, color=base_color, alpha=0.1)

        # Plot processed model results
        for key in [key for key in results.keys() if f"processed_" in key and f"_{action}" in key]:
            model_name = key.split('_')[1]  # Extract the model name
            color = color_map[model_name]
            accuracies = np.array(results[key])
            mean_accuracies = np.mean(accuracies, axis=0)[indices]
            std_accuracies = np.std(accuracies, axis=0)[indices]

            mean_std_data_csv[f"mean_{key}"] = mean_accuracies.tolist()
            mean_std_data_csv[f"std_{key} (%)"] = (std_accuracies / mean_accuracies * 100).tolist()

            plt.plot(ratios, mean_accuracies, marker='o', linestyle='-', color=color, label=f"Processed {int(model_name)+1}", markersize=5)
            plt.errorbar(ratios, mean_accuracies, yerr=std_accuracies, marker='o', linestyle='-', color=color)
            plt.fill_between(ratios, mean_accuracies - std_accuracies, mean_accuracies + std_accuracies, color=color, alpha=0.1)

        plt.xlabel('Ratio')
        plt.ylabel('Test Accuracy')
        plt.title(f'{dataset_name} - {action.capitalize()}')
        plt.legend(loc='best')
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(output_file.replace('result.jpg', f'{action}.jpg'))
        plt.show()

    # Save mean and std data to CSV
    df_mean_std_csv = pd.DataFrame(mean_std_data_csv)
    mean_std_file_csv = results_file.replace('.json', '_mean_std.csv')
    df_mean_std_csv.to_csv(mean_std_file_csv, index=False)