import random
from typing import List, Literal

import numpy as np
import torch
from omegaconf import DictConfig


def process_args(
    args: DictConfig,
    phase: Literal["align", "adapt"] = "align",
    inference_task_ids: List[int] = [7],
) -> DictConfig:

    args.seed = np.random.randint(2**31)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)

    for domain in ["source", "target"]:
        env_id = args[domain + "_env_id"]

        if "ant" in env_id:
            args[domain + "_state_dim"] = 29
            args[domain + "_action_dim"] = 8
        elif "point" in env_id:
            args[domain + "_state_dim"] = 6
            args[domain + "_action_dim"] = 2
        elif "maze2d" in env_id:
            args[domain + "_state_dim"] = 4
            args[domain + "_action_dim"] = 2
        else:
            print("Unrecognized env_id:", env_id)
            raise ValueError

    assert args.source_env_id.split("-")[1] == args.target_env_id.split("-")[1]

    if "umaze" in args.source_env_id:
        proxy_task_ids = list(range(1, 8))
        all_task_ids = list(range(1, 8))
    elif "medium" in args.source_env_id:
        proxy_task_ids = list(range(1, 27))
        all_task_ids = list(range(1, 27))
    elif "large" in args.source_env_id:
        proxy_task_ids = list(range(1, 47))
        all_task_ids = list(range(1, 47))
    else:
        print("Unrecognized env_id:", env_id)
        raise ValueError

    for inference_task_id in inference_task_ids:
        proxy_task_ids.remove(inference_task_id)

    if phase == "align":
        args.task_ids = proxy_task_ids
    elif phase == "adapt":
        args.task_ids = inference_task_ids
    args.num_task_ids = len(all_task_ids)

    return configure_model_params(args)


def configure_model_params(args):
    goal_dim = args.num_task_ids

    m = args.models
    m.source_policy.in_dim = args.source_state_dim + goal_dim
    m.source_policy.hid_dims[-1] = args.source_action_dim

    m.target_policy.in_dim = args.target_state_dim + goal_dim
    m.target_policy.hid_dims[-1] = args.target_action_dim

    m.state_map.in_dim = args.target_state_dim
    m.state_map.hid_dims[-1] = args.source_state_dim

    m.action_map.in_dim = args.source_action_dim
    m.action_map.hid_dims[-1] = args.target_action_dim

    m.inv_state_map.in_dim = args.source_state_dim
    m.inv_state_map.hid_dims[-1] = args.target_state_dim

    m.dynamics_model.in_dim = args.target_state_dim + args.target_action_dim
    m.dynamics_model.hid_dims[-1] = args.target_state_dim

    m.discriminator.in_dim = args.source_state_dim * 2 + args.source_action_dim
    if args.task_cond:
        m.discriminator.in_dim += args.num_task_ids

    return args