
import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

import random
import copy
import time
from dataclasses import dataclass

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import shutil

import torch.optim as optim
import tyro

import wandb
import pandas as pd

import matplotlib
import matplotlib.pyplot as plt

def moving_average(x, w):
    return (np.convolve(x, np.ones(w)) / w)[w:-(w)]

    
def get_checkpoints(save_location, env_prefix, seeds, model_name, models, problem_sets, problem_sets_prefixes, reward_models, baselines, run_name = 'sac_wip', key_words=None, save=False, initial_performance=False, checkpoint_steps=500, num_heads=8, num_layers=4):
    
    # problem_sets_prefixes = ["r5t20f30"]
    problem_sizes = [200]
    model_checkpoints = {}
    
    for problem_set_prefix in problem_sets_prefixes:
        for problem_size in problem_sizes:

            baselines_run = False
            for i, model in enumerate(models):
                xaxis = None
                y_values = []
                for reward_model in reward_models:
                    for seed in seeds:
                        env_id = f"{env_prefix} {problem_set_prefix}{problem_size}p {model_name[i]}"
                        # env_id = f"train {problem_set_prefix} h8l3 p{problem_size} hgtedge makespan reward edgeattention1 resnet all wt nbb split q_opt" # s{seed}"
                        # print("ENVID:", env_id)
                        # TODO: write the Wandb access code into a single file for easy access
                        if run_name == 'sac_wip':
                            file_name = f"final_checkpoints/wandb_access_codes/{run_name}__{'_'.join(env_id.split(' '))}_{reward_model}"
                        elif run_name == 'sac_wip_simultaneous':
                            file_name = f"final_checkpoints/wandb_access_codes/{run_name}__{'_'.join(env_id.split(' '))}{reward_model}"
                        else:
                            file_name = f"final_checkpoints/wandb_access_codes/{run_name}__{'_'.join(env_id.split(' '))}"
                        if key_words is not None:
                            for key in key_words:
                                file_name += f"_{key}"
                        
                        print(f"Directory: {os.path.dirname(file_name)}")
                        
                        if not os.path.exists(os.path.dirname(file_name)):
                            os.mkdir(os.path.dirname(file_name)) # make sure the directory exists
                        
                        file_name += f"__{num_heads}__{num_layers}__{model}_s{seed}_code.txt"
                        
                        # store the checkpoint ids in a text file instead
                        checkpoint_id_file = file_name.replace("wandb_access_codes", "checkpoint_ids")
                        
                        if not os.path.exists(os.path.dirname(checkpoint_id_file)):
                            os.mkdir(os.path.dirname(checkpoint_id_file)) # make sure the directory exists
                        
                        if initial_performance:
                            checkpoint_id = 0
                        elif os.path.exists(checkpoint_id_file):
                            with open(checkpoint_id_file, "r") as f:
                                checkpoint_id = f.read().strip()
                                
                        else: # recover it from the wandb file
                            checkpoint_id = 0
                            if os.path.exists(file_name):
                                with open(file_name, "r") as f:
                                    wandb_id = f.read().strip()
                                    
                                # print(f"Wandb ID is loaded from {env_id} {model} {seed}: {wandb_id}")
                            else:
                                raise FileNotFoundError(f"{file_name} does not exist")
                            
                            # read the wandb file and find the highest scoring checkpoint
                            # save the checkpoint to the final_checkpoints folder for evaluation
                            folder = "wandb_repo/Task Allocation and Scheduling with Path Planning using GNNs"
                            api = wandb.Api(timeout=120)
                            wandb_file = os.path.join(folder, wandb_id)
                            run = api.run(wandb_file)
                            data = run.scan_history()
                            df = pd.DataFrame(data)
                            y = df['greedy_reward'].dropna().to_numpy()
                            if not initial_performance:
                                ma = np.convolve(y, np.ones(problem_size)/problem_size)[problem_size-1:-problem_size+1]
                                # find the highest scoring checkpoint stored every 500 steps (for the relevant checkpoint id)
                                max_score = ma[::checkpoint_steps]
                                # print(len(max_score))
                                if len(max_score) <= 1:
                                    checkpoint_id = 0
                                else:
                                    max_score_index = np.argmax(max_score[1:])
                                    checkpoint_id = (max_score_index + 1) * checkpoint_steps # ignore zero index
                                    # checkpoint_id = 0 # ignore the checkpoint id for initial performance extraction
                                print(f"{env_id} {model} {seed}: {wandb_id} : {len(ma)} {checkpoint_id}")
                            
                            # write the checkpoint id to the file
                            with open(checkpoint_id_file, "w") as f:
                                f.write(str(checkpoint_id))
                            
                        if run_name == 'sac_wip':
                            checkpoint_location = f"{save_location}/{run_name}__{'_'.join(env_id.split(' '))}_{reward_model}"
                        else:
                            checkpoint_location = f"{save_location}/{run_name}__{'_'.join(env_id.split(' '))}"
                        if key_words is not None:
                            for key in key_words:
                                checkpoint_location += f"_{key}"
                        checkpoint_location += f"__{seed}__{models[i]}__{str(checkpoint_id).zfill(5)}.pt"
                        print(checkpoint_location)
                        checkpoint_final_location = checkpoint_location.replace("final_checkpoints", "final_checkpoints/checkpoint_ids")
                        print(checkpoint_final_location)
                        # copy the checkpoint in the checkpoint_location to the final_checkpoints/checkpoint_ids folder for easy access
                        shutil.copyfile(checkpoint_location, checkpoint_final_location)
                        checkpoint_location = checkpoint_final_location
                        
                            
                        print(checkpoint_location)
                        if 'no_critic' in reward_model:
                            model_checkpoints[f"{model} {seed} {problem_set_prefix} nc"] = checkpoint_location
                        else:
                            model_checkpoints[f"{model} {seed} {problem_set_prefix}"] = checkpoint_location
                        
                        # check if the checkpoint exists
                        if not os.path.exists(checkpoint_location):
                            raise FileNotFoundError(f"{checkpoint_location} does not exist")
                        
                        
            #     # multiple seeds for the same model, use fill_between to show the variance
            #     y_np = np.array([y[:len(xaxis)] for y in y_values])
            #     y_mean = np.mean(y_np, axis=0)
            #     y_min = np.min(y_np, axis=0)
            #     y_max = np.max(y_np, axis=0)
            #     p = plt.plot(xaxis, y_max, label=model)
            #     plt.fill_between(xaxis, y_min, y_max, alpha=0.2, color=p[0].get_color())
            #     for y in y_values:
            #         plt.plot(xaxis, y[:len(xaxis)], alpha=0.5, color=p[0].get_color())
                
                    
            #     font = {# 'family' : 'normal',
            #             # 'weight' : 'bold',
            #             'size'   : 18}

            # matplotlib.rc('font', **font)
            # lgd = plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
            # # # x = df['steps'].to_numpy()
            # # y = df['greedy_reward'].dropna().to_numpy()
            # # # moving average for y
            # # print(y)
            # # y = moving_average(y, 200)
            # # print(y)
            # # x = np.arange(len(y))
            # # plt.plot(x, y)

            # plt.xlabel('Steps', fontsize = font['size'])
            # plt.xticks(fontsize=font['size'])
            # plt.ylabel('Reward (Higher is better)', fontsize = font['size'])
            # plt.yticks(fontsize=font['size'])
            # plt.title(f'Training Curve for Soft Actor Critic for {len(seeds)} Seeds {problem_set_prefix}', fontsize = font['size'])
            
            # if save:
            #     plt.savefig(f"figures/{problem_set_prefix}_training_curve.png", bbox_extra_artists=(lgd,), bbox_inches='tight')
            # plt.cla()
    # print(model_checkpoints)
    return model_checkpoints

if __name__ == "__main__":
    env_prefix = "paper init"
    seeds = [10, 11, 12]
    model_name = ["hetgat", "hetgat_resnet", "hgt", "hgt_edge", "hgt_edge_resnet", "hgt-edge-resnet_wbb"]
    models = ["hetgat", "hetgat_resnet", "hgt", "hgt_edge", "hgt_edge_resnet", "hgt_edge_resnet_bb"]
    problem_sets = ["data/problem_set_r2_t5_s10_f30_w50_euc_2000", "data/problem_set_r5_t20_s10_f30_w50_euc_2000"]
    problem_sets_prefixes = ["2r5t", "5r20t"]
    baselines = ['milp_reward', 'improved_edf_reward', 'edf_reward'] # , 'max_sample_reward', 'mean_sample_reward', 'min_sample_reward']
    
    get_checkpoints(env_prefix, seeds, model_name, models, problem_sets, problem_sets_prefixes, baselines)