import os
import sys
import numpy as np
import random
import wandb
from pathlib import Path
current_dir = os.path.dirname(os.path.abspath(__file__))

parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)

import diffuser.utils as utils
import torch
import argparse
from config.locomotion_config import Config
from diffuser.datasets.d4rl import load_environment, sequence_dataset
from diffuser.datasets.preprocessing import get_preprocess_fn
import shutil
import json
import pickle
from diffuser.utils.logger import Logger

def cycle(dl):
    while True:
        for data in dl:
            yield data

def _worker_init_fn(worker_id,base_seed=7):
    seed = np.random.seed(np.random.get_state()[1][0] +worker_id)
    print("worker id:", worker_id," seed:",seed)
    np.random.seed(seed)
    random.seed(seed)
             
def main(args,logger):
    
    torch.backends.cudnn.benchmark = True
    utils.set_seed(Config.seed)
    


    # -----------------------------------------------------------------------------#
    # ---------------------------------- dataset ----------------------------------#
    # -----------------------------------------------------------------------------#

    if Config.mode=='new':
        env_name = Config.dataset
        env = load_environment(Config.dataset)
        preprocess_fn = get_preprocess_fn([], env)
        itr = sequence_dataset(env=env, preprocess_fn=preprocess_fn,po=Config.po,occlude_start_idx=Config.occlude_start_idx)

        dataset_config = utils.Config(
            Config.loader,
            #savepath='dataset_config.pkl',
            itr=itr,
            env_name=Config.dataset,
            horizon=Config.horizon,
            normalizer=Config.normalizer,
            preprocess_fns=Config.preprocess_fns,
            use_padding=Config.use_padding,
            max_path_length=Config.max_path_length,
            include_returns=Config.include_returns,
            include_gaits=Config.include_gaits,
            returns_scale=Config.returns_scale,
            discount=Config.discount,
            termination_penalty=Config.termination_penalty,
        )
    else:
        dataset_config = utils.Config(
        Config.loader,
        savepath=Config.bucket+'/dataset_config.pkl',
        env=Config.dataset,
        horizon=Config.horizon,
        normalizer=Config.normalizer,
        preprocess_fns=Config.preprocess_fns,
        use_padding=Config.use_padding,
        max_path_length=Config.max_path_length,
        include_returns=Config.include_returns,
        include_gaits=Config.include_gaits,
        returns_scale=Config.returns_scale,
        discount=Config.discount,
        termination_penalty=Config.termination_penalty,
        po=Config.po,
        occlude_start_idx=Config.occlude_start_idx,
    )
    dataset = dataset_config()
    
    renderer = None
    
    observation_dim = dataset.observation_dim
    action_dim = dataset.action_dim

    # -----------------------------------------------------------------------------#
    # ------------------------------ model & trainer ------------------------------#
    # -----------------------------------------------------------------------------#
    if Config.diffusion == 'models.GaussianInvDynDiffusion' and Config.mode != 'new':
        model_config = utils.Config(
            Config.model,
            savepath=Config.bucket+'/model_config.pkl',
            horizon=Config.horizon,
            transition_dim=observation_dim,
            cond_dim=observation_dim,
            dim_mults=Config.dim_mults,
            returns_condition=Config.returns_condition,
            dim=Config.dim,
            condition_dropout=Config.condition_dropout,
            calc_energy=Config.calc_energy,
            device=Config.device,
            condition_dim=Config.condition_dim,
        )

        diffusion_config = utils.Config(
            Config.diffusion,
            savepath=Config.bucket+'/diffusion_config.pkl',
            horizon=Config.horizon,
            observation_dim=observation_dim,
            action_dim=action_dim,
            n_timesteps=Config.n_diffusion_steps,
            loss_type=Config.loss_type,
            clip_denoised=Config.clip_denoised,
            predict_epsilon=Config.predict_epsilon,
            hidden_dim=Config.hidden_dim,
            ar_inv=Config.ar_inv,
            train_only_inv=Config.train_only_inv,
            ## loss weighting
            action_weight=Config.action_weight,
            loss_weights=Config.loss_weights,
            loss_discount=Config.loss_discount,
            returns_condition=Config.returns_condition,
            condition_guidance_w=Config.condition_guidance_w,
            device=Config.device,
        )
    elif Config.diffusion == 'models.GaussianInvDynDiffusion' and Config.mode == 'new':
        
        model_config = utils.Config(
            Config.model,
            savepath=Config.bucket+'/model_config.pkl',
            horizon=Config.horizon,
            transition_dim=observation_dim+Config.z_dim, # observation_dim + latent_dim,
            cond_dim=observation_dim+Config.z_dim,
            dim_mults=Config.dim_mults,
            returns_condition=Config.returns_condition,
            dim=Config.dim,
            condition_dropout=Config.condition_dropout,
            calc_energy=Config.calc_energy,
            device=Config.device,
            condition_dim=Config.condition_dim,
        )

        diffusion_config = utils.Config(
            Config.diffusion,
            savepath=Config.bucket+'/diffusion_config.pkl',
            horizon=Config.horizon,
            observation_dim=observation_dim+Config.z_dim, # observation_dim + latent_dim
            action_dim=action_dim,
            n_timesteps=Config.n_diffusion_steps,
            loss_type=Config.loss_type,
            clip_denoised=Config.clip_denoised,
            predict_epsilon=Config.predict_epsilon,
            hidden_dim=Config.hidden_dim,
            ar_inv=Config.ar_inv,
            train_only_inv=Config.train_only_inv,
            ## loss weighting
            action_weight=Config.action_weight,
            loss_weights=Config.loss_weights,
            loss_discount=Config.loss_discount,
            returns_condition=Config.returns_condition,
            condition_guidance_w=Config.condition_guidance_w,
            device=Config.device,
        )
        
    else:
        model_config = utils.Config(
            Config.model,
            savepath=Config.bucket+'/model_config.pkl',
            horizon=Config.horizon,
            transition_dim=observation_dim + action_dim,  
            cond_dim=observation_dim,
            dim_mults=Config.dim_mults,
            returns_condition=Config.returns_condition,
            dim=Config.dim,
            condition_dropout=Config.condition_dropout,
            calc_energy=Config.calc_energy,
            device=Config.device,
        )
        diffusion_config = utils.Config(
            Config.diffusion,
            savepath=Config.bucket+'/diffusion_config.pkl',
            horizon=Config.horizon,
            observation_dim=observation_dim,
            action_dim=action_dim,
            n_timesteps=Config.n_diffusion_steps,
            loss_type=Config.loss_type,
            clip_denoised=Config.clip_denoised,
            predict_epsilon=Config.predict_epsilon,
            ## loss weighting
            action_weight=Config.action_weight,
            loss_weights=Config.loss_weights,
            loss_discount=Config.loss_discount,
            returns_condition=Config.returns_condition,
            condition_guidance_w=Config.condition_guidance_w,
            device=Config.device,
        )
    
    if Config.mode == 'new':
        posterior_config = utils.Config(
            Config.posterior_model,
            state_dim=observation_dim,
            act_dim=action_dim,
            hidden_size=Config.hidden_dim,
            z_dim=Config.z_dim,
            horizon=Config.horizon,
            max_length=Config.max_path_length,
            device=Config.device,
        )
        if Config.cond_z==0:
            prior_in_dim = Config.hidden_dim
        else:
            prior_in_dim = Config.hidden_dim + 1
        prior_config = utils.Config(
            Config.prior_model,
            hidden_dim = prior_in_dim, 
            z_dim = Config.z_dim,
            device=Config.device,
        )
        
        future_diffusion_config = utils.Config(
            'models.future.FutureDiffusion',
            observation_dim=observation_dim ,
            action_dim=action_dim,
            hidden_dim= Config.hidden_dim,
            z_dim=Config.z_dim,
            z_reg=Config.z_reg,
            horizon=Config.horizon,
            max_length=Config.max_path_length,
            future_mode=Config.future_mode,
            token_mode=Config.token_mode,
            cond_z=Config.cond_z,
            device=Config.device,
        )
        trainer_config = utils.Config(
            utils.Trainer,#Config.trainer,
            savepath=Config.bucket+'/trainer_config.pkl',
            train_batch_size=Config.batch_size,
            train_lr=Config.learning_rate,
            gradient_accumulate_every=Config.gradient_accumulate_every,
            ema_decay=Config.ema_decay,
            sample_freq=Config.sample_freq,
            save_freq=Config.save_freq,
            log_freq=Config.log_freq,
            label_freq=int(Config.n_train_steps // Config.n_saves),
            save_parallel=Config.save_parallel,
            bucket=Config.bucket,
            n_reference=Config.n_reference,
            n_workers=Config.n_workers,
            train_device=Config.device,
            save_checkpoints=Config.save_checkpoints,
        )
    else:
        trainer_config = utils.Config(
            utils.Trainer,
            savepath=Config.bucket+'/trainer_config.pkl',
            train_batch_size=Config.batch_size,
            train_lr=Config.learning_rate,
            gradient_accumulate_every=Config.gradient_accumulate_every,
            ema_decay=Config.ema_decay,
            sample_freq=Config.sample_freq,
            save_freq=Config.save_freq,
            log_freq=Config.log_freq,
            label_freq=int(Config.n_train_steps // Config.n_saves),
            save_parallel=Config.save_parallel,
            bucket=Config.bucket,
            n_reference=Config.n_reference,
            n_workers=Config.n_workers,
            train_device=Config.device,
            save_checkpoints=Config.save_checkpoints,
        )
    

    # -----------------------------------------------------------------------------#
    # -------------------------------- instantiate --------------------------------#
    # -----------------------------------------------------------------------------#

    model = model_config()

    diffusion = diffusion_config(model)
    
    dataloader = cycle(torch.utils.data.DataLoader(
            dataset, batch_size=Config.batch_size, num_workers=Config.n_workers, shuffle=True, pin_memory=True,
            worker_init_fn=lambda worker_id: _worker_init_fn(worker_id, Config.seed),
        ))
    dataloader_vis = cycle(torch.utils.data.DataLoader(
            dataset, batch_size=1, num_workers=0, shuffle=True, pin_memory=True
        ))
    
    if Config.mode =='new':
        prior_model = prior_config()
        posterior_model = posterior_config()
        future_diffusion = future_diffusion_config(prior_model,posterior_model,diffusion)
        final_model = future_diffusion
        trainer = trainer_config(future_diffusion, dataset, renderer,dataloader,dataloader_vis)
    else:
        final_model = diffusion
        trainer = trainer_config(diffusion, dataset, renderer,dataloader,dataloader_vis)
    #trainer = trainer_config(diffusion, dataset, renderer)

    log_path = Config.bucket + '/log'
    
    if Config.resume==1:
        loadpath = os.path.join(Config.bucket, 'checkpoint')
        if Config.save_checkpoints:
            loadpath = os.path.join(loadpath, f'state_{Config.ckpt_step}.pt')
        else:
            loadpath = os.path.join(loadpath, 'state.pt')
        
        state_dict = torch.load(loadpath, map_location=Config.device)
        
        trainer.step = state_dict['step']
        trainer.model.load_state_dict(state_dict['model'])
        trainer.ema_model.load_state_dict(state_dict['ema'])

    # -----------------------------------------------------------------------------#
    # ------------------------ test forward & backward pass -----------------------#
    # -----------------------------------------------------------------------------#

    utils.report_parameters(model)
    print('Testing forward...')
    try:
        serialized_dataset = pickle.dumps(dataset[0])
    except Exception as e:
        print(f"Serialization failed: {e}")
        
    batch = utils.batchify(dataset[0], Config.device)
    
    print(batch[0].shape, batch[2].shape)
    #loss, _ = diffusion.loss(*batch) 
    loss, _ = final_model.loss(*batch)
    loss.backward()
    
    #logger.print('✓')
    
    # -----------------------------------------------------------------------------#
    # --------------------------------- main loop ---------------------------------#
    # -----------------------------------------------------------------------------#
    if Config.resume==0:
        n_epochs = int(Config.n_train_steps // Config.n_steps_per_epoch)

        for i in range(n_epochs):
            print(f'Epoch {i} / {n_epochs}')
            trainer.train(n_train_steps=Config.n_steps_per_epoch,wandb_log=logger)
    else:
            print("trainer.step:",trainer.step)
            trainer.train(n_train_steps=Config.n_train_steps,wandb_log=logger)

if __name__ == '__main__':
    '''
    #code for debugging
    import ptvsd
    host = "localhost" 
    port = 55557   
    ptvsd.enable_attach(address=(host, port), redirect_output=True)
    ptvsd.wait_for_attach()
    '''
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_name", default="hopper-medium-expert-v2", type=str)
    parser.add_argument("--z_dim", default=8, type=int)
    parser.add_argument("--mode",default='new',type=str,help='old for DD while new for FutureDD')
    parser.add_argument("--save_freq",default=250000,type=int)
    parser.add_argument("--n_train_steps",default=1e6,type=int)
    parser.add_argument("--batch_size", default=32, type=int)
    parser.add_argument("--hidden_dim",default=256,type=int)
    
    parser.add_argument("--token_mode",default=1,type=int,help='The token information to be encoded')
    parser.add_argument("--future_mode",default=1,type=int,help='The future information to be encoded')
    parser.add_argument("--n_workers",default=0, type=int, help='Num of workers for dataloader')
    parser.add_argument("--po",default=0, type=int, help='1 for POMDP and 0 for MDP')
    parser.add_argument("--seed",default=7, type=int, help='random seed')
    parser.add_argument("--occlude_start_idx",default=-2, type=int, help='number of dims of obs to be occluded')
    parser.add_argument("--condition_guidance_w",default=1.0,type=float)
    parser.add_argument("--resume",default=0,type=int,help='1 for resume a checkpoint and 0 for train from scratch')
    parser.add_argument("--trained_time",default='2023-08-27/18-43-16',type=str,help='time when model was created, helps to specify the model path')
    parser.add_argument("--ckpt_step",default=999999,type=int, help='The step of checkpoint to be evaluated') 
    parser.add_argument("--prior_model",default='DiagGaussian',type=str, help='The prior model to be used')
    parser.add_argument("--cond_z",default=0,type=int)
    args = parser.parse_args()

    # Get the current file name and location
    current_file = os.path.abspath(__file__)
    # Set the destination directory for the file
    from datetime import datetime
    if args.resume==0:
        if args.po==0:
            Config.bucket = '/data/your_name_hdd/dd/mdp/'+ datetime.now().strftime('%Y-%m-%d')+"/"+ datetime.now().strftime('%H-%M-%S')
        else:
            Config.bucket = '/data/your_name_hdd/dd/pomdp/'+ datetime.now().strftime('%Y-%m-%d')+"/"+ datetime.now().strftime('%H-%M-%S')
    else:
        if args.po==0:
            Config.bucket = '/data/your_name_hdd/dd/mdp/'+ args.trained_time
        else:
            Config.bucket = '/data/your_name_hdd/dd/pomdp/'+ args.trained_time
        
    Config.dataset = args.env_name
    Config.cond_z = args.cond_z 
    
    if 'hopper' in Config.dataset:
        Config.returns_scale = 400.0
    elif 'halfcheetah' in Config.dataset:
        Config.returns_scale = 1200.0
    elif 'walker' in Config.dataset:
        Config.returns_scale = 550.0 # Determined using rewards from the dataset
    
    Config.ckpt_step = args.ckpt_step
    Config.z_dim = args.z_dim
    Config.mode = args.mode
    Config.save_freq = args.save_freq
    Config.n_train_steps = args.n_train_steps
    Config.batch_size = args.batch_size
    Config.hidden_dim = args.hidden_dim
    Config.future_mode = args.future_mode
    Config.token_mode = args.token_mode
    Config.n_workers = args.n_workers
    Config.seed = args.seed
    Config.po = args.po
    Config.occlude_start_idx = args.occlude_start_idx
    Config.condition_guidance_w = args.condition_guidance_w
    Config.resume = args.resume
    Config.trained_time = args.trained_time
    Config.prior_model = 'models.helpers.' + args.prior_model
    
    if args.mode=='new':
        Config.loader = 'datasets.FutureSequenceDataset'
    else:
        Config.loader = 'datasets.SequenceDataset'
    
    log_path = Config.bucket + '/log'
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    
    # Use shutil to copy the file to the destination directory
    shutil.copy2(current_file, log_path)
    
    # Print the arguments
    print("env_name:", args.env_name)
    print("seed:", Config.seed)
    print("z_dim:", Config.z_dim)
    print("mode:", Config.mode)
    print("save_freq:", Config.save_freq)
    print("n_train_steps:", Config.n_train_steps)
    print("batch_size:", Config.batch_size)
    print("hidden_dim:", Config.hidden_dim)
    print("future_mode:", Config.future_mode)
    print("token_mode:", Config.token_mode)
    print("n_workers",Config.n_workers)
    print("loader:",Config.loader)
    print("po:",Config.po)
    print("occlude_start_idx:",Config.occlude_start_idx)
    
    print('saving logs to:', log_path, '------------')
    
    with open(log_path + '/args.json', 'w') as f:
        json.dump(vars(args), f)
    print(Config.__dict__)
    
    with open(log_path + '/config.json', 'w') as f:
        json.dump(Config.__dict__, f, default=lambda x: str(x))
    if Config.mode == 'new':
        model_name = 'FutureDD'
    else:
        model_name = 'DD'
    exp_name = model_name + '_' + Config.dataset + '_' + str(Config.z_dim) + '_' + str(Config.hidden_dim) + '_' + str(Config.token_mode) + '_' + str(Config.future_mode)
    wandb.init(project="FutureDD",
               entity="your_entity2020",
               name=exp_name,
               config=Config.__dict__,
               settings=wandb.Settings(
                   start_method="thread",
                   _disable_stats=True,
               ),
               mode="online" ,
               #notes=Config.notes,
               )
    
    logger = Logger(Path(Config.bucket), use_tb=True, use_wandb=True)
    
    main(args,logger)
