import pickle 
import os
import argparse
import sys
import time
sys.path.append('../pytorch-soft-actor-critic')

from tqdm import tqdm
import torch
import numpy as np
from scipy.stats import dirichlet
import gym

from env import sim_env, create_noise
from replay_buffer import ReplayMemory
from utils import seed_everything
from policy import LinearPolicy, load_policy, TestPolicy, TrainPolicy


def sim_policy(args, env, suffix, memory, 
        noise_params, gym_state = [], policy =[], 
        oracle=False, oracle_policy_path = '', 
        env_path=''):
    state_dim = env.observation_space.shape[0]
    if args.policy_type == 'LinearRand':
        action_dim = env.action_space.shape[0]
        kwargs = {
            "state_dim": state_dim,
            "action_dim": action_dim,
        }
        agent = LinearPolicy(**kwargs)
        policy = [agent]
    elif args.policy_type == 'TestPolicy':
        action_dim = env.action_space.shape[0]
        action_max = env.action_space.high
        agent = TestPolicy(action_max, action_dim)
        policy = [agent]
    elif args.policy_type == 'TrainPolicy':
        action_dim = env.action_space.shape[0]
        action_max = env.action_space.high
        agent = TrainPolicy(action_max, action_dim)
        policy = [agent]
    elif args.policy_type == 'PureRand':
        policy = []
    deterministic_policy = True
    #if not args.test_data:
    #    deterministic_policy = False
    _, tot_reward, numb_steps = sim_env(env, args.numb_episodes, noise_params, 
            rand_init=True, replay_buffer= memory, policy = policy, 
            deterministic_policy = deterministic_policy)
    if args.oracle: #or args.test_data:
        expert_data_points = len(memory.buffer)
        print(f'Expert Data: {expert_data_points}')
        oracle_policy = []
        if args.policy == 'LinearRand':
            agent.load_state_dict(torch.load(oracle_policy_path))
            oracle_policy = [agent]
        else:
            #oracle_policy = policy
            oracle_policy = []
        _, tot_reward, numb_steps = sim_env(env, int(args.numb_episodes*50), noise_params, 
                rand_init=True, replay_buffer= memory, policy = oracle_policy, 
                deterministic_policy = deterministic_policy)
        exp_percentage = expert_data_points/len(memory.buffer)*100
        print(f'Noob Data: {len(memory.buffer)-expert_data_points}')
        print(f'Percent of Expert Data: {exp_percentage}')
        data_ratio = {'exp_percentage':exp_percentage, 
            'noob_data': (100-exp_percentage), 'total_data':len(memory.buffer)}
        ratio_path = os.path.join(env_path, 'oracle_ratio.txt')
        with open(ratio_path,'w') as data:
            data.write(str(data_ratio))
    else:
        #_, tot_reward, numb_steps = sim_env(env, args.numb_episodes, noise_params, 
        #        rand_init=True, replay_buffer= memory, policy = policy, 
        #        deterministic_policy = deterministic_policy)
        print(f'Buffer Size: {len(memory.buffer)}')
    
    return policy

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--policy', default="Gaussian",
                        help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
    parser.add_argument('--eval', type=bool, default=True,
                        help='Evaluates a policy a policy every 10 episode (default: True)')
    parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                        help='discount factor for reward (default: 0.99)')
    parser.add_argument('--tau', type=float, default=0.005, metavar='G',
                        help='target smoothing coefficient(tau) (default: 0.005)')
    parser.add_argument('--lr', type=float, default=0.0003, metavar='G',
                        help='learning rate (default: 0.0003)')
    parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
                        help='Temperature parameter alpha determines the relative importance of the entropy term against the reward (default: 0.2)')
    parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
                        help='Automaically adjust alpha (default: False)')
    parser.add_argument('--seed', type=int, default=1456, metavar='N',
                        help='random seed (default: 123456)')
    parser.add_argument('--batch_size', type=int, default=256, metavar='N',
                        help='batch size (default: 256)')
    parser.add_argument('--hidden_size', type=int, default=256, metavar='N',
                        help='hidden size (default: 256)')
    parser.add_argument('--updates_per_step', type=int, default=1, metavar='N',
                        help='model updates per simulator step (default: 1)')
    parser.add_argument('--start_steps', type=int, default=10000, metavar='N',
                        help='Steps sampling random actions (default: 10000)')
    parser.add_argument('--target_update_interval', type=int, default=1, metavar='N',
                        help='Value target update per no. of updates per step (default: 1)')
    parser.add_argument('--replay_size', type=int, default=10000000, metavar='N',
                        help='size of replay buffer (default: 10000000)')
    parser.add_argument('--cuda', action="store_true",
                        help='run on CUDA (default: False)')
    parser.add_argument('--policy_type', default='SAC', type=str,
                        help='pick type of police to run (SAC, LinearRand, PureRand)')
    parser.add_argument("--numb_episodes", default=100, type=int,
                        help='number of episodes to collect data for')
    parser.add_argument('--env', default="Pendulum-v0",
                        help='Mujoco Gym environment (default: HalfCheetah-v2)')
    parser.add_argument('--modes', default=1, type=int,
            help='number of modes in noise to simulate chaotic dynamics')
    parser.add_argument('--valley_distribution', action= 'store_true',
            help='whether or not to create noise via the valley distribution')
    parser.add_argument('--fat_tail', action= 'store_true',
            help='whether or not to create noise via the fat tail distribution')
    parser.add_argument('--gauss_noise', action= 'store_true',
            help='whether or not to create noise with 0 mean gaussian')
    parser.add_argument('--compute_canada', action= 'store_true',
            help='whether or not on compute canada')
    parser.add_argument('--ensemble_size', default=5, type = int,
            help='ensemble size for pets')
    parser.add_argument('--bootstrap', action= 'store_true',
            help='whether or not to bootstrap the data')
    parser.add_argument('--show', action= 'store_true',
            help='show graphs or save them')
    parser.add_argument('--noise_weight', type=float, default=0.2, 
                        help='how much noise to add in')
    parser.add_argument('--test_data', action= 'store_true',
            help='create test data set')
    parser.add_argument('--noise_seed', type=int, default=14, metavar='N',
                        help='random seed (default: 123456)')
    parser.add_argument('--noisy_state', action = 'store_true',
                        help='noise on the state or the action')
    parser.add_argument('--oracle', action = 'store_true',
                        help='make oracle replay buffer for active learning')
    args = parser.parse_args()
    print(args)
    if args.compute_canada:
        store_dir = '/home/nwaftp23/projects/def-dpmeger/nwaftp23/uncertain_nf/mujoco'
        store_dir2 = '/home/nwaftp23/scratch/uncertainty_estimation/mujoco'
    else:
        store_dir = '/home/lucas/dyna_nf/replay_buffers'
    branch_folder = args.env+'_test_aquisition'
    branch_folder = os.path.join(branch_folder, 'noiseweight'+str(args.noise_weight)+'_')
    branch_folder = branch_folder+'modes'+str(args.modes)
    store_dir = os.path.join(store_dir, branch_folder)
    store_dir2 = os.path.join(store_dir2, branch_folder)
    if not os.path.exists(store_dir):
        os.makedirs(store_dir)
    if not os.path.exists(store_dir2):
        os.makedirs(store_dir2)
    suffix = 'seed'+str(args.noise_seed)

    env = gym.make(args.env)

    policy=[]
    if args.policy_type == 'SAC' or args.oracle:
        model_dir ='/home/nwaftp23/pytorch-soft-actor-critic/models'
        state_dim = env.observation_space.shape[0]
        action_dim = 42 
        policy = load_policy(args, state_dim, action_dim, env, model_dir)
    noise_params = create_noise(args, store_dir, suffix, args.noise_seed)    
    
    memory = ReplayMemory(args.replay_size, args.batch_size, bootstrap = args.bootstrap,
            ensemble_size = args.ensemble_size, shuffle = False)
    if args.test_data:
        seed_everything(args.seed+42)
    elif args.oracle:
        seed_everything(args.seed+43)
    else:
        seed_everything(args.seed)
    policy_path = os.path.join(store_dir, ('LinearRand_'+suffix+'.pt'))
    policy = sim_policy(args, env, suffix, memory, noise_params,  policy=policy, oracle=args.oracle, oracle_policy_path = policy_path,
        env_path = store_dir2)
    if args.test_data:
        buffer_path = os.path.join(store_dir, ('test_buffer_'+suffix+ '.pkl'))
        buffer_path2 = os.path.join(store_dir2, ('test_buffer_'+suffix+ '.pkl'))
    elif args.oracle:
        buffer_path = os.path.join(store_dir, ('oracle_buffer_'+suffix+ '.pkl'))
        buffer_path2 = os.path.join(store_dir2, ('oracle_buffer_'+suffix+ '.pkl'))
    else:
        buffer_path = os.path.join(store_dir, ('train_buffer_'+suffix+ '.pkl'))
        buffer_path2 = os.path.join(store_dir2, ('train_buffer_'+suffix+ '.pkl'))
    with open(buffer_path, 'wb') as f:
        pickle.dump(memory.buffer, f)
    f.close()
    print(f'Data Path: {buffer_path2}')
    with open(buffer_path2, 'wb') as f:
        pickle.dump(memory.buffer, f)
    f.close()
    if args.policy_type =='LinearRand':
        policy_path = os.path.join(store_dir, (args.policy_type +'_'+suffix+'.pt')) 
        torch.save(policy[0].state_dict(), policy_path)
