import sys, gc, os

sys.path.append('.')

from utils import helpers

helpers.set_cuda_visible_devices('2,3')

import torch

from learn.trainer import Trainer
from learn.evaluator import AtariEvaluator

from learn.config import StateEncoderConfig, TrainConfig, AtariReplayDataExperimentConfig, GamesConfig
from baselines.dt import DTConfig, DecisionTransformer

from learn.dataset import AtariReplayDatasetForCAStRL

from utils import common, logs_handler, ale_env, misc
from utils.replay_atari_data import ALE_DATA, to_ale_name

# ---------------------------------------------------------------------------------------------------------------------

logger = logs_handler.get_logger('run')

# ---------------------------------------------------------------------------------------------------------------------

games = list(map(to_ale_name, list(ALE_DATA.keys())))
games_cfg = GamesConfig(envs=['pong'])

# ---------------------------------------------------------------------------------------------------------------------
pretrained_ckpt_path = None

pretrained = (pretrained_ckpt_path is not None)
cfg = AtariReplayDataExperimentConfig(wandb_project='baselines-atari', 
                                      name=f'dt-vswin-{games_cfg.envs[0]}', seed=123, 
                                      debug=False, always_ready=False, 
                                      use_strl=False, low_contrast_mode=None, batch_size=128, 
                                      games_cfg=games_cfg, num_steps=500000,
                                      pretrained_ckpt_path=pretrained_ckpt_path,
                                      seq_len=16)

ckpt_path = os.path.join(cfg.out_dir, f'{cfg.out_prefix}_ckpt_{cfg.out_suffix}.pth')

state_encoder_cfg = StateEncoderConfig(embed_dim=24, causal=True)
dt_cfg = DTConfig(18, 3*cfg.seq_len, n_layer=6, n_head=8, n_embd=192, model_type='naive', 
                  unknown_action=18, action_discrete=True, state_encoder_cfg=state_encoder_cfg.get(),
                  num_channels=1,state_encoder_type='transformer')

train_cfg = TrainConfig(num_epochs=2 if cfg.debug else 5, 
                        max_num_batches=500 if cfg.debug else None, 
                        ckpt_path=None if cfg.debug else ckpt_path)

train_cfg.eval_cfg.trials = 2 if cfg.debug else 10

cfg.state_encoder_cfg = state_encoder_cfg
cfg.train_cfg = train_cfg.get()

cfg.ready(init_wandb=True)

# ---------------------------------------------------------------------------------------------------------------------

logger.info(f'Num CUDA Devices: {torch.cuda.device_count()}')

gc.collect()
torch.cuda.empty_cache()

# ---------------------------------------------------------------------------------------------------------------------

dataset = AtariReplayDatasetForCAStRL.from_config(cfg)

# ---------------------------------------------------------------------------------------------------------------------

model = DecisionTransformer(dt_cfg)


logger.info('No. Parameters = {:.2f}M'.format(
    common.get_num_parameters(model) / 1e6))

# ---------------------------------------------------------------------------------------------------------------------

optimizer, scheduler = misc.configure_adamw_optimizer(model, weight_decay=0.1, num_iterations=2000)
    
# ---------------------------------------------------------------------------------------------------------------------

if train_cfg.eval_cfg.enabled:
    envs = []
    for game in cfg.games:
        env = ale_env.Env(game=game, img_size=cfg.image_size, grayscale=True, 
                        clip_reward=False, buffer_size=cfg.stack_size, seed=cfg.seed)
        envs.append(env)

    evaluator = AtariEvaluator(stage='inference', history_length=cfg.seq_len, top_k=train_cfg.eval_cfg.top_k)
    evaluator.set_device(cfg.device)
    evaluator.set_model(model)
else:
    evaluator, envs = None, None

trainer = Trainer(stage='train', cfg=cfg, model=model, evaluator=evaluator, envs=envs, 
                  optimizer=optimizer, scheduler=scheduler)

if cfg.pretrained_ckpt_path is not None:
    trainer.load_checkpoint(path=cfg.pretrained_ckpt_path, strict=False, warn=True)
    logger.info(f'ckpt "{cfg.pretrained_ckpt_path}" loaded!')
    
trainer.train(dataset, **train_cfg.get())

# ---------------------------------------------------------------------------------------------------------------------

if train_cfg.eval_cfg.enabled:
    for env in envs:
        env.close()
misc.wandb_finish()

# ---------------------------------------------------------------------------------------------------------------------
