import numpy as np
import torch
import gym
import argparse

from tdmpc_pamdp import Trainer
from utils import Episode

# import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"


def run(args):
    trainer = Trainer(args)

    total_timesteps = 0
    total_episodes = 0
    max_per_epi_steps = args.episode_length

    # episode_rewards = []

    if args.embed:
        trainer.pretrain(args.pretrain_steps)

    trainer.evaluate()
    # exit()

    while total_timesteps < args.max_timesteps:
        state = trainer.reset()
        episode = Episode(trainer.args, state)

        episode_reward = 0.

        for j in range(max_per_epi_steps):
            # print(i)
            with torch.no_grad():
                act, act_param = trainer.plan(state, step=total_episodes, t0=(j==0), local_step=j)
                action = trainer.pad_action(act, act_param)
                state, reward, terminal = trainer.act(action, j, pre_state=state)

            episode += (state, act, act_param, reward, terminal)

            total_timesteps += 1
            episode_reward += reward

            # if total_timesteps % args.train_every == 0 and total_episodes >= args.seed_steps:
            #     trainer.train()
            print(total_episodes, j, total_timesteps)

            train_metrics = {}
            if total_episodes >= args.seed_steps:
                # num_updates = args.seed_steps if total_episodes == args.seed_steps else args.episode_length
                # print(num_updates)
                for i in range(args.num_updates):

                    # train_log = trainer.train_sperate(total_episodes+i)
                    train_log = trainer.train(total_episodes+i)

                    train_metrics.update(train_log)
                    trainer.upload_log(train_log)
                
                    # trainer.model.debug_grad()
                    # if i == 1:
                    #     exit()

            if total_timesteps % args.eval_freq == 0:
                while not terminal:
                    act, act_param = trainer.plan(state, step=total_episodes, t0=(j==0), local_step=j)
                    action = trainer.pad_action(act, act_param)
                    state, reward, terminal = trainer.act(action, j, pre_state=state)
                    total_timesteps += 1
                    episode_reward += reward
                    j += 1
                # episode_rewards.append(episode_reward)

                trainer.evaluate()
                trainer.save_local()
                break

            if args.embed:
                if total_timesteps % 10 == 0 and total_timesteps >= 1000:
                    trainer.vae_train()

            if terminal:
                # episode_rewards.append(episode_reward)
                break
            
        trainer.buffer += episode
        total_episodes += 1
        # print(total_episodes)
        # exit()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", default='Platform-v0')  # 'Platform-v0', 'Goal-v0', "hard_goal-v0", 'simple_catch-v0', 'simple_move_4_direction_v1-v0'
    parser.add_argument("--seed", default=0, type=int)  # Sets Gym, PyTorch and Numpy seeds

    parser.add_argument("--max_timesteps", default=50_000, type=int)  # Max time steps to run environment for
    parser.add_argument("--eval_freq", default=50, type=int)  # How often (time steps) we evaluate
    parser.add_argument("--expl_noise", default=0.1)  # Std of Gaussian exploration noise 0.1

    parser.add_argument("--train_every", default=50, type=int)  # How often (time steps) we evaluate
    parser.add_argument("--eval_eposides", default=50, type=int)
    parser.add_argument("--num_updates", default=25, type=int)

    parser.add_argument("--pretrain_steps", default=10_000, type=int)  # How many steps to pre-train

    parser.add_argument('--action_n_dim', default=4, help='action_n_dim.', type=int)

    parser.add_argument("--dm_hard", default=0, type=int)
    parser.add_argument("--dm_epoch", default=50, type=int)  # how many epochs to train the dynamic model
    parser.add_argument("--dm_lr", default=5e-4, type=float)
    parser.add_argument("--dm_batchsize", default=256, type=int)
    parser.add_argument("--dm_saveflag", default=0, type=int)
    parser.add_argument("--dm_savepath", default='mpc_model/models/dm.pth', type=str)
    parser.add_argument("--dm_loadpath", default='mpc_model/models/dm.pth', type=str)
    parser.add_argument("--dm_valflag", default=0, type=int)
    parser.add_argument("--dm_valfreq", default=500, type=int)  # freq
    parser.add_argument("--dm_valrati", default=0., type=float)
    parser.add_argument("--dm_loadmodel", default=0, type=int)
    parser.add_argument("--dm_layers", default=64, type=int)
    parser.add_argument("--dm_datasetlen", default=1e5, type=int)

    parser.add_argument("--onepa", default=1, type=int)
    parser.add_argument("--use_terminal", default=1, type=int)
    parser.add_argument("--change_r", default=0, type=int)

    parser.add_argument("--r_hard", default=0, type=int)
    parser.add_argument("--r_epoch", default=50, type=int)  # how many epochs to train the dynamic model
    parser.add_argument("--r_lr", default=5e-4, type=float)
    parser.add_argument("--r_batchsize", default=256, type=int)
    parser.add_argument("--r_saveflag", default=0, type=int)
    parser.add_argument("--r_savepath", default='mpc_model/models/r.pth', type=str)
    parser.add_argument("--r_loadpath", default='mpc_model/models/r.pth', type=str)
    parser.add_argument("--r_valflag", default=0, type=int)
    parser.add_argument("--r_valfreq", default=500, type=int)  # freq
    parser.add_argument("--r_valrati", default=0., type=float)
    parser.add_argument("--r_loadmodel", default=0, type=int)
    parser.add_argument("--r_layers", default=64, type=int)
    parser.add_argument("--r_datasetlen", default=1e5, type=int)

    parser.add_argument("--c_hard", default=0, type=int)
    parser.add_argument("--c_epoch", default=50, type=int)  # how many epochs to train the dynamic model
    parser.add_argument("--c_lr", default=5e-4, type=float)
    parser.add_argument("--c_batchsize", default=256, type=int)
    parser.add_argument("--c_saveflag", default=0, type=int)
    parser.add_argument("--c_savepath", default='mpc_model/models/c.pth', type=str)
    parser.add_argument("--c_loadpath", default='mpc_model/models/c.pth', type=str)
    parser.add_argument("--c_valflag", default=0, type=int)
    parser.add_argument("--c_valfreq", default=500, type=int)  # freq
    parser.add_argument("--c_valrati", default=0., type=float)
    parser.add_argument("--c_loadmodel", default=0, type=int)
    parser.add_argument("--c_layers", default=64, type=int)
    parser.add_argument("--c_datasetlen", default=1e5, type=int)

    parser.add_argument("--mpc_horizon", default=2, type=int)
    parser.add_argument("--mpc_gamma", default=0.99, type=float)
    parser.add_argument("--mpc_popsize", default=1000, type=int)
    parser.add_argument("--mpc_num_elites", default=100, type=int)
    parser.add_argument("--mpc_patrical", default=1, type=int)
    parser.add_argument("--mpc_init_mean", default=0., type=float)
    parser.add_argument("--mpc_init_var", default=1., type=float)
    parser.add_argument("--mpc_epsilon", default=0.001, type=float)
    parser.add_argument("--mpc_alpha", default=0.1, type=float)
    parser.add_argument("--mpc_max_iters", default=1e3, type=int)
    parser.add_argument("--mpc_type", default="Random", type=str)  # CEM, Random
    # parser.add_argument("--mpc_mode", default="hard", type=str)  # hard, dl

    parser.add_argument("--max_buffer_size", default=1e6, type=int)
    parser.add_argument("--episode_length", default=25, type=int)
    parser.add_argument("--mixture_coef", default=0.05, type=float)
    
    parser.add_argument("--seed_steps", default=50, type=int)
    parser.add_argument("--train_single", default=0, type=int)

    parser.add_argument("--algo", default='tdmpc', type=str)
    parser.add_argument("--min_std", default=0.05, type=float)
    parser.add_argument("--cem_iter", default=6, type=int)
    parser.add_argument("--q_dim", default=512, type=int)
    parser.add_argument("--mpc_temperature", default=0.5, type=float)
    parser.add_argument("--td_lr", default=3e-4, type=float)
    parser.add_argument("--rho", default=0.5, type=float)
    parser.add_argument("--grad_clip_norm", default=10, type=int)
    parser.add_argument("--pi_update_freq", default=2, type=int)
    parser.add_argument("--pi_tau", default=0.005, type=float)
    parser.add_argument("--consistency_coef", default=2, type=float)
    parser.add_argument("--reward_coef", default=0.5, type=float)
    parser.add_argument("--contin_coef", default=0.5, type=float)
    parser.add_argument("--value_coef", default=0.1, type=float)
    parser.add_argument("--per_alpha", default=0.6, type=float)
    parser.add_argument("--per_beta", default=0.4, type=float)
    parser.add_argument("--batch_size", default=64, type=int)
    parser.add_argument("--save_local_epi", default=1_000, type=int)
    parser.add_argument("--reward_scale", default=1., type=float)

    parser.add_argument("--inverting_gradients", default=0, type=int)
    parser.add_argument("--policy_type", default="hps", type=str)  # hps, patd3

    parser.add_argument("--use_policy", default=0, type=int)
    parser.add_argument("--use_model", default=1, type=int)

    parser.add_argument("--pi_layers", default=64, type=int)

    parser.add_argument('--which_model', default="normal", type=str)  # normal, h
    parser.add_argument('--embed', default=0, type=int)  # 
    parser.add_argument("--oup_param", default='Gaussian', type=str)

    parser.add_argument('--save_dir', default="070901", type=str)
    # parser.add_argument('--save-frames', default=1, type=int)
    parser.add_argument('--visualise', default=0, type=int)
    parser.add_argument("--save_points", default=0, type=int)

    args = parser.parse_args()
    run(args)
    # for i in range(0, 3):
    #     args.seed = i
    #     run(args)