# source /net/scratch/user/venvs/bin/activate

import wandb
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from tenacity import retry, wait_exponential, stop_after_attempt
import pickle



# Function to retrieve runs from Weights & Biases
@retry(wait=wait_exponential(multiplier=1, min=2, max=20), stop=stop_after_attempt(5))
def get_wandb_runs(api, entity, project_name):
    return api.runs(f"{entity}/{project_name}")

def load_runs_from_file(file_path):
    with open(file_path, 'rb') as f:
        return pickle.load(f)

def pickle_all_wandb_runs_dict(project_name, reload=0):
    wandb_entity='pyt-geo' 

    # Define the cache file path
    cache_dir = f"/directory/pickle/{project_name}/"
    os.makedirs(cache_dir, exist_ok=True)
    cache_file = os.path.join(cache_dir, "wandb_runs.pkl")
    
    # Handle the reload flag
    if reload == 1 and os.path.exists(cache_file):
        print(f"Reload flag is set. Deleting cache file: {cache_file}")
        os.remove(cache_file)
    
    # Check if we have cached runs
    if os.path.exists(cache_file) and reload == 0:
        print(f"Loading runs from cache: {cache_file}")
        with open(cache_file, 'rb') as f:
            cached_data = pickle.load(f)
            valid_runs = cached_data['valid_runs']
            len_eval_steps = cached_data['len_eval_steps']
            dataset_name = cached_data['dataset_name']
            eval_dataset_name = cached_data['eval_dataset_name']
    else:
        print("Starting the API...")
        api = wandb.Api(timeout=5000)
        print("Getting wandb runs...")
        runs = get_wandb_runs(api, wandb_entity, project_name)
        print("Got wandb runs...")

        # Initialize the dictionary to store valid runs
        valid_runs = {}

        len_eval_steps = eval_every = dataset_name = eval_dataset_name = None
        for i, run in enumerate(runs):
            try:
                if  run.state == 'finished':
                    run_dict = {}

                    # Store configuration values
                    for key, value in run.config.items():
                        run_dict[key] = value
                        # print(key,value)

                    # Dynamically fetch all available metrics for the run
                    history = run.history()
                    for metric in history.columns:
                        run_dict[metric] = history[metric].tolist()

                    # Add this run to valid_runs
                    valid_runs[run.name] = run_dict

                    len_eval_steps = len(run_dict['eval_test_loss']) 
                    dataset_name = run.config['dataset']
                    eval_dataset_name = run.config['eval_dataset']
                    print(f"Processed run {i}: {run.name}")
            except Exception as e:
                print(f"Failed to process run {i}: {e}")

        # Save the selected_runs, len_eval_steps, eval_every, and dataset_name to a pickle file
        cache_data = {
            'valid_runs': valid_runs,
            'len_eval_steps': len_eval_steps,
            # 'eval_every': eval_every,
            'dataset_name': dataset_name,
            'eval_dataset_name': eval_dataset_name
        }

        with open(cache_file, 'wb') as f:
            pickle.dump(cache_data, f)
        print(f"Saved valid runs and metadata to cache: {cache_file}")

    return valid_runs, len_eval_steps, dataset_name, eval_dataset_name


def plot_metric(valid_runs, wandb_project = 'AME Plots', run_name = 'Full Plots',
                sort_metric = 'num_ingredients', plot_this_metric = 'test_accuracy', colormap_range = [0,1], 
                x_label = 'Training Epochs', y_label = 'Accuracy', title_label = 'Thing', y_lim = (0, 1)):

    # Set the max number of markers
    max_markers = 12
    marker = 'd' 
    line_thickness = 2.0
    log_every = 1  # Adjust this based on the spacing in x-axis (modify if needed)
    figsize = (6, 5)

    plt.rc('text', usetex=True)
    plt.rc('font', family='serif', size=25) 
    colormap = plt.get_cmap('viridis')

    # Get list of sort_metric from the runs and sort it in increasing order
    sorted_list = sorted(set(run[sort_metric][-1] for run in valid_runs.values()))
    # sorted_list = colormap_range

    # Normalize the colormap based on the num_ingredients_list
    norm = plt.Normalize(min(sorted_list), max(sorted_list))

    plt.figure(figsize=figsize)
    for run_name, run_data in valid_runs.items():
        try:
            # Get the num_ingredients and eval_total_traindata_accuracy for this run
            sort_metric_val = run_data[sort_metric]
            plot_metric_val = run_data[plot_this_metric]
            color = colormap(norm(sort_metric_val[-1]))

            # Generate evenly spaced indices for the diamond markers
            x = np.arange(len(plot_metric_val))
            markevery_indices = np.linspace(0, len(x) - 1, max_markers, dtype=int)

            # Plot the timeseries data
            plt.plot(x * log_every, plot_metric_val, color=color, 
                     label=f'Run {run_name} (ingredients={sort_metric_val})',
                     linewidth=line_thickness, marker=marker, 
                     markevery=markevery_indices, alpha=0.5)
        except Exception as e:
            print(f"Error processing run {run_name}: {e}")

    plt.xlabel(x_label, fontsize=35)
    plt.ylabel(y_label, fontsize=35)
    plt.title(title_label, fontsize=40) 
    sm = cm.ScalarMappable(norm=norm, cmap=colormap)
    sm.set_array([])  

    # Automatically adjust x-axis to fit the data
    plt.gca().autoscale(enable=True, axis='x', tight=True)
    # plt.legend(loc='best')
    # plt.ylim(y_lim)
    plt.grid(True)
    plt.show()

    save_directory='/net/scratch/user/extractresults'
    os.makedirs(save_directory, exist_ok=True)
    pdf_file = os.path.join(save_directory, 'wandb_plot.pdf')
    plt.savefig(pdf_file, format='pdf', bbox_inches='tight')
    plt.show()

    # Initialize Weights and Biases and upload the plot as an artifact
    os.environ['WANDB_DIR'] = '/net/scratch/user/wandblog/'
    wandb.init(project=wandb_project, entity='pyt-geo', name=run_name)
    artifact = wandb.Artifact('plot_pdf', type='report')
    artifact.add_file(pdf_file)
    wandb.log_artifact(artifact)


sort_metric = 'eval_test_loss' 
plot_this_metric = 'eval_test_accuracy' 
colormap_range = [0,15]
x_label = 'Training Epochs' 
y_label = 'Test Accuracy' 
title_label = '60\% Out-Of-Dist.' 
y_lim = (0, 1)
project_name = 'proj_name' 

valid_runs, len_eval_steps, dataset_name, eval_dataset_name = pickle_all_wandb_runs_dict(project_name, reload=0)
plot_metric(valid_runs, sort_metric = sort_metric, plot_this_metric = plot_this_metric,
            colormap_range = colormap_range, x_label = x_label, y_label = y_label, title_label = title_label, 
            y_lim = y_lim)