"""
An example of Proximal Policy Gradient.
"""

import argparse
import json
import os
from pprint import pprint

import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import gymnasium as gym

import machina as mc
from machina.pols import GaussianPol, CategoricalPol, MultiCategoricalPol
from machina.algos import ppo_clip, vpg
from machina.vfuncs import DeterministicSVfunc
from machina.envs import GymEnv, C2DEnv
from machina.traj import Traj
from machina.traj import epi_functional as ef
from machina.samplers import EpiSampler
from machina import logger
from machina.utils import measure, set_device

from simple_net import PolNet, VNet, PolNetLSTM, VNetLSTM, GPNet
from gp_sparsify import sparsify, gp_train, torch_episode, weight_convert
import copy

import matplotlib.pyplot as plt
from scipy import stats

from tqdm import tqdm

random.seed(4242)
seeds = [random.randrange(9999) for i in range(1000)]
pt_seed = 0

def args_init():
    parser = argparse.ArgumentParser()
    parser.add_argument('--log', type=str, default='garbage',
                        help='Directory name of log.')
    parser.add_argument('--env_name', type=str,
                        default='Hopper-v4', help='Name of environment.')
    parser.add_argument('--c2d', action='store_true',
                        default=False, help='If True, action is discretized.')
    parser.add_argument('--record', action='store_true',
                        default=False, help='If True, movie is saved.')
    parser.add_argument('--seed', type=int, default=256)
    parser.add_argument('--max_epis', type=int,
                        default=1000000, help='Number of episodes to run.')
    parser.add_argument('--num_parallel', type=int, default=16,
                        help='Number of processes to sample.')
    parser.add_argument('--cuda', type=int, default='-1', help='cuda device number.')
    parser.add_argument('--max_epis_per_iter', type=int, default=64,
                        help='Number of episodes to use in an iteration.')
    parser.add_argument('--epoch_per_iter', type=int, default=10,
                        help='Number of epoch in an iteration')
    parser.add_argument('--pol_lr', type=float, default=3e-4,
                        help='Policy learning rate')
    parser.add_argument('--vf_lr', type=float, default=3e-4,
                        help='Value function learning rate')

    parser.add_argument('--n_kq', type=int, default=0,
                        help='Number of episodes after sparsification')
    parser.add_argument('--gp_lr', type=float, default=3e-4,
                        help='Gaussian process learning rate')
    parser.add_argument('--ctrl_var', type=bool, default=True,
                        help='If using a control variate')
    parser.add_argument('--rew_gp', type=bool, default=True,
                        help='If modeling the reward by GP')

    parser.add_argument('--rnn', action='store_true',
                        default=False, help='If True, network is reccurent.')
    parser.add_argument('--rnn_batch_size', type=int, default=8,
                        help='Number of sequences included in batch of rnn.')
    parser.add_argument('--max_grad_norm', type=float, default=10,
                        help='Value of maximum gradient norm.')

    parser.add_argument('--algo_type', type=str,
                        choices=['ppo', 'vpg'], default='ppo', help='Type of Proximal Policy Optimization.')

    parser.add_argument('--clip_param', type=float, default=0.2,
                        help='Value of clipping liklihood ratio.')

    parser.add_argument('--kl_targ', type=float, default=0.01,
                        help='Target value of kl divergence.')
    parser.add_argument('--init_kl_beta', type=float,
                        default=1, help='Initial kl coefficient.')

    parser.add_argument('--discount_grad', type=bool, default=True,
                        help='Whether or not dicounting the gradient (not the reward)')
    parser.add_argument('--gamma', type=float, default=0.995,
                        help='Discount factor.')
    parser.add_argument('--lam', type=float, default=1,
                        help='Tradeoff value of bias variance.')
    parser.add_argument('--num_iter', type=int, default=300,
                        help='Number of minibatch policy gradient iterations.')
    parser.add_argument('--batch_size', type=int, default=256,
                        help='Equal to max_epis_per_iter in this code.')
    parser.add_argument('--cv_only', type=bool, default=False)
    parser.add_argument('--use_thin', type=bool, default=False)
    parser.add_argument('--use_herd', type=bool, default=False)
    args = parser.parse_args()
    args.batch_size = args.max_epis_per_iter # special requirement in this code

    return args

def run_algo(args):
    
    global pt_seed
    fix_seed(seeds[pt_seed])
    args.seed = seeds[pt_seed]
    pt_seed += 1
    if pt_seed > len(seeds):
        pt_seed %= len(seeds)
    
    list_rews = []
    list_steps = []
    list_iters = []

    env = GymEnv(args.env_name)
    #, log_dir=os.path.join(args.log, 'movie'), record_video=args.record)
    env.action_space.seed(args.seed)
    if args.c2d:
        env = C2DEnv(env)

    observation_space = env.observation_space
    action_space = env.action_space

    if args.rnn:
        pol_net = PolNetLSTM(observation_space, action_space,
                            h_size=256, cell_size=256)
    else:
        pol_net = PolNet(observation_space, action_space)
    if isinstance(action_space, gym.spaces.Box):
        pol = GaussianPol(observation_space, action_space, pol_net, args.rnn)
    elif isinstance(action_space, gym.spaces.Discrete):
        pol = CategoricalPol(observation_space, action_space, pol_net, args.rnn)
    elif isinstance(action_space, gym.spaces.MultiDiscrete):
        pol = MultiCategoricalPol(
            observation_space, action_space, pol_net, args.rnn)
    else:
        raise ValueError('Only Box, Discrete, and MultiDiscrete are supported')

    if args.rnn:
        vf_net = VNetLSTM(observation_space, h_size=256, cell_size=256)
    else:
        vf_net = VNet(observation_space)
    vf = DeterministicSVfunc(observation_space, vf_net, args.rnn)

    sampler = EpiSampler(env, pol, num_parallel=args.num_parallel, seed=args.seed)

    optim_pol = torch.optim.Adam(pol_net.parameters(), args.pol_lr)
    optim_vf = torch.optim.Adam(vf_net.parameters(), args.vf_lr)

    if args.n_kq > 0:
        if args.n_kq >= args.max_epis_per_iter:
            raise ValueError('n_kq needs to be smaller than max_epis_per_iter')
        gp_net = GPNet(observation_space, action_space,
                       args.ctrl_var, args.rew_gp, args.gamma, args.discount_grad).to(device)
        optim_gp = torch.optim.Adam(gp_net.parameters(), args.gp_lr)

    total_step = 0

    # max_rew = -1e6
    # kl_beta = args.init_kl_beta
    for i_iter in tqdm(range(1, args.num_iter+1)):
        wei, wei_all = None, None
        traj_cv, _, st_cv = None, None, None
        list_steps.append(total_step)
        list_iters.append(i_iter-1)
        with measure('sample', log_enable=False):
            epis = sampler.sample(pol, max_epis=args.max_epis_per_iter)
            epis = epis[:args.max_epis_per_iter]
            rewards = [sum(epi['rews']) for epi in epis]
            list_rews.append(np.mean(rewards))
            if args.n_kq > 0:
                if args.ctrl_var:
                    traj_cv = Traj()
                    traj_cv.add_epis(copy.deepcopy(epis))
                    for epi in traj_cv.current_epis:
                        epi['rews'] *= 0.
                    # black-out rewards
                    emb_all = [gp_net.input_feature(torch_episode(epi)) for epi in traj_cv.current_epis]
                    traj_cv = ef.compute_vs(traj_cv, vf)
                    
                    for i in range(len(emb_all)):
                        cv = gp_net.mean_func(emb_all[i].to(device)).detach().cpu().numpy().reshape(-1,)
                        if gp_net.rew_gp:
                            traj_cv.current_epis[i]['rews'] = cv
                        else:
                            traj_cv.current_epis[i]['rets'] = cv
                    if gp_net.rew_gp:
                        traj_cv = ef.compute_rets(traj_cv, args.gamma)
                        traj_cv = ef.compute_advs(traj_cv, args.gamma, args.lam)
                    
                    else:
                        for i in range(len(emb_all)):
                            traj_cv.current_epis[i]['advs'] = traj_cv.current_epis[i]['rets'] - traj_cv.current_epis[i]['vs']
                if args.cv_only:
                    idx = np.random.permutation(args.max_epis_per_iter)[:args.n_kq]
                    wei = None
                else:
                    idx, wei = sparsify(gp_net, epis, args.n_kq, args.use_thin, args.use_herd)
                epis = [epis[i] for i in idx]
                if args.ctrl_var:
                    cv_rets = [traj_cv.current_epis[i]['rets'] for i in idx]
            elif args.discount_grad:
                idx_all = np.arange(len(epis))
                wei_all = np.ones(len(epis))
                wei_all = weight_convert(epis, idx_all, wei_all,
                                         discount_grad=args.discount_grad, gamma=args.gamma)
                if wei is None:
                    wei = wei_all

        with measure('train', log_enable=False):
            traj = Traj()
            traj.add_epis(epis)
            traj = ef.compute_rets(traj, args.gamma)
            if args.n_kq > 0 and args.ctrl_var:
                epis_ = traj.current_epis
                emb_all = [gp_net.input_feature(torch_episode(epi)) for epi in epis_]
                for i in range(len(emb_all)):
                    epis_[i]['advs'] = epis_[i]['rets'] - cv_rets[i]
            else:
                traj = ef.compute_vs(traj, vf)
                traj = ef.compute_advs(traj, args.gamma, args.lam)
            if traj_cv is not None:
                traj_cv, _, st_cv = ef.centerize_advs(traj_cv)
                traj_cv.register_epis()
            traj, _, _ = ef.centerize_advs(traj, 0., st_cv)
            traj = ef.compute_h_masks(traj)
            if args.n_kq > 0:
                gp_train(gp_net, optim_gp, traj, wei, cv_only=args.cv_only)
            
            traj.register_epis()

            if args.algo_type == 'ppo':
                ppo_clip.w_train(traj=traj, pol=pol, vf=vf, clip_param=args.clip_param,
                                            optim_pol=optim_pol, optim_vf=optim_vf, epoch=args.epoch_per_iter, max_grad_norm=args.max_grad_norm,
                                            wei=wei, traj_cv=traj_cv, wei_cv=wei_all, log_enable=False)
            else:
                vpg.w_train(traj=traj, pol=pol, vf=vf,
                                        optim_pol=optim_pol, optim_vf=optim_vf, wei=wei, traj_cv=traj_cv, wei_cv=wei_all, log_enable=False)   

        total_step += traj.num_step
        del traj
        del traj_cv
    del sampler
    data = {'rews':np.array(list_rews),
            'steps':np.array(list_steps),
            'iters':np.array(list_iters)}
    return data

def fix_seed(seed=42, strong_fix=False):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    if strong_fix:
        torch.backends.cudnn.deterministic = True
        torch.use_deterministic_algorithms = True

def arg_method(method_name, args, n_kq):
    ret = copy.deepcopy(args)
    if 'vpg' in method_name:
        ret.algo_type = 'vpg'
    else:
        ret.algo_type = 'ppo'
    if method_name in ['vpg', 'ppo']:
        ret.max_epis_per_iter = n_kq
    if 'KQ' in method_name:
        ret.n_kq = n_kq
        ret.ctrl_var = True
        ret.rew_gp = True
    if 'no-mean' in method_name:
        ret.n_kq = n_kq
        ret.ctrl_var = False
        ret.rew_gp = False
    if 'thin' in method_name:
        ret.use_thin = True
    elif 'herd' in method_name:
        ret.use_herd = True
    return ret

def data_array(rews, xs, x_size):
    r_array = []
    x_len = copy.copy(x_size)
    for x_ in xs:
        x_len = min(x_len, x_[-1]+1)
    x_array = np.arange(x_len)
    for i in range(len(rews)):
        r_array.append(np.interp(x_array, xs[i], rews[i]))
    r_array = np.stack(r_array)
    r_mean = np.mean(r_array, axis=0)
    r_std = stats.sem(r_array, axis=0) # standard error
    return x_array, r_mean, r_std


def stat_method(m_name, args, n_kq, stat_size, mins):
    algo_args = arg_method(m_name, args, n_kq)
    all_names = ['rews', 'steps', 'iters']
    x_names = ['steps', 'iters']
    data = dict()
    for a_name in all_names:
        data[a_name] = []
    for i in range(stat_size):
        out_algo = run_algo(algo_args)
        print('Running ' + m_name + ' ({}/{}) has finished!'.format(i+1, stat_size))
        for a_name in all_names:
            data[a_name].append(out_algo[a_name])
    
    out = dict()
    for x_name in x_names:
        out[x_name] = data_array(data['rews'], data[x_name], mins[x_name])

    return out


if __name__ == '__main__':

    args = args_init()
    fix_seed()
    
    device_name = "cuda:{}".format(args.cuda) if (args.cuda >= 0 and torch.cuda.is_available()) else 'cpu'
    device = torch.device(device_name)
    set_device(device)

    args.env_name = 'Hopper-v4' # 'HalfCheetah-v4' # 'InvertedDoublePendulum-v4' # 'Walker2d-v4' # 
    args.num_iter = 500
    stat_size = 5
    lr = 3e-4
    args.pol_lr = args.vf_lr = args.gp_lr = lr

    n_kq = 8 # n_kq
    method_names = ['ppo', 'ppo-KQ', 'ppo-KQ-no-mean', 'ppo-large']
    # method_names = ['vpg', 'vpg-KQ', 'vpg-KQ-no-mean', 'vpg-large']
    # method_names = ['ppo-KQ-thin', 'ppo-KQ-herd', 'ppo-KQ-thin-no-mean', 'ppo-KQ-herd-no-mean']
    # method_names = ['vpg-KQ-thin', 'vpg-KQ-herd', 'vpg-KQ-thin-no-mean', 'vpg-KQ-herd-no-mean']
    mins = {'iters': args.num_iter, 'steps':2000000}

    data = dict()
    for m_name in method_names:
        data[m_name] = stat_method(m_name, args, n_kq, stat_size, mins)
    
    x_names = ['iters', 'steps']
    for x_name in x_names:
        fig = plt.figure()
        m_steps = mins['steps']
        if 'ppo-KQ' in method_names:
            x, _, _ = data['ppo-KQ']['steps']
            m_steps = min(m_steps, len(x))
        elif 'vpg-KQ' in method_names:
            x, _, _ = data['vpg-KQ']['steps']
            m_steps = min(m_steps, len(x))

        for m_name in method_names:
            x, y, y_std = data[m_name][x_name]
            if x_name=='steps':
                itvl = 1 if m_steps < 10000 else m_steps // 10000
                x = x[0:m_steps:itvl]
                y = y[0:m_steps:itvl]
                y_std = y_std[0:m_steps:itvl]
            plt.plot(x, y, label = m_name)
            plt.fill_between(x, y-y_std, y+y_std, alpha=0.3)
            # if (m_name in ['vpg-KQ', 'ppo-KQ']) and (x_name in ['epis', 'steps']):
            #     x, y, y_std = data[m_name]['v'+x_name]
            #     plt.plot(x, y, label = m_name+' (virtual)')
            #     plt.fill_between(x, y-y_std, y+y_std, alpha=0.3)
        plt.legend(loc='upper left')    
        fig.savefig("results/{}_{}_{}_{}_{}_{}_{}.pdf".format(args.env_name, x_name, args.num_iter, args.max_epis_per_iter, n_kq, stat_size, lr))

            
