import torch
import numpy as np
import gym
import time, uuid
from execution_scripts import td3_n, bc_offline, revalued, combined, sac_n

from inverse_model import run_inverse_model
from dmc_datasets import environment_utils
import d4rl

if __name__ == '__main__':
    model_info = {'layers':[256,256,256], ##base layer model spec
                  'hidden_activation':'ReLU', ##activation for hidden laeyrs
                  'critic_final_activation':'',
                  }

    bc_params = {'gaussian_bc':False,
                'bc_alpha':8}


    td3_params = {'policy_noise_std':0.2,
                  'noise_clip':0.5,
                  'policy_update_freq':2,
                  'exploration_noise_std':0.1,
                  'td3_alpha':1}


    lr_info = {'optimiser':'Adam', 
               'max_grad_norm':0.05,##for clipping gradient
               }

    replay_buffer_params = {'mem_size':1000000, #2**20,
                            'batch_size':256,
                            'normalise_state':True,
                            'discrete_eps':1e-4,
                            'discrete_bins':3,
                            }

    ensemble_info = {'ensemble_num':1,
                    'critic_factor':5} ##how many critics per actor



    performance_eval_config = {'num_evals':10,    ## how many times to evaluate when testing performance of policy
                               'eval_counter':10000, ##after how many steps of learning to evaluate offline
                               }

    training_config = {'num_env_steps':3000001,
                        'online_steps':1000001,
                        'burn_in_steps':5000, #5000, ## number of random actions to fill buffer with
                        'update_ratio':1,
                        'n_steps':1} #  n_step return

    machine_config = {'device':'cuda:0' if torch.cuda.is_available() else 'cpu',}

    dm_suite = False #True

    env_config = {}
    env_config['dm_suite'] = dm_suite

    if dm_suite:
        
        env_id = 'cheetah-run-medium-expert'


        env_config['dm_suite'] = dm_suite
        env_config['action_bins']=3
        env_config['task_name'], env_config['aim'], *env_config['data_quality'] = env_id.split('-')
        env_config['data_quality'] = '-'.join(env_config['data_quality'])
        print(env_config['data_quality'])
        env_config['task'] = env_config['task_name'] + '-' + env_config['aim']

        env = environment_utils.make_env(task_name=env_config['task_name'], task=env_config['aim'], bin_size=3, factorised=True)
        training_config['n_steps'] = 3

        if 'quadruped' in env_id:
            training_config['burn_in_steps'] = 10000

    else:
       #env_id = 'hopper-medium-v2'
       #env_id = 'hopper-expert-v2' 
       #env_id = 'hopper-medium-expert-v2' 
       #env_id = 'hopper-medium-replay-v2' 
       #env_id = 'halfcheetah-medium-v2' 
       #env_id = 'halfcheetah-medium-expert-v2' 
       #env_id = 'halfcheetah-medium-replay-v2'
       #env_id = 'halfcheetah-expert-v2' 
       #env_id = 'walker2d-medium-v2' 
        env_id = 'walker2d-expert-v2'
       #env_id = 'walker2d-medium-expert-v2' 
       #env_id = 'walker2d-medium-replay-v2'



        env = gym.make(env_id)

        env_config['task'],*env_config['data_quality'],_ = env_id.split('-')
        env_config['data_quality'] = '-'.join(env_config['data_quality'])

    env_config['env_id'] = env_id
    env_config['env'] = env

    config_dict = {**env_config,
                   'seed':4,
                   'gamma':0.99,
                   'train_model':False,
                   'model_info':model_info,
                   'dep_targ':True,
                    **ensemble_info,
                    **machine_config,
                    **performance_eval_config,
                    **training_config,
                    **lr_info,
                    **replay_buffer_params,
                    **bc_params,
                    **td3_params,
                    'wandb_project':'ICLR State Only Offline',
                    'id':str(uuid.uuid4())[:8],
                    'wandb_log_iter':1000, ##when to log values
                    }


    config_dict['use_wandb'] = False

    config_dict['offline'] = False #True 

    config_dict['save_model'] = False 

    if config_dict['offline']:
        config_dict['algo_type'] = 'offline'
    else:
        config_dict['algo_type'] = 'online'

    if not dm_suite:
        config_dict['min_val']=config_dict['env'].action_space.low
        config_dict['max_val']=config_dict['env'].action_space.high

    seed = config_dict['seed']
    config_dict['rng'] = np.random.default_rng(seed)
    config_dict['env'].action_space.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

   #td3_n(config_dict)

   #sac_n(config_dict)

   #bc_offline(config_dict,agent_type='state')
   #bc_offline(config_dict,agent_type='diff_state')
   #bc_offline(config_dict,agent_type='cont_diff_state')
   #bc_offline(config_dict,agent_type='action')

   #run_inverse_model(config_dict,model_type='state')
   #run_inverse_model(config_dict,model_type='diff_state')
   #run_inverse_model(config_dict,model_type='cont_diff_state')

    revalued(config_dict,agent_type='diff_state')
   #revalued(config_dict,agent_type='action')

   #combined(config_dict)
