import gymnasium as gym
import sys
sys.path.append('../')  # Adjust path to the parent directory if needed
import gym_envs
import seaborn as sns
import numpy as np
import pandas as pd
import torch
from gflownet_traj_balance import GFlowNet
from stable_baselines3 import PPO
from gflownet_sub_traj_balance import SubTBGFlowNet
from sklearn.neighbors import KernelDensity
import random
# Utilities
import pickle
import matplotlib.pyplot as plt
import os
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

def seed_all(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)

# from the output file, load the model parameters
def param_to_name(hyperparameters):
    """
    Generate a model name string from a dictionary of hyperparameters.
    :param hyperparameters: dict containing hyperparameter values.
    :return: str model name
    """
    model_name = (f"{hyperparameters['method']}_"
                  f"{hyperparameters['total_iterations']}_"
                  f"{hyperparameters['nn_hidden_sizes']}_"
                  f"{hyperparameters['activation']}_"
                  f"{hyperparameters['initial_z']}_"
                  f"{hyperparameters['seed']}")
    
    if hyperparameters['method'] == 'ppo':
        model_name += (f"_{hyperparameters['learning_rate']}"
                       f"_{hyperparameters['batch_size']}"
                       f"_{hyperparameters['gamma']}"
                       f"_{hyperparameters['ent_coef']}"
                       f"_{hyperparameters['clip_range']}"
                       f"_{hyperparameters['timesteps_per_epoch']}")

    elif hyperparameters['method'] in ['sac', 'td3']:
        model_name += (f"_{hyperparameters['learning_rate']}"
                       f"_{hyperparameters['batch_size']}"
                       f"_{hyperparameters['gamma']}"
                       f"_{hyperparameters['buffer_size']}"
                       f"_{hyperparameters['train_freq']}"
                       f"_{hyperparameters['gradient_steps']}"
                       f"_{hyperparameters['learning_starts']}")
    
    elif hyperparameters['method'] == 'dqn':
        model_name += (f"_{hyperparameters['learning_rate']}"
                       f"_{hyperparameters['batch_size']}"
                       f"_{hyperparameters['gamma']}"
                       f"_{hyperparameters['ent_coef']}"
                       f"_{hyperparameters['clip_range']}"
                       f"_{hyperparameters['timesteps_per_epoch']}")

    elif hyperparameters['method'] == 'gfn':
        model_name += (f"_{hyperparameters['learning_rate']}"
                       f"_{hyperparameters['batch_size']}"
                       f"_{hyperparameters['gamma']}"
                       f"_{hyperparameters['buffer_size']}"
                       f"_{hyperparameters['train_freq']}"
                       f"_{hyperparameters['gradient_steps']}"
                       f"_{hyperparameters['learning_starts']}"
                       f"_{hyperparameters['explorative_as_initial']}"
                       f"_{hyperparameters['explorative_num']}"
                       f"_{hyperparameters['temperature']}"
                       f"_{hyperparameters['sample_method']}"
                       f"_{hyperparameters['rl_start']}"
                       f"_{hyperparameters['rl_length']}"
                       f"_{hyperparameters['use_filter']}"
                       f"_{hyperparameters['pessimistic_update']}"
                       f"_{hyperparameters['epsilon_random']}")
        if hyperparameters['use_filter']:
             model_name += (f"_{hyperparameters['filter_upper']}"
                       f"_{hyperparameters['filter_lower']}")
    
    elif hyperparameters['method'] == 'gfnsub':
        model_name += (f"_{hyperparameters['learning_rate']}"
                       f"_{hyperparameters['batch_size']}"
                       f"_{hyperparameters['gamma']}"
                       f"_{hyperparameters['buffer_size']}"
                       f"_{hyperparameters['train_freq']}"
                       f"_{hyperparameters['gradient_steps']}"
                       f"_{hyperparameters['learning_starts']}"
                       f"_{hyperparameters['temperature']}"
                       f"_{hyperparameters['sample_method']}"
                       f"_{hyperparameters['rl_start']}"
                       f"_{hyperparameters['rl_length']}"
                       f"_{hyperparameters['use_filter']}"
                       f"_{hyperparameters['lambda']}"
                       f"_{hyperparameters['weighting']}"
                       f"_{hyperparameters['pessimistic_update']}"
                       f"_{hyperparameters['epsilon_random']}")
        if hyperparameters['use_filter']:
             model_name += (f"_{hyperparameters['filter_upper']}"
                       f"_{hyperparameters['filter_lower']}")

    elif hyperparameters['method'] == 'gafn':
        model_name += (f"_{hyperparameters['learning_rate']}"
                       f"_{hyperparameters['batch_size']}"
                       f"_{hyperparameters['buffer_size']}"
                       f"_{hyperparameters['train_freq']}"
                       f"_{hyperparameters['gradient_steps']}"
                       f"_{hyperparameters['learning_starts']}"
                       f"_{hyperparameters['sample_method']}"
                       f"_{hyperparameters['epsilon_random']}")
    
    return model_name

def name_to_param(model_name):
    """
    Parse a model name string into a dictionary of hyperparameters.
    :param model_name: str containing the formatted model name.
    :return: dict containing hyperparameters
    """
    components = model_name.split('_')
    params = {}
    
    # Base parameters
    params['method'] = components[0]
    params['total_iterations'] = int(components[1])
    params['nn_hidden_sizes'] = components[2]
    params['activation'] = components[3]
    params['initial_z'] = components[4]
    params['seed'] = int(components[5])

    index = 6
    if params['method'] == 'ppo':
        keys = ['learning_rate', 'batch_size', 'gamma', 'ent_coef', 'clip_range', 'timesteps_per_epoch']
    elif params['method'] in ['sac', 'td3']:
        keys = ['learning_rate', 'batch_size', 'gamma', 'buffer_size', 'train_freq', 'gradient_steps', 'learning_starts']
    elif params['method'] == 'dqn':
        keys = ['learning_rate', 'batch_size', 'gamma', 'ent_coef', 'clip_range', 'timesteps_per_epoch']
    elif params['method'] == 'gfn':
        keys = ['learning_rate', 'batch_size', 'gamma', 'buffer_size', 'train_freq', 'gradient_steps', 'learning_starts',
                'explorative_as_initial', 'explorative_num', 'temperature', 'sample_method', 'rl_start', 'rl_length',
                'use_filter', 'pessimistic_update', 'epsilon_random']
    elif params['method'] == 'gfnsub':
        keys = ['learning_rate', 'batch_size', 'gamma', 'buffer_size', 'train_freq', 'gradient_steps', 'learning_starts',
                'temperature', 'sample_method', 'rl_start', 'rl_length', 'use_filter',  
                'lambda', 'weighting', 'pessimistic_update', 'epsilon_random']
    elif params['method'] == 'gafn':
        keys = ['learning_rate', 'batch_size', 'buffer_size', 'train_freq', 'gradient_steps', 'learning_starts',
                'sample_method', 'epsilon_random']
    elif params['method'] == 'teacher':
        keys = ['learning_rate', 'batch_size', 'buffer_size', 'train_freq', 'gradient_steps', 'learning_starts',
                'temperature', 'sample_method', 'epsilon_random']

    # Parse the remaining components
    for key in keys:
        value = components[index]
        params[key] = int(value) if value.isdigit() else value
        index += 1

    if 'use_filter' in params:
         if params['use_filter'] == 'True':
              keys2 = ['filter_upper', 'filter_lower']
              for key in keys2:
                   value = components[index]
                   params[key] = int(value) if value.isdigit() else value
                   index += 1

    if len(components) > index:
         params['temperature_rate'] = components[index]
    else:
         params['temperature_rate'] = 40

    return params

def dir_to_names(output_dir):
    """
    From the output directory, get a list of tuples (model_dir, hyperparameters).
    :param output_dir: str, path to the directory containing model files.
    :return: list of tuples, where each tuple is (model_dir, hyperparameters).
    """
    model_tuples = []

    # Iterate through files in the output directory
    for file in os.listdir(output_dir):
        if file.endswith(".pkl"):
            model_path = os.path.join(output_dir, file)
            model_name = os.path.splitext(file)[0]  # Extract file name without extension
            hyperparameters = name_to_param(model_name)
            model_tuples.append((model_path.replace(".pkl", ""), hyperparameters))
    
    return model_tuples

def gen_val_errs(model_dir, env, num_of_samples = 200000, steps = 16000, force_update = False, learning_start = 96):
    samples = visited_states(model_dir, env)[1]
    if os.path.exists(f"{model_dir}/val_errs.pkl") and not force_update:
        try:
            with open(f"{model_dir}/val_errs.pkl", "rb") as f:
                    to_check = pickle.load(f)
                    if len(to_check) > 0: # error generated
                        print(f"{model_dir} already has val_errs.pkl")
                        print(to_check)
                        return None
        except:
             print("Cannot load", model_dir)
    errs = []
    for i in range(steps + learning_start, len(samples)+1, steps):
        print(i)
        errs.append(env.unwrapped.get_error(samples[max(i - num_of_samples, 0) : i]))

    with open(f"{model_dir}/val_errs.pkl", "wb") as f:
        pickle.dump(errs, f)
    return errs

# visualize the loss of one model
def l1_loss(model_dir):
    with open(f"{model_dir}/val_errs.pkl", "rb") as f:
        l1_losses = pickle.load(f)
    
    return l1_losses 

def kl_curve(model_dir):
    with open(f"{model_dir}/val_errs.pkl", "rb") as f:
        kl_divergences = pickle.load(f)
    
    return kl_divergences

# visualize the training samples
def visited_states(model_dir, env):
    with open(f"{model_dir}.pkl", "rb") as f:
        data = pickle.load(f)
    end_states = data[0]
    # print(end_states[0])
    samples = env.unwrapped.get_state(np.array(end_states))
    return data[1], samples

def read_tensorboard_data(folder, tag):
    # if there are multiple event files, get the latest one
    event_files = [os.path.join(folder, f) for f in os.listdir(folder) if f.startswith("events")]
    event_files.sort(key=os.path.getmtime)
    event_file = event_files[-1]

    event_acc = EventAccumulator(event_file)
    event_acc.Reload()
    data = event_acc.Scalars(tag)
    return pd.DataFrame(data)


def prepare_log_data(folder1, folder2 = None, tag = 'train/log_Z'):
    model_list = dir_to_names(folder1)
    model_dict1, model_dict2 = get_model_dict(model_list)
    # aggregate over key in model_dict1
    result = {}
    for key in model_dict1:
        count = 1
        res = read_tensorboard_data(model_dict1[key][0], tag)

        result[key] = [res]
        
        for model in model_dict1[key][1:]:
            res = read_tensorboard_data(model, tag)
            count += 1
            result[key].append(res)
    print(count)
    
    if folder2 is not None:
        model_list = dir_to_names(folder2)
        model_dict1 = get_model_dict_supp(model_list)
        # aggregate over key in model_dict1

        for key in model_dict1:
            count = 1
            res = read_tensorboard_data(model_dict1[key][0], tag)
            result[key] = [res]
            
            for model in model_dict1[key][1:]:
                res = read_tensorboard_data(model, tag)
                count += 1
                result[key].append(res)
        print(count)
    return result


def prepare_log_data2(folder1, tag = 'train/log_Z'):
    model_list = dir_to_names(folder1)
    model_dict1, model_dict2 = get_model_dict(model_list)
    # aggregate over key in model_dict1
    result = {}
    for key in model_dict2:
        count = 1
        res = read_tensorboard_data(model_dict2[key][0], tag)

        result[key] = [res]
        
        for model in model_dict2[key][1:]:
            res = read_tensorboard_data(model, tag)
            count += 1
            result[key].append(res)
    print(count)
    return result

def prepare_performance_data(folder, env, gen_val_scores = False):
      model_list = dir_to_names(folder)
      if gen_val_scores:
            for model in model_list:
                  gen_val_errs(model[0], env, force_update=True)
      model_dict1, model_dict2 = get_model_dict(model_list)
      # aggregate over key in model_dict1
      result = {}
      for key in model_dict1:
            count = 1
            
            if len(model_dict1[key]) == 0: continue
            l1 = l1_loss(model_dict1[key][0])
            print(model_dict1[key][0])
            print(l1)
            tmp_mean = np.array(l1)
            tmp_std = np.array(l1) ** 2
            for model in model_dict1[key][1:]:
                  l1 = l1_loss(model)
                  print(model)
                  print(l1)
                  count += 1
                  tmp_mean += np.array(l1)
                  tmp_std += np.array(l1) ** 2
            print(count)

            tmp_mean /= count
            tmp_std = np.sqrt(tmp_std / count - tmp_mean ** 2)
            result[key] = (tmp_mean, tmp_std)
      return result

def prepare_ablation_data(folder, env, gen_val_scores = False):
      model_list = dir_to_names(folder)
      if gen_val_scores:
            for model in model_list:
                  gen_val_errs(model[0], env, force_update=True)
      _, model_dict2 = get_model_dict(model_list)
      # aggregate over key in model_dict2
      print(env)
      result = {}
      
      for key in model_dict2:
            if key in model_dict2 and len(model_dict2[key]) > 0:
                count = 1
                print(key)
                l1 = l1_loss(model_dict2[key][0])
                tmp_mean = np.array(l1)
                tmp_std = np.array(l1) ** 2
                for model in model_dict2[key][1:]:
                    l1 = l1_loss(model)
                    count += 1
                    tmp_mean += np.array(l1)
                    tmp_std += np.array(l1) ** 2
                print(count)

                tmp_mean /= count
                tmp_std = np.sqrt(tmp_std / count - tmp_mean ** 2)
                result[key] = (tmp_mean, tmp_std)
        
    #   for key in model_dict1:
    #         count = 1
    #         if key == "GFN-RP":
    #             l1 = l1_loss(model_dict1[key][0])
    #             tmp_mean = np.array(l1)
    #             tmp_std = np.array(l1) ** 2
    #             for model in model_dict1[key][1:]:
    #                 l1 = l1_loss(model)
    #                 count += 1
    #                 tmp_mean += np.array(l1)
    #                 tmp_std += np.array(l1) ** 2
    #             print(count)

    #             tmp_mean /= count
    #             tmp_std = np.sqrt(tmp_std / count - tmp_mean ** 2)
    #             result["Baseline"] = (tmp_mean, tmp_std)
      return result

def prepare_supp_data(folder1, folder2, env, gen_val_scores = False, force_update = False):
    # first load folder 1, then use folder2 to overwrite folder1 if the key has been overlapped
    model_list = dir_to_names(folder1)
    if gen_val_scores:
        for model in model_list:
                gen_val_errs(model[0], env, force_update=force_update)
    model_dict1, model_dict2 = get_model_dict(model_list)
    # aggregate over key in model_dict1
    result = {}
    for key in model_dict1:
        count = 1
        l1 = l1_loss(model_dict1[key][0])
        tmp_mean = np.array(l1)
        tmp_std = np.array(l1) ** 2
        for model in model_dict1[key][1:]:
                l1 = l1_loss(model)
                count += 1
                tmp_mean += np.array(l1)
                tmp_std += np.array(l1) ** 2
        print(count)

        tmp_mean /= count
        tmp_std = np.sqrt(tmp_std / count - tmp_mean ** 2)
        result[key] = (tmp_mean, tmp_std)
    
    model_list = dir_to_names(folder2)
    if gen_val_scores:
        for model in model_list:
                gen_val_errs(model[0], env, force_update=force_update)
    model_dict1 = get_model_dict_supp(model_list)
    # aggregate over key in model_dict1

    for key in model_dict1:
        count = 1
        l1 = l1_loss(model_dict1[key][0])
        tmp_mean = np.array(l1)
        tmp_std = np.array(l1) ** 2
        for model in model_dict1[key][1:]:
                l1 = l1_loss(model)
                count += 1
                tmp_mean += np.array(l1)
                tmp_std += np.array(l1) ** 2
        print(count)

        tmp_mean /= count
        tmp_std = np.sqrt(tmp_std / count - tmp_mean ** 2)
        result[key] = (tmp_mean, tmp_std)
    return result


def get_model_dict(model_list):
    model_dict1 = {"Ours": [], "GFN": [], "GAFN": [], "PBP-GFN": [], "GFN-RP": [], "SubTB": [], "Teacher": []}
    model_dict2 = {"BF + MP + TD": [], "BF + MP": [], "MP + TD": [], "MP": [], "TD": [], "Vanilla": [],
                   "BF + RP + TD": [], "BF + RP": [], "RP + TD": [], "BF + TD": [], "BF": [], "RP": []}

    for model in model_list:
        model_config = model[1]
        if model_config['method'] == "gfn":
                if model_config['sample_method'] == 0:
                    if model_config['use_filter'] == 'False':
                            if float(model_config['temperature']) == 0:
                                model_dict1['GFN'].append(model[0])
                                model_dict2["Vanilla"].append(model[0])
                            else:
                                model_dict2["TD"].append(model[0])
                    else:
                            if float(model_config['temperature']) == 0:
                                model_dict2["BF"].append(model[0])
                            else:
                                model_dict2["BF + TD"].append(model[0])
                elif model_config['sample_method'] == 2:
                    if model_config['use_filter'] == 'False':
                            if float(model_config['temperature']) == 0:
                                if model_config['pessimistic_update'] == 8:
                                        model_dict1['PBP-GFN'].append(model[0])
                                else:
                                        model_dict2['RP'].append(model[0])
                                        model_dict1['GFN-RP'].append(model[0])
                            else:
                                model_dict2["RP + TD"].append(model[0])
                    else:
                            if float(model_config['temperature']) == 0:
                                model_dict2["BF + RP"].append(model[0])
                            else:
                                model_dict2["BF + RP + TD"].append(model[0])
                elif  model_config['sample_method'] == 3:
                    if model_config['use_filter'] == 'False':
                            if float(model_config['temperature']) == 0:
                                model_dict2["MP"].append(model[0])
                            else:
                                model_dict2["MP + TD"].append(model[0])
                    else:
                            if float(model_config['temperature']) == 0:
                                model_dict2["BF + MP"].append(model[0])
                            else:
                                model_dict2["BF + MP + TD"].append(model[0])
                                model_dict1['Ours'].append(model[0])
                
                else:
                    model_dict1['GFN'].append(model[0])

        elif model_config['method'] == "gafn":
                model_dict1["GAFN"].append(model[0])
        elif model_config['method'] == "gfnsub":
                model_dict1["SubTB"].append(model[0])
        elif model_config['method'] == "teacher":
                model_dict1["Teacher"].append(model[0])
    return model_dict1, model_dict2

def get_model_dict_supp(model_list):
    model_dict1 = {"GAFN": [], "PBP-GFN": [], "GFN-RP": [], "SubTB":[]}
    for model in model_list:
        model_config = model[1]
        if model_config['method'] == "gfn":
                if model_config['sample_method'] == 2:
                    if model_config['use_filter'] == 'False':
                            if model_config['pessimistic_update'] == 8:
                                    model_dict1['PBP-GFN'].append(model[0])
                            else:
                                    model_dict1['GFN-RP'].append(model[0])
                else:
                    model_dict1['GFN'].append(model[0])

        elif model_config['method'] == "gafn":
                model_dict1["GAFN"].append(model[0])

        elif model_config['method'] == "gfnsub":
                model_dict1["SubTB"].append(model[0])
    return model_dict1

def load_tensorboard_data_pytorch(directories):
    """
    Load TensorBoard event data from a list of directories into Python data structures using PyTorch.
    
    Args:
        directories (list): List of directories containing TensorBoard event files.
    
    Returns:
        dict: A dictionary where keys are directory paths and values are lists of event data.
    """
    data = {}
    
    for directory in directories:
        if not os.path.isdir(directory):
            print(f"Directory does not exist: {directory}")
            continue
        
        try:
            reader = tb.SummaryReader(directory)
            events = []
            for summary in reader:
                events.append({
                    "tag": summary.tag,
                    "value": summary.value,
                    "step": summary.step,
                    "wall_time": summary.wall_time,
                })
            data[directory] = events
        except Exception as e:
            print(f"Error reading TensorBoard data in directory {directory}: {e}")
    
    return data



