import sys, gc, os

sys.path.append('.')

from utils import helpers

helpers.set_cuda_visible_devices('0,1')

import torch

from learn.trainer import Trainer
from learn.evaluator import AtariEvaluator

from learn.config import TrainConfig, AtariReplayDataExperimentConfig, GamesConfig
from baselines.starformer import StarformerConfig, Starformer

from learn.dataset import AtariReplayDatasetForCAStRL

from utils import common, logs_handler, ale_env, misc
from utils.replay_atari_data import ALE_DATA, to_ale_name

# ---------------------------------------------------------------------------------------------------------------------

logger = logs_handler.get_logger('run')

# ---------------------------------------------------------------------------------------------------------------------

games = list(map(to_ale_name, list(ALE_DATA.keys())))
games_cfg = GamesConfig(envs=games[-5:])

# ---------------------------------------------------------------------------------------------------------------------
pretrained_ckpt_path = None

pretrained = (pretrained_ckpt_path is not None)
cfg = AtariReplayDataExperimentConfig(wandb_project='baselines-atari', 
                                      name=f'star-multigame-eval-{len(games_cfg.envs)}', seed=123, debug=False, 
                                      always_ready=False, use_strl=False, low_contrast_mode=None, batch_size=128, 
                                      games_cfg=games_cfg, num_steps=100000, 
                                      pretrained_ckpt_path=pretrained_ckpt_path,
                                      seq_len=16)

ckpt_path = os.path.join(cfg.out_dir, f'{cfg.out_prefix}_ckpt_{cfg.out_suffix}.pth')

star_cfg = StarformerConfig(18, img_size=(4, 84, 84), patch_size=(7, 7), N_head=8, D=192, 
                            local_N_head=4, local_D=96, model_type='star', n_layer=6, maxT=cfg.seq_len,
                            unknown_action=18, action_discrete=True)

train_cfg = TrainConfig(num_epochs=2 if cfg.debug else 10, 
                        max_num_batches=500 if cfg.debug else 13500, 
                        ckpt_path=None if cfg.debug else ckpt_path)

train_cfg.eval_cfg.trials = 2 if cfg.debug else 10

cfg.star_cfg = star_cfg
cfg.train_cfg = train_cfg.get()

cfg.ready(init_wandb=True)

# ---------------------------------------------------------------------------------------------------------------------

logger.info(f'Num CUDA Devices: {torch.cuda.device_count()}')

gc.collect()
torch.cuda.empty_cache()

# ---------------------------------------------------------------------------------------------------------------------

dataset = AtariReplayDatasetForCAStRL.from_config(cfg)

# ---------------------------------------------------------------------------------------------------------------------

model = Starformer(star_cfg)


logger.info('No. Parameters = {:.2f}M'.format(
    common.get_num_parameters(model) / 1e6))

# ---------------------------------------------------------------------------------------------------------------------

optimizer, scheduler = misc.configure_adamw_optimizer(model, weight_decay=0.1, num_iterations=2000)
    
# ---------------------------------------------------------------------------------------------------------------------

if train_cfg.eval_cfg.enabled:
    envs = []
    for game in cfg.games:
        env = ale_env.Env(game=game, img_size=cfg.image_size, grayscale=True, 
                        clip_reward=False, buffer_size=cfg.stack_size, seed=cfg.seed)
        envs.append(env)

    evaluator = AtariEvaluator(stage='inference', history_length=cfg.seq_len, top_k=train_cfg.eval_cfg.top_k)
    evaluator.set_device(cfg.device)
    evaluator.set_model(model)
else:
    evaluator, envs = None, None

trainer = Trainer(stage='train', cfg=cfg, model=model, evaluator=evaluator, envs=envs, 
                  optimizer=optimizer, scheduler=scheduler)

if cfg.pretrained_ckpt_path is not None:
    trainer.load_checkpoint(path=cfg.pretrained_ckpt_path, strict=False, warn=True)
    logger.info(f'ckpt "{cfg.pretrained_ckpt_path}" loaded!')
    
trainer.train(dataset, **train_cfg.get())

# ---------------------------------------------------------------------------------------------------------------------

if train_cfg.eval_cfg.enabled:
    for env in envs:
        env.close()
misc.wandb_finish()

# ---------------------------------------------------------------------------------------------------------------------
