import json
import os
import argparse
import sys
import time
from datetime import timedelta
import warnings
warnings.filterwarnings("ignore")

from tqdm import tqdm
import torch
import torch.nn as nn
import numpy as np
import gym
import imageio

from replay_buffer import ReplayMemory, load_mem_uncertain
from plot_functions import plot_input_distribution, plot_rmse_likelihood, plot_state_hm
from env import sim_env
from analyze_fit import calc_rmse
from estimate_uncertainty import find_best_points
from utils import (instantiate_model, normalize, un_normalize,
    gen_folder_uncertain, seed_everything)
from policy import load_policy

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', default="WetChicken-v0",
                        help='Environment [WetChicken-v0, Pendulum-v0, HalfCheetah-v2, Hopper-v2]')
    parser.add_argument('--seed', type=int, default=1456,
                        help='random seed (default: 123456)')
    parser.add_argument('--noise_seed', type=int, default=14,
                        help='random seed (default: 123456)')
    parser.add_argument('--num_layers', default=3, help='total number of flows', type = int)
    parser.add_argument('--hids', type = int, default = 256, help='hidden units in flows')
    parser.add_argument('--lr', default=5e-4, type=float, help='flows learning rate')
    parser.add_argument('--gamma', default=0.999, type=float, help='schedule for lr step')
    parser.add_argument('--batch_size', default=2056, type=int, help='size of training batch size')
    parser.add_argument('--save_model', action= 'store_true', help='save model or not')
    parser.add_argument('--bins', type = int, default = 10, help='number of bins for spline NSF')
    parser.add_argument('--domain', type = float, default = 1.2, help='domain for spline NSF')
    parser.add_argument('--show', action= 'store_true', help='show graphs')
    parser.add_argument('--epochs', default=100, type=int, 
            help='number of epochs for dyna model')
    parser.add_argument('--model', default="nflows_ensemble", type =str,
            help='Selects the dynamics model [nflows, gp, mc_drop, nn_ensemble, nflows_ensemble])')
    parser.add_argument('--ensemble_size', default=5, type = int,
            help='number of components in uncertainty models')
    parser.add_argument('--bootstrap', action= 'store_true',
            help='whether or not to bootstrap the data, used for model PE')
    parser.add_argument('--compute_canada', action= 'store_true',
            help='whether or not on compute canada')
    parser.add_argument('--noise_weight', type=float, default=0.2,
                        help='how much noise to add in')
    parser.add_argument('--modes', default=0, type=int,
            help='number of modes in noise to simulate chaotic dynamics')
    parser.add_argument('--valley_distribution', action= 'store_true',
            help='whether or not to create noise via the valley distribution')
    parser.add_argument('--fat_tail', action= 'store_true',
            help='whether or not to create noise via the fat tail distribution')
    parser.add_argument('--replay_size', type=int, default=1000000,
                        help='size of replay buffer (default: 10000000)')
    parser.add_argument('--epochs_multiplier', type=int, default=100,
                        help='number of printouts')
    parser.add_argument('--policy_type', default='SAC', type=str,
                        help='pick type of police to run (SAC, LinearRand, PureRand)')
    parser.add_argument('--policy', default="Gaussian",
                        help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
    parser.add_argument('--eval', type=bool, default=True,
                        help='Evaluates a policy a policy every 10 episode (default: True)')
    parser.add_argument('--tau', type=float, default=0.005, metavar='G',
                        help='target smoothing coefficient(tau) (default: 0.005)')
    parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
                        help='Temperature parameter alpha determines the relative importance of the entropy term against the reward (default: 0.2)')
    parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
                        help='Automaically adjust alpha (default: False)')
    parser.add_argument('--updates_per_step', type=int, default=1, metavar='N',
                        help='model updates per simulator step (default: 1)')
    parser.add_argument('--start_steps', type=int, default=10000, metavar='N',
                        help='Steps sampling random actions (default: 10000)')
    parser.add_argument('--target_update_interval', type=int, default=1, metavar='N',
                        help='Value target update per no. of updates per step (default: 1)')
    parser.add_argument('--cuda', action="store_true",
                        help='run on CUDA (default: False)')
    parser.add_argument('--data_size', type=int, default=200,
                        help='controls size of the data (negative number use all data)')
    parser.add_argument('--rqs', action="store_true",
                        help='rational quadratic or cubic spline')
    parser.add_argument('--dropout_masks', action="store_true",
                        help='fixed set of dropout masks')
    parser.add_argument('--multihead', action="store_true",
                        help='multihead ensemble')
    parser.add_argument('--base_distro', action="store_true",
                        help='ensemble in base distro')
    parser.add_argument('--uncertain_nflows', action="store_true",
                        help='uncertainty in nflow layers')
    parser.add_argument('--rc_data', action="store_true",
                        help='rc car data or not')
    parser.add_argument('--index', type=int, default=-50,
                        help='Index for hyperparam list')
    parser.add_argument('--uncertainty_suffix', default="base",
                        help='[base, out, mean, max]')
    parser.add_argument('--test_acquisition', action="store_true", 
                        help='test different acquisitions')
    parser.add_argument('--acquisition_type', type=str, default='sample_bald',
                        help='how to acquire new points')
    parser.add_argument('--points_2_add', type=int, default=10,
                        help='how many new points to acquire')
    parser.add_argument('--test_num_samples', action="store_true", 
                        help='test different number of samples for MC on Hopper-v2')
    parser.add_argument('--numb_samps', type=int, default=10,
                        help='numb samps for MC test num_samples')
    args = parser.parse_args()
    print(args)
    step_ahead_max = 30
    seed_everything(args.seed)
    if args.compute_canada:
        store_dir = '/home/nwaftp23/scratch/uncertainty_estimation/mujoco/'
        save_model_dir = '/home/nwaftp23/scratch/uncertain_nf/models'
    else:
        store_dir = './results'
        save_model_dir = './models'
    output_preproc = normalize
    output_postproc = un_normalize
    input_preproc = normalize
    input_postproc = un_normalize
    branch_folder, child_folder = gen_folder_uncertain(args)
    env_dir = os.path.join(store_dir, branch_folder)
    ##
    ##env_dir = env_dir + '_100'
    ##
    store_dir = os.path.join(env_dir, child_folder)
    if not os.path.exists(store_dir):
        os.makedirs(store_dir)
    save_model_dir = os.path.join(save_model_dir, branch_folder)
    save_model_dir = os.path.join(save_model_dir, child_folder)
    if not os.path.exists(save_model_dir):
        os.makedirs(save_model_dir)
    results_dir = os.path.join(store_dir, 'results/')
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    imgs_dir = os.path.join(store_dir, 'epoch_imgs/')
    if not os.path.exists(imgs_dir):
        os.makedirs(imgs_dir)
    with open(os.path.join(store_dir, 'date_ran.txt'), mode='a') as f:
        f.write(f'Date: \n{time.strftime("%Y-%m-%d_%H_%M_%S")}')
    with open(os.path.join(save_model_dir, 'date_ran.txt'), mode='a') as f:
        f.write(f'Date: \n{time.strftime("%Y-%m-%d_%H_%M_%S")}')
    epoch_files = os.listdir(imgs_dir)
    for f in epoch_files:
        path = os.path.join(imgs_dir, f)
        os.remove(path)
    results_files = os.listdir(results_dir)
    for f in results_files:
        path = os.path.join(results_dir, f)
        os.remove(path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if args.env != 'SARCOS':
        env = gym.make(args.env)
    memory = ReplayMemory(args.replay_size, args.batch_size, bootstrap = args.bootstrap,
            ensemble_size = args.ensemble_size, shuffle = True)
    buf_dir = load_mem_uncertain(args, memory, env_dir)
    test_memory = ReplayMemory(args.replay_size, 1028, bootstrap = args.bootstrap,
            ensemble_size = args.ensemble_size, shuffle = False)
    buf_dir = load_mem_uncertain(args, test_memory, env_dir, test=True)
    if args.data_size > 0:
        memory.reduce_buffer(args.data_size)
    #test_memory.reduce_buffer(args.data_size*10)
    test_memory.reduce_buffer(2000)
    oracle_memory = ReplayMemory(args.replay_size, args.batch_size, 
        bootstrap = args.bootstrap, ensemble_size = args.ensemble_size, 
        shuffle = False)
    buf_dir = load_mem_uncertain(args, oracle_memory, env_dir, oracle=True)
    oracle_memory.remove_portion(memory.buffer)
    if args.env in ['Pendulum-v0', 'WetChicken-v0']:
        plot_state_hm(memory, test_memory, imgs_dir, env=args.env, show=False)
    if args.env != 'SARCOS': 
        state_dim = env.observation_space.shape[0]
    else:
        state_dim = 14
    if args.env=='Ant-v2':
        state_dim = 27
    elif args.env=='Hopper-v2':
        state_dim = 11 
    elif args.env=='Humanoid-v2':
        state_dim = 257
    if args.env != 'SARCOS': 
        action_dim = env.action_space.shape[0]
    else:
        action_dim = 7
    context_dims = action_dim+state_dim
    action_dim_seq = action_dim
    if args.policy_type == 'SAC':
        load_model_dir = '/home/nwaftp23/pytorch-soft-actor-critic/models'
    else:
        load_model_dir = store_dir
    args.output_dim = state_dim
    if args.env != 'SARCOS': 
        output_dim = state_dim
    else:
        output_dim = 7
    args.context_dim = context_dims
    with open(os.path.join(store_dir, 'commandline_args.txt'), 'w') as f:
        json.dump(args.__dict__, f, indent=2)
    with open(os.path.join(save_model_dir, 'commandline_args.txt'), 'w') as f:
        json.dump(args.__dict__, f, indent=2)
    model = instantiate_model(args, output_dim, context_dims, device, input_preproc,
        output_preproc, step_ahead_max)
    test_losses = []
    rmses = []
    numb_points_2_add = args.points_2_add
    ensemble = True
    nflows = False
    if args.model == 'nn_ensemble':
        numb_samps = 5000
    elif args.model == 'mc_drop':
        numb_samps = 2500
    elif args.model == 'nflows_ensemble':
        numb_samps = 5000
        nflows = True
    elif args.acquisition_type == 'batchbald':
        numb_samps = 1000
    if args.test_num_samples:
        numb_samps = args.numb_samps
    train_set_size = [len(memory.buffer)]
    time_estimates = []
    for i in range(args.epochs_multiplier):
        start_time = time.time()
        train_loss = model.train(args.epochs, memory)
        model.detach_model()
        test_loss = model.loss(test_memory)
        epoch_suffix = 'epoch_'+str(((i+1)))
        step_ahead_max_pl = None
        #check_dyna_fit(env, noise_params, model, memory, device, args.show, 
        #    epoch_suffix, imgs_dir, state_dim, action_dim, action_dim_seq, 
        #    input_preproc, output_postproc, inp_stats = model.stats_inputs,
        #    out_stats = model.stats_outputs, state='replay', policy=policy,
        #    dyna_horizon=args.nstep, step_ahead_max = step_ahead_max_pl)
        if args.env == 'Humanoid-v2':
            samp_oracle = oracle_memory.sample(125)
        else:
            samp_oracle = oracle_memory.sample(1000)
        if args.acquisition_type not in ['sample_bald', 'batchbald']:
            samp_oracle = oracle_memory.sample(10000)
        if args.acquisition_type != 'random':
            states, actions, _, next_states, _, _, _ = map(np.stack, zip(*memory.buffer))
            states = torch.tensor(states, dtype = torch.float32).to(model.device)
            actions = torch.tensor(actions, dtype = torch.float32).to(model.device)
            next_states = torch.tensor(next_states, dtype = torch.float32).to(model.device)
            inps = torch.hstack([states, actions])
            outs = next_states
            inps = model.input_preproc(inps, model.stats_inputs)
            outs = model.output_preproc(outs, model.stats_outputs)
            points_2_add, time_taken = find_best_points(samp_oracle, numb_samps,
                model, input_preproc, args.ensemble_size,
                device, acquisition_criteria = args.acquisition_type, 
                nflows=nflows, numb_points_2_add = numb_points_2_add, 
                x_train=inps, y_train=outs)
        else:
            rand_samp = oracle_memory.sample(numb_points_2_add)
            points_2_add = [(rand_samp[0][i], rand_samp[1][i], 
                rand_samp[2][i], rand_samp[3][i], rand_samp[4][i], 
                rand_samp[5][i], rand_samp[6][i]) for i in range(numb_points_2_add)]
            time_taken=0
            torch.cuda.empty_cache()
        time_estimates.append(time_taken)
        np.save(os.path.join(env_dir, 'time_estimates_'+args.acquisition_type+'_'+str(args.ensemble_size)),np.array(time_estimates))
        memory.add_to_buffer(points_2_add)
        oracle_memory.remove_portion(points_2_add)
        rmse = calc_rmse(test_memory, input_preproc, output_postproc,
            model, ensemble_size = args.ensemble_size, device = device)
        batch = [tpl for tpl in test_memory.buffer]
        states, actions, reward, next_states, done, noisy_actions, index = map(np.stack, zip(*batch))
        states = torch.tensor(states, dtype = torch.float32).to(device)
        actions = torch.tensor(actions, dtype = torch.float32).to(device)
        inps = torch.hstack([states, actions])
        subset = np.random.choice(inps.shape[0], size=50, replace=False)
        test_points = inps[subset, :]
        states = states[subset, :]
        actions = actions[subset, :]
        test_losses += [test_loss]
        rmses.append(rmse)
        mean_dyna_loss = torch.tensor(train_loss).mean()
        test_likelihood = np.exp(-np.array(test_losses))
        plot_rmse_likelihood(train_loss, np.arange(len(train_loss)),
            'train_loss', store_dir=results_dir)
        plot_rmse_likelihood(test_losses, train_set_size,
            'test_loss', store_dir=results_dir)
        plot_rmse_likelihood(rmses, train_set_size, 'rmse', store_dir=results_dir)
        plot_rmse_likelihood(test_likelihood, train_set_size,
            'likelihood', store_dir=results_dir)
        train_set_size.append(len(memory.buffer))
        end_time = time.time()
        train_time = str(timedelta(seconds=(end_time-start_time)))
        performance_string = f'Total Epochs: {(i+1)}, '\
                             f'Train Loss: {mean_dyna_loss:.2f}, '\
                             f'test Loss: {test_loss:.2f}, '\
                             f'Train Time: {train_time}'
        print(performance_string)
        print(f'Last Train Loss: {train_loss[-1]:.2f}, RMSE Test: {rmse}, '\
            f' Train Set Size: {len(memory.buffer)-numb_points_2_add}')
        np.save(os.path.join(results_dir, ('train_loss_array')), np.array(train_loss))
        np.save(os.path.join(results_dir, ('test_loss_array')), np.array(test_losses))
        np.save(os.path.join(results_dir, ('rmse_array')), np.array(rmses))
        print('Saving Model')
        model_path = os.path.join(save_model_dir,('model.pt'))
        model.save_model(model_path)
        model = instantiate_model(args, output_dim, context_dims, device, input_preproc,
            output_preproc, step_ahead_max)
        print("-----------------------------------------------")
