import torch
import numpy as np
import gym
#import gymnasium as gym
import time, uuid
from execution_scripts import td3_n, sac_n, decqn, bc_offline, revalued, maddpg, maac, true_revalued

from dmc_datasets import environment_utils

if __name__ == '__main__':
    ###256,256 is small 256,256,256 is two hidden layers
    model_info = {'layers':[256,256,256], ##base layer model spec
                  'hidden_activation':'ReLU', ##activation for hidden laeyrs
                  #'actor_final_activation':'Sigmoid', #'Softmax' ##Sigmoid  -- using conditional logic instead a bit easier...
                  'critic_final_activation':'',
                  }

    layer_params = {
                      'actor_hidden_dim':256, #512,
                      'critic_hidden_dim':256, #512,
                    }

    bc_params = {'bc_alpha':0.07}


    td3_params = {'policy_noise_std':0.2,
                  'noise_clip':0.5,
                  'policy_update_freq':2,
                  'exploration_noise_std':0.1,
                  'td3_alpha':1}


    lr_info = {'optimiser':'Adam',
               'max_grad_norm':0.05,##for clipping gradient
               }

    replay_buffer_params = {'mem_size':250000, # 250000, 1000000, 
                            'batch_size':256,
                            'normalise_state':True,
                            }

    ensemble_info = {'ensemble_num':1,
                    'critic_factor':10} ##how many critics per actor



    performance_eval_config = {'num_evals':10,    ## how many times to evaluate when testing performance of policy
                               'eval_counter':10000, ##after how many steps of learning to evaluate offline
                               }

    training_config = {'num_env_steps':1000001,
                        'online_steps':1000001,
                        'burn_in_steps':10000, ## number of random actions to fill buffer with
                        'update_ratio':1,
                        'utd_ratio':1,
                        'n_steps':3} #3} #1} ### calculate n_step return

    machine_config = {'device':'cuda:0' if torch.cuda.is_available() else 'cpu',}

    dm_suite = True
    is_continuous = False 

    env_config = {}
    env_config['dm_suite'] = dm_suite
    env_config['is_continuous'] = is_continuous
    env_config['action_bins']=3

    exploration_config = {'epsilon':1,
                        'epsilon_min':0.05, # 0.1 0.05, 0.2
                        'exploration_decay':0.99995} #0.99999 0.99995
    if dm_suite:
        
       #env_id = 'finger-spin'
       #env_id = 'walker-walk'
       #env_id = 'walker-run'
       #env_id = 'fish-swim'
       #env_id = 'quadruped-walk'
       #env_id = 'quadruped-run'
        env_id = 'cheetah-run'
       #env_id = 'humanoid-stand'
       #env_id = 'humanoid-walk'
       #env_id = 'humanoid-run'
       #env_id = 'dog-walk'
       #env_id = 'dog-trot'
       #env_id = 'dog-run'


       #env_id = 'quadruped-walk-medium'
       #env_id = 'quadruped-walk-expert'
       #env_id = 'quadruped-walk-medium-expert'
       #env_id = 'quadruped-walk-random-medium-expert'

       #env_id = 'cheetah-run-expert'
       #env_id = 'cheetah-run-medium'
       #env_id = 'cheetah-run-medium-expert'
       #env_id = 'cheetah-run-random-medium-expert'

       #env_id = 'humanoid-stand-expert'
       #env_id = 'humanoid-stand-medium'
       #env_id = 'humanoid-stand-medium-expert'
       #env_id = 'humanoid-stand-random-medium-expert'

       #env_id = 'dog-trot-expert'
       #env_id = 'dog-trot-medium'
       #env_id = 'dog-trot-medium-expert'
       #env_id = 'dog-trot-random-medium-expert'



        env_config['dm_suite'] = dm_suite
        env_config['task_name'], env_config['aim'], *env_config['data_quality'] = env_id.split('-')
        env_config['data_quality'] = '-'.join(env_config['data_quality'])
        print(env_config['data_quality'])
        env_config['task'] = env_config['task_name'] + '-' + env_config['aim']

        env = environment_utils.make_env(task_name=env_config['task_name'], task=env_config['aim'], bin_size=3, 
                                        is_continuous=is_continuous, factorised=True)

        test_env = environment_utils.make_env(task_name=env_config['task_name'], task=env_config['aim'], bin_size=3,
                                                is_continuous=is_continuous, factorised=True)

        if 'humanoid-stand' in env_id or 'humanoid-walk' in env_id:
           training_config['online_steps'] = 5000001

        if 'humanoid-run' in env_id:
           training_config['online_steps'] = 10000001

        if 'dog-trot' in env_id:
           training_config['online_steps'] = 3000001

        if 'dog-run' in env_id:
           training_config['online_steps'] = 5000001


    env_config['env_id'] = env_id
    env_config['env'] = env
    env_config['test_env'] = test_env

    layer_params['action_constrained_dim'] = 50

    config_dict = {**env_config,
                   'seed':0,
                   'gamma':0.99, 
                   'train_model':False,
                   'model_info':model_info,
                   'dep_targ':True,
                    **ensemble_info,
                    **machine_config,
                    **performance_eval_config,
                    **training_config,
                    **lr_info,
                    **replay_buffer_params,
                    **bc_params,
                    **td3_params,
                    **layer_params,
                    **exploration_config,
                    'wandb_project':'Attention Net',
                    'id':str(uuid.uuid4())[:8],
                    'wandb_log_iter':1000, ##when to log values
                    }


    config_dict['use_wandb'] = False #True

    config_dict['offline'] = False

    config_dict['save_model'] = False #True
    
    if 'humanoid' in config_dict['env_id']:
        config_dict['transform_reward'] = True
    else:
        config_dict['transform_reward'] = False

    if config_dict['offline']:
        config_dict['algo_type'] = 'offline'
        config_dict['mem_size'] = 2000000
    else:
        config_dict['algo_type'] = 'online'

    if is_continuous:
        config_dict['min_val']=config_dict['env'].action_space.low
        config_dict['max_val']=config_dict['env'].action_space.high

    seed = config_dict['seed']
    config_dict['rng'] = np.random.default_rng(seed)
    config_dict['env'].action_space.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)


   #td3_n(config_dict)

   #sac_n(config_dict)

   #bc_offline(config_dict)

   #revalued(config_dict)

   #true_revalued(config_dict)

   #maac(config_dict)

    maddpg(config_dict)

