import sys, gc, os

sys.path.append('.')

from utils import helpers

helpers.set_cuda_visible_devices('0,1,2,3')

import torch

from learn.trainer import Trainer
from learn.evaluator import AtariEvaluator

from learn.config import CAStRLConfig, ContextGPTConfig, StRLConfig,\
    TrainConfig, AtariReplayDataExperimentConfig, GamesConfig, StateEncoderConfig, EvaluateConfig
from learn.castrl import CAStRL

from learn.dataset import AtariReplayDatasetForCAStRL

from utils import common, logs_handler, ale_env, misc
from utils.replay_atari_data import ALE_DATA, to_ale_name

# ---------------------------------------------------------------------------------------------------------------------

logger = logs_handler.get_logger('run')

# ---------------------------------------------------------------------------------------------------------------------

games = list(map(to_ale_name, list(ALE_DATA.keys())))
games_cfg = GamesConfig(envs=['qbert'])

# ---------------------------------------------------------------------------------------------------------------------
pretrained_ckpt_path = 'castrl-pretrain_ckpt_08-31-2023_02-36-44-PM.pth'

pretrained = (pretrained_ckpt_path is not None)
cfg = AtariReplayDataExperimentConfig(wandb_project='benchmarks-atari', 
                                      name=f'castrl-{games_cfg.envs[0]}',
                                      seed=123, debug=False, always_ready=False, 
                                      use_strl=False, low_contrast_mode=None, batch_size=128, 
                                      games_cfg=games_cfg, num_steps=500000,
                                      pretrained_ckpt_path=pretrained_ckpt_path,
                                      seq_len=16)

csv_path = os.path.join(cfg.out_dir, f'{cfg.out_prefix}_loss_{cfg.out_suffix}.csv')
ckpt_path = os.path.join(cfg.out_dir, f'{cfg.out_prefix}_ckpt_{cfg.out_suffix}.pth')

state_encoder_cfg = StateEncoderConfig(embed_dim=96)
context_gpt_cfg = ContextGPTConfig(num_layers=6)
strl_cfg = StRLConfig(context_type='next_state')
castrl_cfg = CAStRLConfig(context_dim=192, use_actions=True, 
                          expander_dims=[1024]*3,
                          state_encoder_cfg=state_encoder_cfg,
                          context_gpt_cfg=context_gpt_cfg, 
                          strl_cfg=strl_cfg)
eval_cfg = EvaluateConfig(trials = 2 if cfg.debug else 16, 
                          step=1, include_first_step=True)
train_cfg = TrainConfig(num_epochs=2 if cfg.debug else 10, 
                        max_num_batches=500 if cfg.debug else None, 
                        ckpt_path=None if cfg.debug else ckpt_path, 
                        csv_path=None if cfg.debug else csv_path, 
                        eval_cfg=eval_cfg, 
                        freeze_at=1 if pretrained else None,
                        unfreeze_at=2 if pretrained else None,
                        freeze_all=True,
                        unfreeze_all=True)

cfg.castrl_cfg = castrl_cfg.get()
cfg.train_cfg = train_cfg.get()

cfg.ready(init_wandb=True)

# ---------------------------------------------------------------------------------------------------------------------

logger.info(f'Num CUDA Devices: {torch.cuda.device_count()}')

gc.collect()
torch.cuda.empty_cache()

# ---------------------------------------------------------------------------------------------------------------------

dataset = AtariReplayDatasetForCAStRL.from_config(cfg)

# ---------------------------------------------------------------------------------------------------------------------

# castrl_cfg.actions_weights = misc.get_tokens_weights(dataset.actions, num_tokens=castrl_cfg.num_action_tokens)
# logger.info(f'Actions Weights: {castrl_cfg.actions_weights}')

model = CAStRL(actions_pretrained=False, **castrl_cfg.get())

logger.info('No. Parameters = {:.2f}M'.format(
    common.get_num_parameters(model) / 1e6))

# ---------------------------------------------------------------------------------------------------------------------

optimizer, scheduler = misc.configure_adamw_optimizer(model, weight_decay=0.1, num_iterations=2000)
    
# ---------------------------------------------------------------------------------------------------------------------

if train_cfg.eval_cfg.enabled:
    envs = []
    for game in cfg.games:
        env = ale_env.Env(game=game, img_size=cfg.image_size, grayscale=True, 
                        clip_reward=False, buffer_size=cfg.stack_size, seed=cfg.seed)
        envs.append(env)

    evaluator = AtariEvaluator(stage='evaluate_bc', history_length=cfg.seq_len, top_k=train_cfg.eval_cfg.top_k)
    evaluator.set_device(cfg.device)
    evaluator.set_model(model)
else:
    evaluator, envs = None, None

trainer = Trainer(stage='finetune_bc', cfg=cfg, model=model, evaluator=evaluator, envs=envs, 
                  optimizer=optimizer, scheduler=scheduler)

if cfg.pretrained_ckpt_path is not None:
    trainer.load_checkpoint(path=cfg.pretrained_ckpt_path, strict=False, warn=True)
    logger.info(f'ckpt "{cfg.pretrained_ckpt_path}" loaded!')
    
trainer.train(dataset, **train_cfg.get())

# ---------------------------------------------------------------------------------------------------------------------

if train_cfg.eval_cfg.enabled:
    for env in envs:
        env.close()
misc.wandb_finish()

# ---------------------------------------------------------------------------------------------------------------------
