import dataclasses
import re
from pathlib import Path
from typing import Any
import torch

import hydra
from omegaconf import OmegaConf

from common import MODEL_SIZE, TASK_SET


def cfg_to_dataclass(cfg, frozen=False):
    """
    Converts an OmegaConf config to a dataclass object.
    This prevents graph breaks when used with torch.compile.
    """
    cfg_dict = OmegaConf.to_container(cfg)
    fields = []
    for key, value in cfg_dict.items():
        fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_)))
    dataclass_name = "Config"
    dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen)
    def get(self, val, default=None):
        return getattr(self, val, default)
    dataclass.get = get
    return dataclass()


def parse_cfg(cfg: OmegaConf) -> OmegaConf:
    """
    Parses a Hydra config. Mostly for convenience.
    """

    # Logic
    for k in cfg.keys():
        try:
            v = cfg[k]
            if v == None:
                v = True
        except:
            pass

    # Algebraic expressions
    for k in cfg.keys():
        try:
            v = cfg[k]
            if isinstance(v, str):
                match = re.match(r"(\d+)([+\-*/])(\d+)", v)
                if match:
                    cfg[k] = eval(match.group(1) + match.group(2) + match.group(3))
                    if isinstance(cfg[k], float) and cfg[k].is_integer():
                        cfg[k] = int(cfg[k])
        except:
            pass

    # Convenience
    cfg.work_dir = Path(hydra.utils.get_original_cwd()) / 'logs' / cfg.task / str(cfg.seed) / cfg.exp_name
    cfg.task_title = cfg.task.replace("-", " ").title()
    cfg.bin_size = (cfg.vmax - cfg.vmin) / (cfg.num_bins-1) if cfg.num_bins > 1 else 1e6 # Bin size for discrete regression

    # Model size
    if cfg.get('model_size', None) is not None:
        assert cfg.model_size in MODEL_SIZE.keys(), \
            f'Invalid model size {cfg.model_size}. Must be one of {list(MODEL_SIZE.keys())}'
        for k, v in MODEL_SIZE[cfg.model_size].items():
            cfg[k] = v
        if cfg.task == 'mt30' and cfg.model_size == 19:
            cfg.latent_dim = 512 # This checkpoint is slightly smaller

    # Multi-task
    cfg.multitask = cfg.task in TASK_SET.keys()
    if cfg.multitask:
        cfg.task_title = cfg.task.upper()
        # Account for slight inconsistency in task_dim for the mt30 experiments
        cfg.task_dim = 96 if cfg.task == 'mt80' or cfg.get('model_size', 5) in {1, 317} else 64
    else:
        cfg.task_dim = 0
    cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])

    # Check torch.compile compatibility
    if cfg.get('compile', False):
        assert cfg.obs == 'state', 'torch.compile only supports state observations at the moment.'
        assert not cfg.multitask, 'torch.compile does not support multitask training at the moment.'
    
    # Check GPU accessiblae
    if torch.cuda.is_available():
        cfg.device = "cuda"
    else:
        cfg.device = "cpu"

    # Check student model configs
    # return cfg_to_dataclass(cfg)
    cfg = cfg_to_dataclass(cfg)
    if hasattr(cfg, 'student_cfg'):
        cfg.student_cfg = cfg_to_dataclass(OmegaConf.create(cfg.student_cfg))
        cfg.student_cfg.latent_dim = cfg.latent_dim
        cfg.student_cfg.multitask = cfg.multitask
        cfg.student_cfg.tasks = cfg.tasks
        cfg.student_cfg.task_dim = cfg.task_dim
    return cfg
