# Define the config for robomimic
import optax
from ml_collections import ConfigDict

from openx.data.datasets.bridge import bridge_dataset_transform
from openx.data.utils import NormalizationType, StateEncoding
from openx.networks.action_heads import L2ActionHead
from openx.networks.core import Model
from openx.networks.mlp import MLP, Concatenate
from openx.networks.resnet import ResNet18, ResNet34, ResNet50
from openx.utils.spec import ModuleSpec


def get_config(config_str: str = "state,2,50,None,True"):
    # By default returns the bridge config.
    data_type, n_obs, resnet_class, num_kp, augs = config_str.split(",")

    assert data_type in {"state", "img"}
    assert resnet_class in {"18", "34", "50"}
    assert n_obs in {"1", "2"}
    assert num_kp in {"None", "64", "96", "128", "160", "192", "256"}
    assert augs in {"True", "False"}

    resnet_class = {"18": ResNet18, "34": ResNet34, "50": ResNet50}[resnet_class]

    structure = {
        "observation": {
            "image": {
                "agent": (224, 288),  # Height x width
            },
            "state": {
                StateEncoding.EE_POS: NormalizationType.NONE,
                StateEncoding.EE_EULER: NormalizationType.NONE,
                StateEncoding.GRIPPER: NormalizationType.NONE,
            },
        },
        "action": {
            "achieved_delta": {
                StateEncoding.EE_POS: NormalizationType.GAUSSIAN,
                StateEncoding.EE_EULER: NormalizationType.GAUSSIAN,
            },
            "desired_absolute": {StateEncoding.GRIPPER: NormalizationType.NONE},
        },
    }

    dataloader = dict(
        datasets=dict(
            bridge=dict(
                path="REDACTED/preprocessed/bridge_256x341/1.0.0",
                train_split="train",
                val_split="val",
                transform=ModuleSpec.create(bridge_dataset_transform),
            ),
        ),
        n_obs=int(n_obs),
        n_action=1,
        augment_kwargs=dict(
            scale_range=(0.8, 1.0),
            aspect_ratio_range=(1.12, 1.328),
            aligned=True,
            **(
                dict(
                    brightness=0.1,
                    contrast_range=[0.9, 1.1],
                    saturation_range=[0.9, 1.1],
                    hue=0.025,
                )
                if augs == "True"
                else dict()
            ),
        ),
        chunk_img=True,
        goal_conditioned=True,
        shuffle_size=250000,
        batch_size=384,
        recompute_statistics=False,
    )

    model = ModuleSpec.create(
        Model,
        encoders={
            "observation->image->agent,goal->image->agent": ModuleSpec.create(
                resnet_class,
                spatial_coordinates=num_kp == "None",
                act="swish",
                num_kp=None if num_kp == "None" else int(num_kp),
            ),
            **(
                {
                    "observation->state": None,
                }
                if config_str == "state"
                else {}
            ),
        },
        trunk=ModuleSpec.create(Concatenate, features=None, flatten_time=True),
        action_head=ModuleSpec.create(
            L2ActionHead,
            model=ModuleSpec.create(
                MLP, hidden_dims=(512, 512, 512), dropout_rate=None, activate_final=True, use_layer_norm=True
            ),
        ),
    )

    lr_schedule = ModuleSpec.create(
        optax.warmup_cosine_decay_schedule,
        init_value=1e-6,
        peak_value=2e-4,
        warmup_steps=1000,
        decay_steps=500000,
        end_value=1e-6,
    )
    optimizer = ModuleSpec.create(optax.adamw)

    envs = None
    return ConfigDict(
        dict(
            structure=structure,
            envs=envs,
            model=model,
            dataloader=dataloader,
            optimizer=optimizer,
            lr_schedule=lr_schedule,
            # Add training parameters
            steps=500000,
            log_freq=500,
            val_freq=2500,
            eval_freq=20000,
            save_freq=100000,
            val_steps=25,
            seed=0,
        )
    )
