import sys, gc, os

sys.path.append('.')

from utils import helpers

helpers.set_cuda_visible_devices('0,1,2,3')

import torch

from learn.trainer import Trainer

from learn.config import CAStRLConfig, ContextGPTConfig, StRLConfig,\
    TrainConfig, AtariReplayDataExperimentConfig, GamesConfig, StateEncoderConfig
from learn.castrl import CAStRL

from learn.dataset import AtariReplayDatasetForCAStRL

from utils import common, logs_handler, misc
from utils.replay_atari_data import ALE_DATA, to_ale_name


# ---------------------------------------------------------------------------------------------------------------------

logger = logs_handler.get_logger('run')

# ---------------------------------------------------------------------------------------------------------------------

games = list(map(to_ale_name, list(ALE_DATA.keys())))
games_cfg = GamesConfig(envs=games[:-5], eval_envs=games[-5:])

logger.info(f'Test Envs: {games_cfg.eval_envs}')

# ---------------------------------------------------------------------------------------------------------------------
pretrained_ckpt_path = None
cfg = AtariReplayDataExperimentConfig(wandb_project='castrl-atari', 
                                      name='castrl-pretrain', 
                                      seed=123, debug=False, always_ready=False, 
                                      use_strl=True, low_contrast_mode=None, batch_size=64, 
                                      games_cfg=games_cfg, num_steps=10000,
                                      pretrained_ckpt_path=pretrained_ckpt_path, 
                                      seq_len=16)

csv_path = os.path.join(cfg.out_dir, f'{cfg.out_prefix}_loss_{cfg.out_suffix}.csv')
ckpt_path = os.path.join(cfg.out_dir, f'{cfg.out_prefix}_ckpt_{cfg.out_suffix}.pth')

state_encoder_cfg = StateEncoderConfig(embed_dim=96)
context_gpt_cfg = ContextGPTConfig(num_layers=12)
strl_cfg = StRLConfig(context_type='next_state')
castrl_cfg = CAStRLConfig(context_dim=768, use_actions=False, 
                          expander_dims=[2048]*3,
                          state_encoder_cfg=state_encoder_cfg,
                          context_gpt_cfg=context_gpt_cfg, 
                          strl_cfg=strl_cfg)
train_cfg = TrainConfig(num_epochs=2 if cfg.debug else 10, 
                        max_num_batches=500 if cfg.debug else None, 
                        ckpt_path=None if cfg.debug else ckpt_path, 
                        csv_path=None if cfg.debug else csv_path)
train_cfg.eval_cfg.enabled = False

cfg.castrl_cfg = castrl_cfg.get()
cfg.train_cfg = train_cfg.get()

cfg.ready(init_wandb=True)

# ---------------------------------------------------------------------------------------------------------------------

gc.collect()
torch.cuda.empty_cache()

# ---------------------------------------------------------------------------------------------------------------------

dataset = AtariReplayDatasetForCAStRL.from_config(cfg)

# ---------------------------------------------------------------------------------------------------------------------

castrl_cfg.actions_weights = misc.get_tokens_weights(dataset.actions, num_tokens=castrl_cfg.num_action_tokens)
logger.info(f'Actions Weights: {castrl_cfg.actions_weights}')

model = CAStRL(**castrl_cfg.get())

logger.info('No. Parameters = {:.2f}M'.format(
    common.get_num_parameters(model) / 1e6))

# ---------------------------------------------------------------------------------------------------------------------

optimizer, scheduler = misc.configure_adamw_optimizer(model, weight_decay=0.1, num_iterations=2000)
    
# ---------------------------------------------------------------------------------------------------------------------
evaluator, envs = None, None
trainer = Trainer(stage='pretrain', cfg=cfg, model=model, evaluator=evaluator, envs=envs, 
                  optimizer=optimizer, scheduler=scheduler)

if cfg.pretrained_ckpt_path is not None:
    trainer.load_checkpoint(path=cfg.pretrained_ckpt_path, strict=False, warn=True)
    logger.info(f'ckpt "{cfg.pretrained_ckpt_path}" loaded!')

trainer.train(dataset, **train_cfg.get())

# ---------------------------------------------------------------------------------------------------------------------

misc.wandb_finish()

# ---------------------------------------------------------------------------------------------------------------------
