#from quadrabot_gym.scripts.play import load_env


import torch
from copy import deepcopy
import numpy as np
import os
import sys
import shutil
import json
import wandb
current_dir = os.path.dirname(os.path.abspath(__file__))

parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)

import gym
import diffuser.utils as utils
import torch.nn.functional as F
import argparse
from config.locomotion_config import Config
from diffuser.utils.arrays import to_torch, to_np, to_device
from diffuser.datasets.d4rl import suppress_output
from diffuser.environments.hopper import HopperFullObsEnv
from diffuser.environments.half_cheetah import HalfCheetahFullObsEnv
from diffuser.environments.walker2d import Walker2dFullObsEnv
from diffuser.datasets.d4rl import load_environment, sequence_dataset
from diffuser.datasets.preprocessing import get_preprocess_fn

eval_testrt = { 
            'hopper-medium-replay-v2': 0.85, 
            'walker2d-medium-replay-v2': 0.65, 
            'halfcheetah-medium-replay-v2': 0.4, 
            'walker2d-medium-v2': 0.75, 
            'halfcheetah-medium-v2': 0.5, 
            'hopper-medium-v2': 0.85, 
            'hopper-medium-expert-v2': 0.85, 
            'walker2d-medium-expert-v2': 0.9, 
            'halfcheetah-medium-expert-v2': 0.8, 
        }


def cycle(dl):
    while True:
        for data in dl:
            yield data

def _worker_init_fn(worker_id,base_seed=7):
    seed = np.random.seed(np.random.get_state()[1][0] +worker_id)
    print("worker id:", worker_id," seed:",seed)
    np.random.seed(seed)
    random.seed(seed)
    
def main(args):
    
    from config.locomotion_config import Config
    import wandb
    torch.backends.cudnn.benchmark = True
    utils.set_seed(Config.seed)
    
    Config.device = 'cuda'

    loadpath = os.path.join(Config.bucket, 'checkpoint')
    if Config.save_checkpoints:
        loadpath = os.path.join(loadpath, f'state_{Config.ckpt_step}.pt')
    else:
        loadpath = os.path.join(loadpath, 'state.pt')
    
    state_dict = torch.load(loadpath, map_location=Config.device)

    # Load configs
    torch.backends.cudnn.benchmark = True
    utils.set_seed(Config.seed)

     # -----------------------------------------------------------------------------#
    # ---------------------------------- dataset ----------------------------------#
    # -----------------------------------------------------------------------------#
    
    if Config.mode=='new':
        env_name = Config.dataset
        env = load_environment(Config.dataset)
        preprocess_fn = get_preprocess_fn([], env)
        itr = sequence_dataset(env=env, preprocess_fn=preprocess_fn,po=Config.po,occlude_start_idx=Config.occlude_start_idx)

        dataset_config = utils.Config(
            Config.loader,
            #savepath='dataset_config.pkl',
            itr=itr,
            env_name=Config.dataset,
            horizon=Config.horizon,
            normalizer=Config.normalizer,
            preprocess_fns=Config.preprocess_fns,
            use_padding=Config.use_padding,
            max_path_length=Config.max_path_length,
            include_returns=Config.include_returns,
            include_gaits=Config.include_gaits,
            returns_scale=Config.returns_scale,
            discount=Config.discount,
            termination_penalty=Config.termination_penalty,
        )
    else:
        dataset_config = utils.Config(
        Config.loader,
        savepath=Config.bucket+'/dataset_config.pkl',
        env=Config.dataset,
        horizon=Config.horizon,
        normalizer=Config.normalizer,
        preprocess_fns=Config.preprocess_fns,
        use_padding=Config.use_padding,
        max_path_length=Config.max_path_length,
        include_returns=Config.include_returns,
        include_gaits=Config.include_gaits,
        returns_scale=Config.returns_scale,
        discount=Config.discount,
        termination_penalty=Config.termination_penalty,
        po=Config.po,
        occlude_start_idx=Config.occlude_start_idx,
    )
    dataset = dataset_config()
    dataloader = cycle(torch.utils.data.DataLoader(
            dataset, batch_size=Config.batch_size, num_workers=Config.n_workers, shuffle=True, pin_memory=True,
            worker_init_fn=lambda worker_id: _worker_init_fn(worker_id, Config.seed),
        ))
    dataloader_vis = cycle(torch.utils.data.DataLoader(
            dataset, batch_size=1, num_workers=0, shuffle=True, pin_memory=True
        ))
    
    renderer = None
    observation_dim = dataset.observation_dim
    action_dim = dataset.action_dim

    if Config.diffusion == 'models.GaussianInvDynDiffusion':
        if Config.mode =='new':
            transition_dim = observation_dim+Config.z_dim
        else:
            transition_dim = observation_dim
    else:
        transition_dim = observation_dim + action_dim


    if Config.mode =='new':
        model_config = utils.Config(
            Config.model,
            savepath='model_config.pkl',
            horizon=Config.horizon,
            transition_dim=observation_dim+Config.z_dim, # observation_dim + latent_dim,
            cond_dim=observation_dim+Config.z_dim,
            dim_mults=Config.dim_mults,
            returns_condition=Config.returns_condition,
            dim=Config.dim,
            condition_dropout=Config.condition_dropout,
            calc_energy=Config.calc_energy,
            device=Config.device,
            condition_dim=Config.condition_dim,
        )

        diffusion_config = utils.Config(
            Config.diffusion,
            savepath='diffusion_config.pkl',
            horizon=Config.horizon,
            observation_dim=observation_dim+Config.z_dim, # observation_dim + latent_dim
            action_dim=action_dim,
            n_timesteps=Config.n_diffusion_steps,
            loss_type=Config.loss_type,
            clip_denoised=Config.clip_denoised,
            predict_epsilon=Config.predict_epsilon,
            hidden_dim=Config.hidden_dim,
            ar_inv=Config.ar_inv,
            train_only_inv=Config.train_only_inv,
            ## loss weighting
            action_weight=Config.action_weight,
            loss_weights=Config.loss_weights,
            loss_discount=Config.loss_discount,
            returns_condition=Config.returns_condition,
            condition_guidance_w=Config.condition_guidance_w,
            device=Config.device,
        )
        posterior_config = utils.Config(
            Config.posterior_model,
            state_dim=observation_dim,
            act_dim=action_dim,
            hidden_size=Config.hidden_dim,
            z_dim=Config.z_dim,
            horizon=Config.horizon,
            max_length=Config.max_path_length,
            device=Config.device,
        )
        if Config.cond_z==0:
            prior_in_dim = Config.hidden_dim
        else:
            prior_in_dim = Config.hidden_dim + 1
        prior_config = utils.Config(
            Config.prior_model,
            hidden_dim = prior_in_dim, 
            z_dim = Config.z_dim,
            device=Config.device,
        )
        future_diffusion_config = utils.Config(
            'models.future.FutureDiffusion',
            observation_dim=observation_dim ,
            cond_z=Config.cond_z,
            action_dim=action_dim,
            hidden_dim= Config.hidden_dim,
            z_dim=Config.z_dim,
            z_reg=Config.z_reg,
            horizon=Config.horizon,
            max_length=Config.max_path_length,
            device=Config.device,
        )
        trainer_config = utils.Config(
            utils.Trainer,
            savepath='trainer_config.pkl',
            train_batch_size=Config.batch_size,
            train_lr=Config.learning_rate,
            gradient_accumulate_every=Config.gradient_accumulate_every,
            ema_decay=Config.ema_decay,
            sample_freq=Config.sample_freq,
            save_freq=Config.save_freq,
            log_freq=Config.log_freq,
            label_freq=int(Config.n_train_steps // Config.n_saves),
            save_parallel=Config.save_parallel,
            bucket=Config.bucket,
            n_reference=Config.n_reference,
            train_device=Config.device,
            save_checkpoints=Config.save_checkpoints,
        )
        model = model_config()
        diffusion = diffusion_config(model)
        prior_model = prior_config()
        posterior_model = posterior_config()
        future_diffusion = future_diffusion_config(prior_model,posterior_model,diffusion)

        trainer = trainer_config(future_diffusion, dataset, renderer,dataloader,dataloader_vis)
   
    else:
        model_config = utils.Config(
            Config.model,
            savepath='model_config.pkl',
            horizon=Config.horizon,
            transition_dim=transition_dim,
            cond_dim=observation_dim,
            dim_mults=Config.dim_mults,
            dim=Config.dim,
            returns_condition=Config.returns_condition,
            device=Config.device,
            condition_dim=Config.condition_dim,
        )
        diffusion_config = utils.Config(
            Config.diffusion,
            savepath='diffusion_config.pkl',
            horizon=Config.horizon,
            observation_dim=observation_dim,
            action_dim=action_dim,
            n_timesteps=Config.n_diffusion_steps,
            loss_type=Config.loss_type,
            clip_denoised=Config.clip_denoised,
            predict_epsilon=Config.predict_epsilon,
            hidden_dim=Config.hidden_dim,
            ## loss weighting
            action_weight=Config.action_weight,
            loss_weights=Config.loss_weights,
            loss_discount=Config.loss_discount,
            returns_condition=Config.returns_condition,
            device=Config.device,
            condition_guidance_w=Config.condition_guidance_w,
        )

        trainer_config = utils.Config(
            utils.Trainer,
            savepath='trainer_config.pkl',
            train_batch_size=Config.batch_size,
            train_lr=Config.learning_rate,
            gradient_accumulate_every=Config.gradient_accumulate_every,
            ema_decay=Config.ema_decay,
            sample_freq=Config.sample_freq,
            save_freq=Config.save_freq,
            log_freq=Config.log_freq,
            label_freq=int(Config.n_train_steps // Config.n_saves),
            save_parallel=Config.save_parallel,
            bucket=Config.bucket,
            n_reference=Config.n_reference,
            train_device=Config.device,
        )

        model = model_config()
        diffusion = diffusion_config(model)
        trainer = trainer_config(diffusion, dataset, renderer,dataloader,dataloader_vis)
        
    print(utils.report_parameters(model))
    trainer.step = state_dict['step']
    trainer.model.load_state_dict(state_dict['model'])
    trainer.ema_model.load_state_dict(state_dict['ema'])

    num_eval = Config.num_eval
    device = Config.device
    # Unitree-go-running environment

    # env = load_env(label="gait-conditioned-agility/pretrain-v0/train", headless=True)
    # env_list = [env for _ in range(num_eval)]

    # Gym environment
    env_list = [gym.make(Config.dataset) for _ in range(num_eval)]
    dones = [0 for _ in range(num_eval)]
    episode_rewards = [0 for _ in range(num_eval)]
    if Config.mode=='new':
        assert trainer.ema_model.diffusion_model.condition_guidance_w == Config.condition_guidance_w
    else:
        assert trainer.ema_model.condition_guidance_w == Config.condition_guidance_w
    # Define returns
    returns = to_device(Config.test_ret * torch.ones(num_eval, 1), device)
    # Define gaits
    # rows = torch.tensor([[0, 0, 0.5], [0, 0.5, 0], [0.5, 0, 0]])
    # returns = to_device(torch.stack([rows[i] for i in torch.randint(0, len(rows), size=(10,))]), device) # Gaits actually
    if args.po == 1:
        occlude_start_idx = args.occlude_start_idx
    t = 0
    # Gym environment
    if args.po == 0:
        obs_list = [env.reset()[None] for env in env_list]
        obs = np.concatenate(obs_list, axis=0)
        recorded_obs = [deepcopy(obs[:, None])]
    else:
        robs_list = [env.reset()[None] for env in env_list]
        obs_list = [env.reset()[None][:, :occlude_start_idx] for env in env_list]
        robs = np.concatenate(robs_list, axis=0)
        obs = np.concatenate(obs_list, axis=0)
        recorded_obs = [deepcopy(robs[:, None])]
    # Unitree-go-running environment
    # obs_list = [env.reset()['obs'].cpu() for env in env_list]
    while sum(dones) <  num_eval:
        obs = dataset.normalizer.normalize(obs, 'observations')
        if Config.mode=='new':
            obs = torch.from_numpy(obs).to(device).type(torch.float32) #.reshape(num_eval,-1,observation_dim)
            obs_embed = trainer.ema_model.embed_state(obs)
            if args.sample_mode=='mu':
                if Config.cond_z==0:
                    z = trainer.ema_model.prior_model(F.relu(obs_embed)).mu
                else:
                    z = trainer.ema_model.prior_model(F.relu(torch.cat((obs_embed,torch.ones(10,1).cuda()*Config.test_ret),dim=-1))).mu
            else:
                z = trainer.ema_model.prior_model(F.relu(obs_embed)).sample()
            obs_new = torch.cat((obs,z),dim=-1)
            conditions = {0: to_torch(obs_new, device=device)}
            samples = trainer.ema_model.diffusion_model.conditional_sample(conditions, returns=returns)
        else:
            conditions = {0: to_torch(obs, device=device)}
            samples = trainer.ema_model.conditional_sample(conditions, returns=returns)


        obs_comb = torch.cat([samples[:, 0, :], samples[:, 1, :]], dim=-1)
        if Config.mode=='new':
            obs_comb = obs_comb.reshape(-1, 2*(observation_dim+Config.z_dim))
            action = trainer.ema_model.diffusion_model.inv_model(obs_comb)
        else:
            obs_comb = obs_comb.reshape(-1, 2*observation_dim)
            action = trainer.ema_model.inv_model(obs_comb)
        samples = to_np(samples)
        action = to_np(action)

        action = dataset.normalizer.unnormalize(action, 'actions')

        # if t == 0:
        #     if Config.mode=='new':
        #         normed_observations = samples[:, :, :observation_dim]
        #     else:
        #         normed_observations = samples[:, :, :]
        #     observations = dataset.normalizer.unnormalize(normed_observations, 'observations')
        #     savepath = os.path.join('images', 'sample-planned.png')
        #     # When using Unitree-go-running, rendering is prohibited.
        #     #renderer.composite(savepath, observations)

        obs_list = []
        if args.po == 1:
            r_obs_list = []
        for i in range(num_eval):
            this_obs, this_reward, this_done, _ = env_list[i].step(action[i])
            if args.po == 0:
                obs_list.append(this_obs[None])
            else:
                r_obs_list.append(this_obs[None])
                obs_list.append(this_obs[None][:, :occlude_start_idx])
            if this_done:
                if dones[i] == 1:
                    pass
                else:
                    dones[i] = 1
                    episode_rewards[i] += this_reward
                    print(f"Episode ({i}): {episode_rewards[i]}")
                    
            else:
                if dones[i] == 1:
                    pass
                else:
                    episode_rewards[i] += this_reward
        # When using Unitree-go-running, rendering is prohibited.
        obs = np.concatenate(obs_list, axis=0)
        if args.po == 0:
            recorded_obs.append(deepcopy(obs[:, None]))
        else:
            robs = np.concatenate(r_obs_list, axis=0)
            recorded_obs.append(deepcopy(robs[:, None]))
        t += 1  

    recorded_obs = np.concatenate(recorded_obs, axis=1)
    savepath = os.path.join('images', f'sample-executed.png')
    #renderer.composite(savepath, recorded_obs)
    episode_rewards = np.array(episode_rewards)
    for i in range(num_eval):
        wandb.log({str("episode_reward_"+str(i)):episode_rewards[i]})
    print(f"average_ep_reward: {np.mean(episode_rewards)}, std_ep_reward: {np.std(episode_rewards)}")
    wandb.log({"average_ep_reward":np.mean(episode_rewards)})
    wandb.log({"std_ep_reward":np.std(episode_rewards)})

if __name__ == '__main__':
    '''
    #code for debugging
    import ptvsd
    host = "localhost" 
    port = 55557   
    ptvsd.enable_attach(address=(host, port), redirect_output=True)
    ptvsd.wait_for_attach()
    ''' 
    parser = argparse.ArgumentParser()
    parser.add_argument("--trained_time",default='2023-08-27/18-43-16',type=str,help='time when model was created, helps to specify the model path')
    parser.add_argument("--ckpt_step",default=999999,type=int, help='The step of checkpoint to be evaluated') 
    parser.add_argument("--env_name", default="hopper-medium-expert-v2", type=str)
    parser.add_argument("--z_dim", default=8, type=int)
    parser.add_argument("--mode",default='new',type=str,help='old for DD while new for FutureDD')
    parser.add_argument("--batch_size", default=32, type=int)
    parser.add_argument("--hidden_dim",default=256,type=int)
    parser.add_argument("--token_mode",default=1,type=int,help='The token information to be encoded')
    parser.add_argument("--future_mode",default=1,type=int,help='The future information to be encoded')
    parser.add_argument("--n_workers",default=0, type=int, help='Num of workers for dataloader')
    parser.add_argument("--po",default=0, type=int, help='1 for POMDP and 0 for MDP')
    parser.add_argument("--seed",default=7, type=int, help='random seed')
    parser.add_argument("--occlude_start_idx",default=-2, type=int, help='number of dims of obs to be occluded')
    parser.add_argument("--condition_guidance_w",default=1,type=float)
    parser.add_argument("--test_ret",default=0.0,type=float)
    parser.add_argument("--prior_model",default='DiagGaussian',type=str, help='The prior model to be used')
    parser.add_argument("--sample_mode",default='mu',type=str, help='The way z to be sampled')
    parser.add_argument("--cond_z",default=0,type=int)
   
    args = parser.parse_args()

    current_file = os.path.abspath(__file__)

    from datetime import datetime
    if args.po==0:
        Config.bucket = '/data/your_name_hdd/dd/mdp/'+ args.trained_time
    else:
        Config.bucket = '/data/your_name_hdd/dd/pomdp/'+ args.trained_time
    
    Config.dataset = args.env_name
    
    if 'hopper' in Config.dataset:
        Config.returns_scale = 400.0
    elif 'halfcheetah' in Config.dataset:
        Config.returns_scale = 1200.0
    elif 'walker' in Config.dataset:
        Config.returns_scale = 550.0 # Determined using rewards from the dataset
    
    Config.ckpt_step = args.ckpt_step
    Config.z_dim = args.z_dim
    Config.mode = args.mode
    Config.batch_size = args.batch_size
    Config.future_mode = args.future_mode
    Config.token_mode = args.token_mode
    Config.n_workers = args.n_workers
    Config.seed = args.seed
    Config.po = args.po
    Config.occlude_start_idx = args.occlude_start_idx
    Config.condition_guidance_w = args.condition_guidance_w
    Config.prior_model = 'models.helpers.' + args.prior_model
    Config.sample_mode = args.sample_mode
    Config.cond_z = args.cond_z 
    if args.test_ret==0.0:
        Config.test_ret = eval_testrt[args.env_name]
    else:
        Config.test_ret = args.test_ret
    
    if args.mode=='new':
        Config.loader = 'datasets.FutureSequenceDataset'
    else:
        Config.loader = 'datasets.SequenceDataset'
    
    log_path = Config.bucket + '/log'
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    
    shutil.copy2(current_file, log_path)
    
    # Print the arguments
    print("env_name:", args.env_name)
    print("seed:", Config.seed)
    print("z_dim:", Config.z_dim)
    print("mode:", Config.mode)
    print("save_freq:", Config.save_freq)
    print("n_train_steps:", Config.n_train_steps)
    print("batch_size:", Config.batch_size)
    print("hidden_dim:", Config.hidden_dim)
    #print("step:", Config.step)
    print("future_mode:", Config.future_mode)
    print("token_mode:", Config.token_mode)
    print("n_workers",Config.n_workers)
    print("loader:",Config.loader)
    print("po:",Config.po)
    print("occlude_start_idx:",Config.occlude_start_idx)
    
    print('saving logs to:', log_path, '------------')
    
    with open(log_path + '/args.json', 'w') as f:
        json.dump(vars(args), f)
    print(Config.__dict__)
    
    with open(log_path + '/config.json', 'w') as f:
        json.dump(Config.__dict__, f, default=lambda x: str(x))
    if Config.mode == 'new':
        model_name = 'FutureDD'
    else:
        model_name = 'DD'
    exp_name = args.prior_model + '_' + Config.dataset + '_' + str(Config.ckpt_step) + '_' + str(Config.condition_guidance_w) + '_' + str(Config.test_ret)
    
    wandb.init(project="FutureDD_eval",
               entity="your_entity2020",
               name=exp_name,
               config=Config.__dict__,
               settings=wandb.Settings(
                   start_method="thread",
                   _disable_stats=True,
               ),
               mode="online" ,
               #notes=Config.notes,
               )
    
    #logger = Logger(Path(Config.bucket), use_tb=True, use_wandb=True)
    
    
    main(args)