from dataclasses import dataclass
from types import MappingProxyType
from relign.helpers import update_dict
from lift.environments import (
    LightTunnel,
    LensPositioning,
    PositionOnly,
)
import d3rlpy


@dataclass(frozen=True)
class Config:
    envs = MappingProxyType(
        {
            "lp": {
                "cls": LensPositioning,
                "args": {},
            },
            "po": {
                "cls": PositionOnly,
                "args": {
                    "n_actions": 2,
                    "noise_movement": 0.05,
                    "max_episode_steps": 50,
                },
            },
            'lp': {
                "cls": LightTunnel,
                "args": {
                    "n_actions": 2,
                    "noise_movement": 0.05,
                    "max_episode_steps": 50,
                }
            }
        }
    )

    algorithms = MappingProxyType(
        {
            "CQL": {
                "cls": d3rlpy.algos.CQLConfig,
                "args": {
                    "actor_learning_rate": 1e-3,
                    "critic_learning_rate": 1e-3,
                    "conservative_weight": 5,
                    "alpha_threshold": 10.0,

                },
                },

            "SAC": {
                "cls": d3rlpy.algos.SACConfig,
                "args": {},
            },
            "BC":{
                "cls": d3rlpy.algos.BCConfig,
                "args": {},
            }
        }
    )
    
    default_model_args = MappingProxyType({})

    training_args = MappingProxyType(
        {
            "model": "SAC",
            "env": "po",
            "total_steps": int(2e4),
            "n_steps": 1000,
            "n_samples": 5000,
            "n_alignments_per_lens": 100,
            "n_lens_systems": 100,
            "max_transitions": -1,
            "online": False,
            "seed": None,
        }
    )

    @staticmethod
    def setup_params(params: dict) -> tuple:
        default_model_args = update_dict(Config.default_model_args, params)
        training_args = update_dict(Config.training_args, params)
        env_args = update_dict(Config.envs[params["env"]]["args"], params)
        model_args = update_dict(
            Config.algorithms[params["model"]]["args"], params
        )
        if params['online']:
            training_args['n_lens_systems'] = 0
        return env_args, {**default_model_args, **model_args}, training_args
