import os

import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm
from gym_recording_modified.playback import get_recordings
from typing import List, Any, Dict, Union, Callable
import csv

def extract_seed_data(root_seed_dir: str, extracted_type: str = None,
        process_func: Callable[[List[Any]], List[Any]] = None) -> List[Any]:
    """This function will extract the data generated by 

    :param root_seed_dir:
    :type root_seed_dir: str
    :param extracted_type:
    :type extracted_type: str
    :rtype: List[Any]
    """

    if extracted_type == 'rewards' or extracted_type == 'rewards2' or extracted_type == 'sum-rewards':
        extract = 'rewards'
    elif extracted_type == 'episode_steps':
        extract = ['episode_steps']
    elif extracted_type == 'episode_returns':
        extract = ['episode_returns']
    elif extracted_type == 'losses':
        extract = ['losses']
    elif extracted_type == 'td_errors':
        extract = ['td_errors']
    elif extracted_type == 'weight_differences':
        extract = ['weight_differences']   
    elif extracted_type == 'sum-steps-per-episode' or 'num-steps-per-episode' or extracted_type == 'num-steps-per-episode2' or extracted_type == 'episode-num-at-steps' or extracted_type == 'mean-steps-for-last-100-episodes' or extracted_type == 'rewards3':
        extract = 'episodes_end_point'
    elif extracted_type == 'sum-episodic-rewards':
        extract = ['rewards', 'episodes_end_point']
    elif extracted_type == 'observations':
        extract = ['observations', 'episodes_end_point']
    elif extracted_type == 'actions':
        extract = ['actions', 'episodes_end_point']
    else:
        raise ValueError('extracted_type cannot be identified')
    
    seed_data = []
    for seed_value in os.listdir(root_seed_dir):

        seed_dir = os.path.join(root_seed_dir, seed_value)
        # print(extract, seed_dir)
        single_seed_data = get_recordings(seed_dir, extract=extract)
        
        # Only store the average rewards in each episode
        if extracted_type == 'sum-episodic-rewards':
            sum_episodic_rewards = []
            
            eep = single_seed_data['episodes_end_point']
            rewards = single_seed_data['rewards']
            for i, ep_idx in enumerate(eep[:-1]):
                sum_episodic_rewards.append(np.sum(rewards[ep_idx: eep[i+1]]))
            
            single_seed_data = sum_episodic_rewards
            
        elif extracted_type == 'rewards' or extracted_type == 'rewards2':

            single_seed_data = single_seed_data['rewards']

        elif extracted_type == 'sum-rewards':

            single_seed_data = [np.sum(single_seed_data['rewards'])]

        elif extracted_type == 'rewards3':
            num_steps_per_episode = []

            eep = single_seed_data['episodes_end_point']
            for i, ep_idx in enumerate(eep[:-1]):
                num_steps_per_episode.append(0 - (eep[i+1] - ep_idx))
            
            single_seed_data = num_steps_per_episode
        elif extracted_type == 'num-steps-per-episode':
            num_steps_per_episode = []

            num_steps = 0
            num_episodes = 0
            window_size = 1000
            eep = single_seed_data['episodes_end_point']
            for i, ep_idx in enumerate(eep[:-1]):
                num_steps += eep[i+1] - ep_idx
                num_episodes += 1
                if num_steps >= window_size:
                    num_steps_per_episode.append(num_steps/num_episodes)
                    num_steps = 0
                    num_episodes = 0
            
            single_seed_data = num_steps_per_episode
        elif extracted_type == 'sum-steps-per-episode':
            num_steps_per_episode = []

            eep = single_seed_data['episodes_end_point']
            
            if len(eep) < 1001:
                single_seed_data = [-1000000]
            else:
                single_seed_data = [-eep[-1]]

        elif extracted_type == 'num-steps-per-episode2':
            num_steps_per_episode = []

            eep = single_seed_data['episodes_end_point']
            for i, ep_idx in enumerate(eep[:-1]):
                num_steps_per_episode.append(eep[i+1] - ep_idx)
            
            single_seed_data = num_steps_per_episode[1:-2]
            single_seed_data = np.delete(single_seed_data, np.arange(0,len(single_seed_data),2), None)
        elif extracted_type == 'mean-steps-for-last-100-episodes':
            num_steps_per_episode = []

            eep = single_seed_data['episodes_end_point'][-101:-1] - single_seed_data['episodes_end_point'][-101]
            for i, ep_idx in enumerate(eep[:-1]):
                num_steps_per_episode.append(eep[i+1] - ep_idx)
            
            single_seed_data = num_steps_per_episode
        
        elif extracted_type == 'episode-num-at-steps':
            episode_num_at_steps = []
 
            eep = single_seed_data['episodes_end_point']
            for episode_num, ep_idx in enumerate(eep[:-1]):
               episode_num_at_steps += [episode_num+1] * (eep[episode_num+1] - ep_idx)
            
            single_seed_data = episode_num_at_steps
        elif extracted_type == 'observations':

            single_seed_data = single_seed_data['observations']
 
        elif extracted_type == 'actions':
        
            single_seed_data = single_seed_data['actions']
        elif extracted_type == 'episode_returns':
            single_seed_data = single_seed_data['episode_returns']
        elif extracted_type == 'episode_steps':
            single_seed_data = single_seed_data['episode_steps']
        elif extracted_type == 'losses':
            single_seed_data = single_seed_data['losses']
        elif extracted_type == 'td_errors':
            single_seed_data = single_seed_data['td_errors']
        elif extracted_type == 'weight_differences':
            single_seed_data = single_seed_data['weight_differences']
        if extracted_type=='rewards' or extracted_type == 'sum-steps-per-episode' or extracted_type == 'num-steps-per-episode' or extracted_type == 'num-steps-per-episode2' or extracted_type == 'mean-steps-for-last-100-episodes' or extracted_type == 'sum-rewards' or extracted_type == 'episode_returns' or extracted_type == 'episode_steps' or extracted_type == 'td_errors' or extracted_type == 'weight_differences' or extracted_type == 'losses':
            seed_data.append(single_seed_data)
        else:
            raise ValueError('This values does not exist')
    # if extracted_type != 'rewards2' and extracted_type != 'num-ste# ps-per-episode' and extracted_type != 'mean-steps-for-last-100-episodes' and extracted_type != 'num-steps-per-episode2':
        # seed_data.append(single_seed_data)
    
    if process_func is not None:
        seed_data = process_func(seed_data)
    
    return seed_data


def collect_experiments_data(root_dir: str, extracted_type: str = None,
                    process_func: Callable[[List[Any]], List[Any]] = None) -> Union[Dict[str, Any],List[Any]]:
    """ negative depths means unlimited recursion """
    count = count_grand_child_dirs(root_dir)
    pbar = tqdm(total=count)
    # recursive function that collects all the ids in `acc`
    def recurse(current_dir):
        dic = {}
        child_dirs = os.listdir(current_dir)

        # NOTE: Not a best practice but it works for now
        grand_child_dir = os.listdir(os.path.join(current_dir, child_dirs[0]))[0]
        is_seed_dir = '.csv' in grand_child_dir or '.pkl' in grand_child_dir or '.npy' in grand_child_dir or 'log_file' in grand_child_dir
        
        if is_seed_dir:
            nonlocal pbar
            pbar.update(1)
            return extract_seed_data(current_dir, extracted_type)

        else:
            for folder in child_dirs:

                # recursive call for each subfolder
                child_dir = os.path.join(current_dir, folder)
                dic[folder] = recurse(child_dir)

            return dic

    return recurse(root_dir) # starts the recursion

def find_failed_jobs(root_dir: str, extracted_type: str = None,
                    process_func: Callable[[List[Any]], List[Any]] = None,
                    expected_length: int = 50) -> Union[Dict[str, Any],List[Any]]:
    """ negative depths means unlimited recursion """
    count = count_grand_child_dirs(root_dir)
    pbar = tqdm(total=count)
    failed_jobs = []
    # recursive function that collects all the ids in `acc`
    def recurse(current_dir):
        child_dirs = os.listdir(current_dir)

        # NOTE: Not a best practice but it works for now
        is_seed_dir = any(['.csv' in child_dir or '.pkl' in child_dir or '.npy' in child_dir for child_dir in child_dirs])
        if is_seed_dir:
            nonlocal pbar
            pbar.update(1)
            failed = True
            for recorded_file in child_dirs:
                if 'episode_returns' in recorded_file:
                    failed = len(np.load(os.path.join(current_dir, recorded_file))) != expected_length
            if failed:
                failed_jobs.append(current_dir)
            return
        else:
            # recursive call for each subfolder
            for child_dir in child_dirs:
                recurse(os.path.join(current_dir, child_dir))
            return
 
    recurse(root_dir)
    return failed_jobs # starts the recursion


def count_grand_child_dirs(root_dir: str):
    count = 0

    def recurse(current_dir):
        dic = {}
        child_dirs = os.listdir(current_dir)

        # NOTE: Not a best practice but it works for now
        grand_child_dir = os.listdir(os.path.join(current_dir, child_dirs[0]))[0]
        is_seed_dir = '.csv' in grand_child_dir or '.pkl' in grand_child_dir or '.npy' in grand_child_dir or 'log_file' in grand_child_dir
        
        if is_seed_dir:
            nonlocal count
            count += 1
        else:
            for folder in child_dirs:
                # recursive call for each subfolder
                child_dir = os.path.join(current_dir, folder)
                dic[folder] = recurse(child_dir)

            return dic
    recurse(root_dir)
    return count # starts the recursion

def count_dict_to_dataframe(dictionary):
    count = 0 

    def recurse(dictionary):        
        for key in dictionary.keys():
            if isinstance(dictionary[key], list):
                for seed_row in dictionary[key]:
                    nonlocal count 
                    count += 1

            else: 
                recurse(dictionary[key])
                
    recurse(dictionary)

    return count


def dict_to_dataframe(dictionary, columns, range_added=0):
    '''
    Converts a recursive data dictionary to a pandas dataframe
        dictionary : dict
            dictionary to be converted
        columns : list
            columns for parameter information
    '''

    episode_len = None
    count = count_dict_to_dataframe(dictionary)
    pbar = tqdm(total=count)

    def recurse(dictionary):
        rows = list()

        for key in dictionary.keys():
            if isinstance(dictionary[key], list):

                for seed_row in dictionary[key]:
                    seed_row = list(seed_row)
    
                    new_row = [key] + seed_row
                    rows.append(new_row)
                    nonlocal episode_len
                    episode_len = len(seed_row)

                    nonlocal pbar
                    pbar.update(1)

            else:
                new_rows = recurse(dictionary[key])
                for row in new_rows:
                    new_row = [key] + row
                    rows.append(new_row)

        return rows

    data = recurse(dictionary)

    df = pd.DataFrame(data, columns = columns + list(range(0, episode_len+range_added)))

    return df


def sum_df_rewards(df, columns, denominator=1, col_title='Total Return'):
    '''
    Sums all columns per row, where the summed columns are not in 'columns'
        df : pandas DataFrame
        columns : list 
            columns not to be summed
    '''
    df = df.set_index(columns)
    df = df.sum(axis=1)#.div(denominator)
    df = df.reset_index()
    df.columns = columns + [col_title]

    return df


def cumsum_df_rewards(df, columns):
    '''
    Sums all columns per row, where the summed columns are not in 'columns'
        df : pandas DataFrame
        columns : list 
            columns not to be summed
    '''
    df = df.set_index(columns)
    df = df.cumsum(axis=1)
    df = df.reset_index()
    
    return df

def extract_all_returns_to_data_frame(path, num_entries, sum_rewards=False):
    """
    Extracts traces rooted at root_dir and returns a pandas df with the total return for each of the parameters
        
    arguments:
        path : str
            Path to the root of the directory which contains the results of the same environment
    
    
    NOTE: all runs MUST have the same length of episode. Otherwise they cannot be put into a dataframe
    NOTE: This will currently only work for the tabular directory structure
    """

    # iterate though the directories to count number of extractions for progess bar
    num_to_extract = 0
    for alg in os.listdir(path):
        alg_dir = os.path.join(path, alg)
        for exp_technique in os.listdir(alg_dir): # This will go through different exploration techniques
            exp_dir = os.path.join(alg_dir, exp_technique)
            for step_size in os.listdir(exp_dir): # This will go through different exploration techniques
                step_dir = os.path.join(exp_dir, step_size)
                for exp_value in os.listdir(step_dir): # the parameter
                    exp_value_dir = os.path.join( step_dir, exp_value)
                    for seed in os.listdir(exp_value_dir): # the parameter
                        num_to_extract += 1

    pbar = tqdm(total=num_to_extract)

    all_params = list() 
    all_rewards = list()
    for alg in os.listdir(path):
        alg_dir = os.path.join(path, alg)
        for exp_technique in os.listdir(alg_dir): # This will go through different exploration techniques
            exp_dir = os.path.join(alg_dir, exp_technique)
            for step_size in os.listdir(exp_dir): # This will go through different exploration techniques
                step_dir = os.path.join(exp_dir, step_size)
                for exp_value in os.listdir(step_dir): # the parameter
                    exp_value_dir = os.path.join(step_dir, exp_value)
                    for seed in os.listdir(exp_value_dir): # the parameter
                        seed_dir = os.path.join(exp_value_dir, seed)

                        rewards = list(get_recordings(seed_dir, extract='rewards').values())[0][:num_entries] # Idk why but this fixes it

                        if sum_rewards:
                            rewards = [np.sum(rewards)]


                        params = [alg, exp_technique, float(step_size), float(exp_value), int(seed)]
                    
                        all_params.append(params)
                        all_rewards.append(rewards)
                        pbar.update(1)

    all_rewards = np.array(all_rewards)


    columns =  ["Algorithm", "Exploration Technique",  "Step Size", "Exp Value", "Seed"]
    param_df =  pd.DataFrame(data=all_params, columns=columns)

    param_df["Exp Value"] = param_df["Exp Value"].astype(float).round(1)

    if sum_rewards:
        reward_df = pd.DataFrame(data=all_rewards, columns=['Total Return'])
    else:
        reward_df = pd.DataFrame(data=all_rewards,  columns=range(0, all_rewards.shape[1]))
    return pd.concat([param_df, reward_df], axis=1)


def extract_all_returns_to_data_frame_config(config_path, num_entries, config, sum_rewards=False):
    all_rewards = list()

    for seed in os.listdir(config_path): # the parameter
        seed_dir = os.path.join(config_path, seed)

        rewards = list(get_recordings(seed_dir, extract='rewards').values())[0][:num_entries] # Idk why but this fixes it

        if sum_rewards:
            rewards = [np.sum(rewards)]

        all_rewards.append(rewards)

    all_rewards = np.array(all_rewards)

    return all_rewards


def load_results(path, load_type, param_names, filters=None):
    '''
    loads all results and returns an np array
        path : root of results
        type : 'trace', 'offline_eval_returns', 'offline_eval_steps', 'actions', 'states', 'td_errs', 'params'
        param_names : list of parameters to extract
        filters : a dictionary of filters to apply 
    '''

    def load_params(root):
        params = {}
        with open(os.path.join(root, 'args.csv'), mode='r') as infile:
            reader = csv.reader(infile)
            for line in reader:
                params[line[0]] = line[1]
        return params

    data = None
    all_params = []

    for root, subdirs, files in os.walk(path):
        if len(subdirs) == 0: # base case
            params = load_params(root)
            params_run  = [] # relevant parameters for a single run
            
            filter = False
            for p in param_names:

                if p == 'exp_val':
                    if 'softmax' in params['exploration_strategy']:
                        param = params.get('temp', None)
                    elif params['exploration_strategy'] == 'epsilon-greedy':
                        param = params.get('epsilon', None)
                    else:
                        param = params.get('eta', None)
                else:
                    param = params.get(p, None)

                params_run.append(param)
                
                if filters is not None:
                    allowed_values = filters.get(p)
                    if allowed_values is not None and param not in allowed_values:
                        filter = True

            if filter == False:
                all_params.append(params_run)

                if load_type == 'trace':
                    for file in files:
                        if 'trace.episode_returns' in file:
                            load_file = os.path.join(root, file)
                            run_data  = np.load(load_file)
                            break

                elif load_type == 'params':
                    continue

                else:
                    load_file = os.path.join(root, load_type+'.npy')
                    run_data  = np.load(load_file)

                if data is None:
                    data = np.reshape(run_data, (1, run_data.shape[0]))
                else:
                    data = np.append(data,  np.reshape(run_data, (1, run_data.shape[0])), axis=0)

    return np.array(all_params), data
