import os

from typing import Tuple, List, Union, Callable
from dataclasses import dataclass, field, asdict

import numpy as np
import wandb

import torch

from utils import helpers, common, logs_handler

logger = logs_handler.get_logger('cfg')

@dataclass
class GamesConfig:
    envs: List[str] = field(default_factory=lambda: ['breakout', 'assault', 'pong', 'bowling', 'qbert', 'seaquest'])
    eval_envs: List[str] = field(default_factory=lambda: [])
    unlabeled_envs: List[str] = field(default_factory=lambda: [])
    unlabeled_ratio: float = 1.0
    use_unlabeled: bool = False
    def get(self):
        return asdict(self)
    
@dataclass
class StRLConfig:
    window: int = 8
    mask_window: int = 4
    gamma: float = 1.0
    beta: float = 0.1
    eta: float = 1.0
    context_type: str = 'masked_state'
    def get(self):
        return asdict(self)

@dataclass
class StateEncoderConfig:
    embed_dim: int = 48
    temporal_window: int = 2
    spatial_window: Tuple[int, int] = (7, 7)
    temporal_patches: int = 4
    patch_size: Tuple[int, int] = (4, 4)
    pool_size: Tuple[int, int, int] = (None, 1, 1)
    num_heads: List[int] = field(default_factory=lambda: [3, 6, 12, 24])
    depths: List[int] = field(default_factory=lambda: [2, 2, 6, 2])
    qkv_bias: bool = True
    drop_path_rate: float = 0.1
    causal: bool = True
    def get(self):
        cfgs = asdict(self)
        return cfgs

@dataclass
class ContextGPTConfig:
    num_layers: int = 6
    num_heads: int = 8
    attn_drop: float = 0.1
    resid_pdrop: float = 0.1
    mlp_ratio: float = 4.0
    def get(self):
        cfgs = asdict(self)
        return cfgs
    
@dataclass
class CAStRLConfig:
    context_dim: int = 192
    num_channels: int = 1
    expander_dims: List[int] = field(default_factory=lambda: [1024, 1024, 1024])
    state_hidden_dims: List[int] = field(default_factory=lambda: None)
    action_hidden_dims: List[int] = field(default_factory=lambda: None)
    num_action_tokens: int = 18
    max_seq_len: int = 32
    unknown_action: Union[int, float] = 18
    action_discrete: bool = True
    actions_weights: List[float] = None
    use_actions: bool = False
    
    state_encoder_cfg: StateEncoderConfig = field(default_factory=lambda: StateEncoderConfig())
    context_gpt_cfg: ContextGPTConfig = field(default_factory=lambda: ContextGPTConfig())
    strl_cfg: StRLConfig = field(default_factory=lambda: StRLConfig())
    def get(self):
        cfgs = asdict(self)
        return cfgs

@dataclass
class EvaluateConfig:
    enabled: bool = True
    step: int = 1
    include_first_step: bool = False
    trials: int = 10
    reduce_method: Callable = np.mean
    top_k: int = None
    def get(self):
        return asdict(self)

@dataclass
class TrainConfig:
    num_epochs: int = 10
    max_num_batches: int = None
    grad_accum_step: int = 1
    scheduler_step: int = 1
    max_grad_norm: float = 1.0
    freeze_at: int = None
    unfreeze_at: int = None
    freeze_all: bool = False
    unfreeze_all: bool = False
    csv_path: str = None
    ckpt_path: str = None 
    monitor: str = 'loss'
    mode: Callable = min
    eval_cfg: EvaluateConfig = field(default_factory=lambda: EvaluateConfig())
    def get(self):
        return asdict(self)

@dataclass
class AtariReplayDataExperimentConfig:
    name: str
    debug: bool = False
    pretrained_ckpt_path: str = None
    wandb_group: str = None
    wandb_project: str = 'castrl-atari'
    working_dir: str = '/data/datasets'
    cache_dir: str = 'cache'
    cachefile_prefix: Union[str, List[str]] = None
    num_workers: int = os.cpu_count()
    out_dir: str = None
    out_prefix: str = None
    out_suffix: str = None
    
    _replay_data_dir: str = 'atari/dqn/{game_dir}/2/replay_logs'
    
    games_cfg: GamesConfig = field(default_factory=lambda: GamesConfig())
    
    games: List[str] = field(init=False)
    replay_data_dir: List[str] = field(init=False)
    
    num_steps: Union[int, List[int]] = 500000
    start_buffer: int = 0
    num_buffers: int = 50
    unlabeled_ratio: Union[float, List[float]] = None
    unknown_label: int = 18
    map_ale_actions: bool = True
    
    seed: int = 123
    data_ratio: List[float] = None
    image_size: Tuple[int, int] = (84, 84)
    stack_size: int = 4
    seq_len: int = 16
    overlap_ratios: Union[List[float], str] = None
    frame_rate = 1
    batch_size: int = 128
    use_strl: bool = False
    low_contrast_mode: str = None
    
    num_gpus: int = field(init=False)
    device: int = field(init=False)
    device_ids: List[int] = field(init=False)
    
    always_ready: bool = True

    def get_replay_data_dir(self, game):
        game_dir = ''.join(list(map(lambda substr: substr.capitalize(), game.split('_'))))
        return self._replay_data_dir.format(game_dir=game_dir)
    
    def __post_init__(self):
        self.out_prefix = self.out_prefix or self.name
        self.out_suffix = self.out_suffix or helpers.str_datetime()
        self.out_dir = self.out_dir or self.working_dir

        if self.pretrained_ckpt_path is not None:
            self.pretrained_ckpt_path = os.path.join(self.out_dir, self.pretrained_ckpt_path)

        assert os.path.exists(self.working_dir), f'WORKING_DIR - "{self.working_dir}" doesn\'t exist'
        assert os.path.exists(self.out_dir), f'OUT_DIR - "{self.out_dir}" doesn\'t exist'
        
        if self.debug:
            self.num_steps = 10
            
        self.games = self.games_cfg.envs
        if not self.games_cfg.use_unlabeled:
            self.games += self.games_cfg.unlabeled_envs
        games_count = len(self.games)
        if isinstance(self.num_steps, int):
            self.num_steps = [self.num_steps]*games_count
        self.replay_data_dir = []
        self.unlabeled_ratio = []
        for game in self.games_cfg.envs:
            self.replay_data_dir.append(self.get_replay_data_dir(game))
            self.unlabeled_ratio.append(None)
        if self.games_cfg.use_unlabeled:
            for game in self.games_cfg.unlabeled_envs:
                self.replay_data_dir.append(self.get_replay_data_dir(game))
                self.unlabeled_ratio.append(self.games_cfg.unlabeled_ratio)
        self.games += self.games_cfg.eval_envs
        self.replay_data_dir = list(map(lambda data_dir: os.path.join(self.working_dir, data_dir), self.replay_data_dir))

        if self.cache_dir is not None:
            self.cache_dir = os.path.join(self.out_dir, self.cache_dir)
            self.cachefile_prefix = self.games.copy()
                        
        logger.info(f'DEBUG - {self.debug}')
        logger.info(f'NUM_WORKERS - {self.num_workers}')
        logger.info(f'WORKING_DIR - "{self.working_dir}"')
        logger.info(f'OUT_DIR - "{self.out_dir}"')
        logger.info(f'REPLAY_DATA_DIR - {self.replay_data_dir}')
        logger.info(f'GAMES - {self.games}')
        logger.info(f'NUM_STEPS: {self.num_steps}')
        if self.overlap_ratios:
            logger.info(f'OVERLAP_RATIOS: {self.overlap_ratios}')
        logger.info(f'CACHE_DIR - {self.cache_dir}')
        logger.info(f'PRETRAINED_CKPT_PATH - {self.pretrained_ckpt_path}')
        logger.info(f'OUT_PREFIX - "{self.out_prefix}" | OUT_SUFFIX - "{self.out_suffix}"')
        
        if self.always_ready:
           self.ready()
         
    def ready(self, init_wandb=True, rank_zero_only=True):
        if (not rank_zero_only) and (self.wandb_group is None):
            wandb_group = f'{self.out_prefix}_{self.out_suffix}'
            wandb_run_name = None
        else:
            wandb_group = self.wandb_group
            wandb_run_name = f'{self.out_prefix}_{self.out_suffix}'
        
        common.set_seed(self.seed)

        self.num_gpus = torch.cuda.device_count()
        self.device = common.get_cuda_current_device()
        self.device_ids = list(range(self.num_gpus))

        logger.info(f'Out Device Id: {self.device}')
        logger.info(f'Device Ids: {self.device_ids}')
        
        if init_wandb and (not self.debug) and \
            ((not rank_zero_only) or helpers.is_rank_zero()):            
            wandb.init(
                project=self.wandb_project,
                group=wandb_group,
                name=wandb_run_name,
                config=self,
            )

    def get(self):
        return asdict(self)
