from procgen import ProcgenEnv

import os
import yaml
import time
import math
import types
import torch
import argparse
import numpy as np
from tqdm import trange
from datetime import datetime
from collections import deque
import utils.logger as logger

from torch.nn import functional as F
from torch.func import vmap, grad, functional_call

# pytorch distributed training
import torch.multiprocessing as mp

from utils.runners import Runner
from torch.optim import Adam, SGD, RMSprop
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter

from utils.utils import build_cnn, build_resnet, build_mlp
from utils.utils import ActorCritic, count_vars, safemean, set_seed
from vec_env import ( VecExtractDictObs, VecMonitor, VecNormalize)

def learn(world_size, algo, actor_critic, writer, venv, device,
          total_timesteps, nsteps, algo_config, log_config, log_dir=None):

    gamma = .999
    lam = .95

    per_epoch_timesteps = nsteps * venv.num_envs
    # epochs = total_timesteps // (per_epoch_timesteps * world_size) + 1
    epochs = total_timesteps // per_epoch_timesteps + 1

    pi_minibatch_size = per_epoch_timesteps // algo_config.pi_minibatches
    v_minibatch_size = per_epoch_timesteps // algo_config.v_minibatches

    # Instantiate the runner object
    runner = Runner(env=venv, model=actor_critic, nsteps=nsteps, gamma=gamma, lam=lam, adv_type=algo_config.adv_type, device=device)
    epinfobuf = deque(maxlen=100)

    params_pi = list(actor_critic.pi_net.parameters())
    dict_params = {k: v.detach() for k, v in actor_critic.pi_net.named_parameters() if v.requires_grad}
    dict_buffers = {k: v.detach() for k, v in actor_critic.pi_net.named_buffers() if v.requires_grad}

    if algo_config.optimizer == 'sgd':
        # momentum is enabled to facilitate the implementation of adv-normalized SGD
        # gradient update does not use momentum
        pi_optimizer = SGD(params_pi, lr=algo_config.lr_pi, momentum=1e-6)
    else:
        raise NotImplementedError
    
    if hasattr(algo_config, 'lr_decay') and algo_config.lr_decay == 'cosine':
        pi_scheduler = CosineAnnealingLR(pi_optimizer, T_max=epochs*algo_config.pi_epochs*algo_config.pi_minibatches, eta_min=0.01)
    else:
        pi_scheduler = None

    v_optimizer = Adam(actor_critic.v_net.parameters(), lr=algo_config.lr_v)

    # Start total timer
    tfirststart = time.perf_counter()

    def RAT_ActorUpdate(_obs, _act, _adv, _outputs_old):
        _outputs = actor_critic.forward_pi(_obs)

        if actor_critic.is_discrete:
            _logp_full = F.log_softmax(_outputs, dim=-1)
            _logp_full_old = F.log_softmax(_outputs_old, dim=-1)
            _llr = torch.gather(_logp_full - _logp_full_old, dim=-1, index=_act.unsqueeze(-1)).squeeze(1)
            _ratio = torch.exp(_llr)
            _p_log_p = torch.exp(_logp_full) * _logp_full
            _entropy = - _p_log_p.sum(-1).mean()
            _logp = torch.gather(_logp_full, dim=-1, index=_act.unsqueeze(-1)).squeeze(1)
            _real_kl = (torch.exp(_logp_full_old) * (_logp_full_old - _logp_full)).sum(dim=-1).mean()

            def compute_logp(params, buffers, batch_obs, batch_act):
                batch_obs, batch_act = batch_obs.unsqueeze(0), batch_act.unsqueeze(0)
                batch_outs = functional_call(actor_critic.pi_net, (params, buffers), (batch_obs,) )
                batch_logp_full = F.log_softmax(batch_outs, dim=-1)
                batch_logp = torch.gather(batch_logp_full, dim=-1, index=batch_act.unsqueeze(-1)).squeeze(1)
                return batch_logp.squeeze(0)

        else:
            _mu, _logstd = _outputs.chunk(2, dim=-1)
            _dist = torch.distributions.Normal(_mu, torch.exp(_logstd))
            _logp = _dist.log_prob(_act).sum(dim=-1) 

            _mu_old, _logstd_old = _outputs_old.chunk(2, dim=-1)
            _dist_old = torch.distributions.Normal(_mu_old, torch.exp(_logstd_old))
            _logp_old = _dist_old.log_prob(_act).sum(dim=-1)

            _llr = _logp - _logp_old
            _ratio = torch.exp(_llr)
            _entropy = _dist.entropy().sum(dim=-1).mean()
            _real_kl = (_logstd - _logstd_old + 0.5 * ( torch.exp(_logstd_old).pow(2) + (_mu_old - _mu).pow(2) ) / torch.exp(_logstd).pow(2) - 0.5).sum(dim=-1).mean()

            def compute_logp(params, buffers, batch_obs, batch_act):
                batch_obs, batch_act = batch_obs.unsqueeze(0), batch_act.unsqueeze(0)
                batch_outs = functional_call(actor_critic.pi_net, (params, buffers), (batch_obs,) )
                batch_mu, batch_logstd = batch_outs.chunk(2, dim=-1)

                var = torch.exp(batch_logstd)**2
                batch_logp = (
                    -((batch_act - batch_mu) ** 2) / (2 * var)
                    - batch_logstd
                    - math.log(math.sqrt(2 * math.pi))
                )

                return batch_logp.sum(dim=-1).squeeze(0)

        # zero mean of advantage
        _adv = _adv - _adv.mean() 
        
        # clamp the ratio
        if algo_config.clamp_ratio:
            _ratio = torch.clamp(_ratio, algo_config.min_ratio, algo_config.max_ratio)

        if algo_config.norm_obj == 'adv':
            _rms_sqrt = torch.sqrt( _adv.pow(2).mean() ).detach()
        elif algo_config.norm_obj == 'obj':
            _rms_sqrt = torch.sqrt( (_ratio * _adv).pow(2).mean() ).detach() # might related to variance reduction in importance sampling
        elif algo_config.norm_obj == 'ratio':
            _rms_sqrt = _ratio.mean().detach() * torch.sqrt( _adv.pow(2).mean() ).detach()
        else: 
            raise NotImplementedError
        _adv = _adv / (_rms_sqrt + 1e-8)

        pi_optimizer.zero_grad()
        ft_compute_sample_grad = vmap(grad(compute_logp), in_dims=(None, None, 0, 0))
        ft_per_sample_grads = ft_compute_sample_grad(dict_params, dict_buffers, _obs, _act) # num_samples x param_shape

        with torch.no_grad():
            num_sa = _obs.shape[0]
            H = torch.cat([v.contiguous().view(num_sa, -1) for v in ft_per_sample_grads.values()], dim=-1)  # num_samples x num_params
            HHT = H @ H.t() @ torch.diag(_ratio) / num_sa # num_samples x num_samples

            gk_list = [ v['momentum_buffer'].contiguous().flatten() for v in pi_optimizer.state.values() if v['momentum_buffer'] is not None ]
            if algo_config.is_karzmarz and len(gk_list) > 0:
                g_k = torch.cat(gk_list, dim=0)
                _adv = _adv - torch.mv(H, g_k)

            _png_adv = torch.linalg.solve( HHT + algo_config.cg_damping * torch.eye(num_sa, device=device), _adv)

        # udpate actor
        _loss_pi = (- _ratio.detach() * _logp * _png_adv).mean() 
        pi_optimizer.zero_grad()
        _loss = _loss_pi - algo_config.ent_coef * _entropy
        _loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(params_pi, algo_config.max_grad_norm) 
        pi_optimizer.step()
        if pi_scheduler is not None:
            pi_scheduler.step()

        # Useful extra info
        with torch.no_grad():
            clipfrac = 0.0
            pi_info = dict(kl=_real_kl.item(), curr_lr=pi_optimizer.param_groups[0]['lr'], ent=_entropy.item(), cf=clipfrac, 
                           grad_norm=grad_norm.item(), ratio_max=_ratio.max().item(), ratio_min=_ratio.min().item())

        return _loss, _loss_pi, pi_info

    # choose the policy update rule
    if algo in {'rat'}: 
        update_actor = RAT_ActorUpdate
    else: 
        raise NotImplementedError

    tepochs = trange(epochs+1, desc='Epoch starts', leave=True)

    # Main loop: collect experience in env and update/log each epoch
    inds = np.arange(per_epoch_timesteps)
    compute_time = []

    for epoch in tepochs:
        tstart = time.perf_counter()

        tepochs.set_description('Stepping environment...')

        actor_critic.eval() # set to eval mode for PPO
        obs, ret, act, adv, outputs_old, epinfos = runner.run() #pylint: disable=E0632

        epinfobuf.extend(epinfos)
        tepochs.set_description('Minibatch training...')

        # pop art
        if actor_critic.with_popart:
            actor_critic.last_v_layer.update(ret) # update the mean/var
            ret = actor_critic.last_v_layer.normalize(ret)
            adv = actor_critic.last_v_layer.normalize(adv)

        if actor_critic.obs_rms is not None:
            actor_critic.obs_rms.training = True
            obs = actor_critic.obs_rms(obs) # norm obs for training
            actor_critic.obs_rms.training = False
            # recalculate outputs_old with normalized obs
            with torch.no_grad():
                outputs_old = actor_critic.forward_pi(obs)

        actor_critic.train()  # set to train mode
        actor_tstart = time.perf_counter()
        for _ in range(algo_config.pi_epochs):
            # Randomize the indexes
            np.random.shuffle(inds)
            # 0 to batch_size with batch_train_size step
            for start in range(0, per_epoch_timesteps, pi_minibatch_size):
                end = start + pi_minibatch_size
                mbinds = inds[start:end]
                mb_obs, mb_act, mb_adv, mb_outputs_old = obs[mbinds], act[mbinds], adv[mbinds], outputs_old[mbinds]
                mb_loss, mb_loss_pi, pi_info = update_actor(mb_obs, mb_act, mb_adv, mb_outputs_old)
        actor_tnow = time.perf_counter()
        actor_time_elapsed = actor_tnow - actor_tstart
        compute_time.append(actor_time_elapsed)

        # kl adaptive lr adjustment
        if algo_config.use_kl_adaptive_lr:
            curr_kl = pi_info['kl']
            if curr_kl > 0.008 * 2:
                pi_optimizer.param_groups[0]['lr'] = max(pi_optimizer.param_groups[0]['lr'] / 1.5, 1e-4)
            elif curr_kl < 0.008 / 2:
                pi_optimizer.param_groups[0]['lr'] = min(pi_optimizer.param_groups[0]['lr'] * 1.5, 5e-2)

        for _ in range(algo_config.v_epochs):
            # Randomize the indexes
            np.random.shuffle(inds)
            # 0 to batch_size with batch_train_size step
            for start in range(0, per_epoch_timesteps, v_minibatch_size):
                end = start + v_minibatch_size
                mbinds = inds[start:end]
                _obs, _ret = obs[mbinds], ret[mbinds]
                _vals = actor_critic.forward_v(_obs) # get the value estimate

                # value loss
                mb_loss_v = F.mse_loss(_vals, _ret)

                v_optimizer.zero_grad()
                mb_loss_v.backward()
                torch.nn.utils.clip_grad_norm_(actor_critic.v_net.parameters(), 5.0)
                v_optimizer.step()

        tepochs.set_postfix(loss_pi=mb_loss_pi.item(), loss_v=mb_loss_v.item(), entropy=pi_info['ent'], kl=pi_info['kl'], cf=pi_info['cf'], lr=pi_info['curr_lr'])

        # clean GPU cache
        torch.cuda.empty_cache()

        tnow = time.perf_counter()
        # Calculate the fps (frame per second)
        fps = int(per_epoch_timesteps / (tnow - tstart))

def train_fn(rank, world_size, algo, seed, algo_config, env_config, nets_config, log_config, device=-1):
    # Serialize data into file:
    time_now = datetime.now().strftime('%Y%m%d-%H%M%S')

    # Random seed
    if seed is None:
        seed = np.random.randint(1e6) + 10000 * rank # different seeds for each process
    set_seed(seed, torch_deterministic=True)

    env_name = env_config.env_name
    num_envs = env_config.num_envs

    if env_name in ['cartpole', 'acrobot', 'mountaincar', 'lunarlander', 'carracing', 'hopper', 'invertedpendulum', 'inverteddoublependulum',
                    'halfcheetah', 'walker2d', 'humanoid', 'humanoidstandup', 'reacher', 'swimmer', 'ant']:
        timesteps_per_proc = env_config.timesteps_per_proc

    elif 'atari' not in env_name:
        env_name, distribution_mode, start_level, num_levels = env_name.split('-')
        start_level, num_levels = int(start_level), int(num_levels)

        if distribution_mode == 'easy':
            timesteps_per_proc = env_config.timesteps_per_proc_easy
        elif distribution_mode == 'hard':
            timesteps_per_proc = env_config.timesteps_per_proc_hard

    if rank==0:
        if env_name in {'cartpole', 'acrobot', 'mountaincar', 'lunarlander', 'carracing', 'hopper', 'invertedpendulum', 'inverteddoublependulum',
                        'halfcheetah', 'walker2d', 'humanoid', 'humanoidstandup', 'reacher', 'swimmer', 'ant'}:
            log_dir = f"logs/{algo}.karzmarz_{algo_config.is_karzmarz}.{nets_config.type}.a{nets_config.a_hidden_size}x{nets_config.a_num_layers}x{nets_config.a_dropout}e{algo_config.pi_epochs}x{algo_config.pi_minibatches}.c{nets_config.c_hidden_size}x{nets_config.c_num_layers}x{nets_config.c_dropout}e{algo_config.v_epochs}x{algo_config.v_minibatches}.{algo_config.grad}_{algo_config.post_grad}_{algo_config.max_grad_norm}.{algo_config.sigma_type}.damping_{algo_config.cg_damping}.lr_pi_{algo_config.lr_pi}/{env_name}.{time_now}_{seed}"
        else:
            log_dir = f"logs/{algo}.karzmarz_{algo_config.is_karzmarz}.{nets_config.type}{'_bn' if nets_config.with_bn else ''}_{algo_config.pi_epochs}epoch.damping_{algo_config.cg_damping}.lr_pi_{algo_config.lr_pi}/{env_config.env_name}.{time_now}_{seed}"

        format_strs = ['csv', 'stdout'] 
        logger.configure(dir=log_dir, format_strs=format_strs)
        writer = SummaryWriter(log_dir=log_dir)
    else:
        log_dir = None
        writer = None
    
    if rank==0:
        logger.info("creating environment")

    if 'atari' in env_name:
        from stable_baselines3.common.env_util import make_atari_env
        from stable_baselines3.common.vec_env import VecFrameStack
        env_name = env_name.split('.')[1]
        # use atari env with terminal on life loss for better value bootstrap
        # cannot use VecMonitor then: episodic return and length will be incorrect
        # venv = make_atari_env(env_name, n_envs=num_envs, monitor_dir=log_dir, wrapper_kwargs={'terminal_on_life_loss': True})
        venv = make_atari_env(env_name, n_envs=num_envs)
        venv = VecFrameStack(venv, n_stack=3) # set stack number to 3 (compatible with Procgen number of channels)
        timesteps_per_proc = env_config.timesteps_per_proc # 10M for atari envs
        distribution_mode = 'atari'

    elif env_name in ['cartpole', 'acrobot', 'mountaincar', 'lunarlander', 'carracing', 'invertedpendulum', 'inverteddoublependulum',
                      'hopper', 'halfcheetah', 'walker2d', 'humanoid', 'humanoidstandup', 'reacher', 'swimmer', 'ant']:
        from stable_baselines3.common.env_util import make_vec_env
        tag_name = {'cartpole': 'CartPole-v1', 'acrobot': 'Acrobot-v1', 'mountaincar': 'MountainCar-v0', 
                    'lunarlander': 'LunarLander-v2', 'carracing': 'CarRacing-v2', 'invertedpendulum': 'InvertedPendulum-v4',
                    'inverteddoublependulum': 'InvertedDoublePendulum-v4',
                    'hopper': 'Hopper-v4', 'halfcheetah': 'HalfCheetah-v4', 'walker2d': 'Walker2d-v4', 
                    'humanoid': 'Humanoid-v4', 'humanoidstandup': 'HumanoidStandup-v4', 'reacher': 'Reacher-v4', 
                    'swimmer': 'Swimmer-v3', 'ant': 'Ant-v4'}
        
        venv = make_vec_env(tag_name[env_name], n_envs=num_envs, env_kwargs={'continuous': False} if env_name == 'carracing' else {})

    else:
        venv = ProcgenEnv(num_envs=num_envs, env_name=env_name, num_levels=num_levels, start_level=start_level, distribution_mode=distribution_mode, rand_seed=seed)
        venv = VecExtractDictObs(venv, "rgb")
        venv = VecMonitor(venv=venv, filename=log_dir)

    if device == -1:
        if torch.cuda.is_available(): # i.e. for NVIDIA GPUs
            device_type = "cuda"
        else:
            device_type = "cpu"
        
        device = torch.device(device_type) # Select best available device
    else:
        assert device >= 0
        device = f"cuda:{device}"

    obs_space = venv.observation_space

    # Create actor-critic module
    if nets_config.type == 'resnet':
        # kwargs = {'with_bn': nets_config.with_bn, 'depths': [16, 32, 32], 'device': device}
        kwargs = {'with_bn': nets_config.with_bn, 'depths': [8, 16], 'device': device}
        fn_neural_nets, preprocess = build_resnet(obs_space.shape[0], nets_config.hidden_size, **kwargs)
        # now the obs_space becomes channel x height x width
        obs_shape = (obs_space.shape[2], obs_space.shape[0], obs_space.shape[1])

    elif nets_config.type == 'cnn':
        img_size = obs_space.shape[1]
        kwargs = {'with_bn': nets_config.with_bn, 'p_dropout': nets_config.dropout, 'device': device}
        fn_neural_nets, preprocess = build_cnn(img_size, nets_config.hidden_size, **kwargs)
        # now the obs_space becomes channel x height x width
        obs_shape = (obs_space.shape[2], obs_space.shape[0], obs_space.shape[1])

    elif nets_config.type == 'mlp':
        kwargs = {'device': device}
        fn_neural_nets, preprocess = build_mlp(obs_space, **kwargs)
        obs_shape = obs_space.shape

    else: 
        raise NotImplementedError

    act_num, act_dim = None, None
    try:
        act_num = venv.action_space.n
    except AttributeError:
        act_dim = venv.action_space.shape[0]

    actor_critic = ActorCritic(fn_neural_nets, obs_shape, nets_config=nets_config, n_actions=act_num, 
                            dim_actions=act_dim, with_popart=algo_config.with_popart, 
                            sigma_type=algo_config.sigma_type, device=device).to(device)

    venv = VecNormalize(venv=venv, norm_ret=env_config.norm_ret, obs_preprocess=preprocess) # img transform and reward normalization

    if rank==0:
        logger.info(f'Running on device: {device}')
        logger.info(f"training...")

        # Count variables
        var_counts = count_vars(actor_critic)
        logger.log(f'\nNumber of parameters: {var_counts}\n')

        # yaml.dump(args, open( f"{log_dir}/args.yaml", 'w' ))
        config = {'algo_config': algo_config.__dict__, 
                'env_config': env_config.__dict__, 
                'nets_config': nets_config.__dict__, 
                'log_config': log_config.__dict__}

        yaml.dump(config, open( f"{log_dir}/config.yaml", 'w' ))

    learn(world_size, algo, actor_critic, writer, venv, device,
          total_timesteps=timesteps_per_proc, nsteps=env_config.nsteps, 
          algo_config=algo_config, log_config=log_config, log_dir=log_dir)

def main():
    parser = argparse.ArgumentParser(description='Process procgen training arguments.')
    parser.add_argument('--config', type=str, default='rat_mlp.yaml')
    parser.add_argument('--device', type=int, default=-1) # -1: use any available device
    parser.add_argument('--env_name', type=str, default=None) # -1: use any available device
    parser.add_argument('--n_proc', type=int, default=1) # distributed training: number of processes
    parser.add_argument('--port_num', type=int, default=29500) # distributed training: number of processes
    parser.add_argument('--a_dropout', type=float, default=None) # distributed training: number of processes
    parser.add_argument('--a_hidden_size', type=int, default=None) # distributed training: number of processes
    parser.add_argument('--a_num_layers', type=int, default=None) # distributed training: number of processes
    parser.add_argument('--c_dropout', type=float, default=None) # distributed training: number of processes
    parser.add_argument('--c_hidden_size', type=int, default=None) # distributed training: number of processes
    parser.add_argument('--c_num_layers', type=int, default=None) # distributed training: number of processes
    parser.add_argument('--norm_obj', type=str, default=None) # distributed training: number of processes
    parser.add_argument('--optimizer', type=str, default=None) # distributed training: number of processes
    parser.add_argument('--sigma_type', type=str, default=None, choices=['vector', 'mu_shared', 'separate', 'linear']) 
    parser.add_argument('--cg_damping', type=float, default=None) # distributed training: number of processes
    parser.add_argument('--pi_epochs', type=int, default=None) # distributed training: number of processes
    parser.add_argument('--timesteps_per_proc', type=int, default=None) # distributed training: number of processes
    parser.add_argument('--lr_pi', type=float, default=None) # distributed training: number of processes
    parser.add_argument('--grad', type=str, default=None) # distributed training: number of processes
    parser.add_argument('--post_grad', type=str, default=None) # distributed training: number of processes
    parser.add_argument('--seed', type=int, default=None) 

    args = parser.parse_args()

    with open(f'configs/{args.config}') as fin:
        config = yaml.safe_load(fin)

    algo = config['algo']
    algo_config = types.SimpleNamespace(**config['algo_config'])
    env_config = types.SimpleNamespace(**config['env_config'])
    nets_config = types.SimpleNamespace(**config['nets_config'])
    log_config = types.SimpleNamespace(**config['log_config'])

    if args.n_proc > 1:
        # multiple nodes
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = str(args.port_num)

        mp.spawn(train_fn, args=(args.n_proc, algo, args.seed, algo_config, env_config, nets_config, log_config, args.device),
                        nprocs=args.n_proc, # INFO: for TPU, either 1 or the maximum number of TPU chips
                        join=True)

    else:
        train_fn(0, args.n_proc, algo, args.seed, algo_config, env_config, nets_config, log_config, args.device)

if __name__ == '__main__':
    main()
