import os
import argparse

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np 

from utils import seed_everything
from replay_buffer import ReplayMemory, load_mem_uncertain

def plot_buffer(buf, path):
    dims = buf.shape[1]
    print('fix humanoid plotting')
    if (buf.shape[1]<200):
        f, axes = plt.subplots(1, dims, figsize=(10*dims,10))
        for i in range(dims):
            if dims == 1:
                sns.histplot(buf[:,i], kde=True, color = 'skyblue',
                     edgecolor = 'black', stat = 'count', ax = axes, bins = 'sqrt')
            else:
                sns.histplot(buf[:,i], kde=True, color = 'skyblue',
                     edgecolor = 'black', stat = 'count', ax = axes[i], bins = 'sqrt')
                axes[i].set_title(f'Dimension {i}', fontsize=20)
    else:
        f, axes = plt.subplots(11, 25, figsize=(5*25,5*11))
        for i in range(dims):
            col = i%25
            row = int(i/25)
            sns.histplot(buf[:,i], kde=True, color = 'skyblue',
                 edgecolor = 'black', stat = 'count', ax = axes[row, col], bins = 'sqrt')
            axes[row, col].set_title(f'Dimension {i}', fontsize=20)
    plt.savefig(path)
    plt.close()

if __name__ =='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', default="Ant-v2",
                        help='Environment [WetChicken-v0, Pendulum-v0, HalfCheetah-v2, Hopper-v2]')
    parser.add_argument('--noise_seed', type=int, default=14,
                        help='random seed (default: 123456)')
    parser.add_argument('--noise_weight', type=float, default=0.2,
                        help='how much noise to add in')
    parser.add_argument('--modes', default=0, type=int,
            help='number of modes in noise to simulate chaotic dynamics')
    args = parser.parse_args()
    seed_everything(42)
    env = 'Ant-v2_test_aquisition'
    humanoid = True
    scratch_path = '/home/nwaftp23/scratch/uncertainty_estimation/mujoco' 
    env_path = os.path.join(scratch_path, env)
    buffer_path = os.path.join(env_path, 'noiseweight0.2_modes0')
    train_memory = ReplayMemory(1000000, 2056, bootstrap = False,
            ensemble_size = 5, shuffle = True)
    buf_dir = load_mem_uncertain(args, train_memory, env_path)
    test_memory = ReplayMemory(1000000, 2056, bootstrap = False,
            ensemble_size = 5, shuffle = True)
    buf_dir = load_mem_uncertain(args, test_memory, env_path, test=True)
    oracle_memory = ReplayMemory(1000000, 2056, bootstrap = False,
            ensemble_size = 5, shuffle = True)
    buf_dir = load_mem_uncertain(args, oracle_memory, env_path, oracle=True)
    train_states, train_actions, _, _, _, _, _ = map(np.stack, zip(*train_memory.buffer))
    test_states, test_actions, _, _, _, _, _ = map(np.stack, zip(*test_memory.buffer))
    oracle_states, oracle_actions, _, _, _, _, _ = map(np.stack, zip(*oracle_memory.buffer))
    train_graph_path = os.path.join(buffer_path, 'train_data_states.png')
    plot_buffer(train_states, train_graph_path)
    train_graph_path = os.path.join(buffer_path, 'train_data_actions.png')
    plot_buffer(train_actions, train_graph_path)
    test_graph_path = os.path.join(buffer_path, 'test_data_states.png')
    plot_buffer(test_states, test_graph_path)
    test_graph_path = os.path.join(buffer_path, 'test_data_actions.png')
    plot_buffer(test_actions, test_graph_path)
    oracle_graph_path = os.path.join(buffer_path, 'oracle_data_states.png')
    plot_buffer(oracle_states, oracle_graph_path)
    oracle_graph_path = os.path.join(buffer_path, 'oracle_data_actions.png')
    plot_buffer(oracle_actions, oracle_graph_path)
