import random
from typing import List, Literal

import numpy as np
import torch
from omegaconf import DictConfig

# yapf: disable
task_id_map_m2p = {
    'umaze': {1: 1, 2: 3, 3: 5, 4: 6, 5: 2, 6: 4, 7: 7},
    'medium': {1: 1, 2: 6, 3: 18, 4: 22, 5: 2, 6: 7, 7: 14, 8: 19, 9: 23, 10: 8, 11: 11, 12: 15, 13: 3, 14: 9, 15: 16, 16: 20, 17: 24, 18: 4, 19: 12, 20: 17, 21: 25, 22: 5, 23: 10, 24: 13, 25: 21, 26: 26},
    'large': {1: 1, 2: 7, 3: 12, 4: 14, 5: 22, 6: 29, 7: 31, 8: 37, 9: 41, 10: 2, 11: 15, 12: 23, 13: 32, 14: 42, 15: 3, 16: 8, 17: 13, 18: 16, 19: 20, 20: 24, 21: 33, 22: 38, 23: 43, 24: 4, 25: 25, 26: 44, 27: 5, 28: 9, 29: 17, 30: 26, 31: 30, 32: 34, 33: 39, 34: 45, 35: 10, 36: 18, 37: 27, 38: 35, 39: 6, 40: 11, 41: 19, 42: 21, 43: 28, 44: 36, 45: 40, 46: 46}
}

task_id_map_inv_m2p = {
    key: {v:k for k,v in val.items()} for key, val in task_id_map_m2p.items()
}

goal_to_id = {
    "maze2d": {
        "umaze": {(1, 1): 1, (1, 2): 2, (1, 3): 3, (2, 3): 4, (3, 1): 5, (3, 2): 6, (3, 3): 7,},
        "medium": {(1, 1):1, (1, 2):2, (1, 5):3, (1, 6):4, (2, 1):5, (2, 2):6, (2, 4):7, (2, 5):8, (2, 6):9, (3, 2):10, (3, 3):11, (3, 4):12, (4, 1):13, (4, 2):14, (4, 4):15, (4, 5):16, (4, 6):17, (5, 1):18, (5, 3):19, (5, 4):20, (5, 6):21, (6, 1):22, (6, 2):23, (6, 3):24, (6, 5):25, (6, 6):26,},
        "large": {(1, 1):1, (1, 2):2, (1, 3):3, (1, 4):4, (1, 6):5, (1, 7):6, (1, 8):7, (1, 9):8, (1, 10):9, (2, 1):10, (2, 4):11, (2, 6):12, (2, 8):13, (2, 10):14, (3, 1):15, (3, 2):16, (3, 3):17, (3, 4):18, (3, 5):19, (3, 6):20, (3, 8):21, (3, 9):22, (3, 10):23, (4, 1):24, (4, 6):25, (4, 10):26, (5, 1):27, (5, 2):28, (5, 4):29, (5, 6):30, (5, 7):31, (5, 8):32, (5, 9):33, (5, 10):34, (6, 2):35, (6, 4):36, (6, 6):37, (6, 8):38, (7, 1):39, (7, 2):40, (7, 4):41, (7, 5):42, (7, 6):43, (7, 8):44, (7, 9):45, (7, 10):46,},
    },
    "point": {
        "umaze": {(0, 0):1, (0, 8):2, (4, 0):3, (4, 8):4, (8, 0):5, (8, 4):6, (8, 8):7,},
        "medium": {(0, 0):1, (0, 4):2, (0, 12):3, (0, 16):4, (0, 20):5, (4, 0):6, (4, 4):7, (4, 8):8, (4, 12):9, (4, 20):10, (8, 8):11, (8, 16):12, (8, 20):13, (12, 4):14, (12, 8):15, (12, 12):16, (12, 16):17, (16, 0):18, (16, 4):19, (16, 12):20, (16, 20):21, (20, 0):22, (20, 4):23, (20, 12):24, (20, 16):25, (20, 20):26,},
        "large": {(0, 0):1, (0, 4):2, (0, 8):3, (0, 12):4, (0, 16):5, (0, 24):6, (4, 0):7, (4, 8):8, (4, 16):9, (4, 20):10, (4, 24):11, (8, 0):12, (8, 8):13, (12, 0):14, (12, 4):15, (12, 8):16, (12, 16):17, (12, 20):18, (12, 24):19, (16, 8):20, (16, 24):21, (20, 0):22, (20, 4):23, (20, 8):24, (20, 12):25, (20, 16):26, (20, 20):27, (20, 24):28, (24, 0):29, (24, 16):30, (28, 0):31, (28, 4):32, (28, 8):33, (28, 16):34, (28, 20):35, (28, 24):36, (32, 0):37, (32, 8):38, (32, 16):39, (32, 24):40, (36, 0):41, (36, 4):42, (36, 8):43, (36, 12):44, (36, 16):45, (36, 24):46,},
    }
}
# yapf: enable


def goal_to_task_id(goal, env_id):
    goal = np.array(goal)
    if "maze2d" in env_id:
        goal_tuple = tuple((goal + 0.5).astype("int"))
        if "umaze" in env_id:
            task_id = goal_to_id["maze2d"]["umaze"][goal_tuple]
        elif "medium" in env_id:
            task_id = goal_to_id["maze2d"]["medium"][goal_tuple]
        elif "large" in env_id:
            task_id = goal_to_id["maze2d"]["large"][goal_tuple]
        else:
            raise ValueError
    elif "point" in env_id or "ant" in env_id:
        goal_tuple = tuple((4 * np.floor(goal / 4 + 0.5)).astype("int"))
        if "umaze" in env_id:
            task_id = goal_to_id["point"]["umaze"][goal_tuple]
        elif "medium" in env_id:
            task_id = goal_to_id["point"]["medium"][goal_tuple]
        elif "large" in env_id:
            task_id = goal_to_id["point"]["large"][goal_tuple]
        else:
            raise ValueError
    else:
        raise ValueError
    return task_id


def distance(x, y):
    d = torch.abs(x - y)
    return torch.square(d).sum(dim=1)


def calc_accuracy(domain_logits, domain_ids):
    mask = torch.softmax(domain_logits, dim=1) > 0.5
    correct = torch.sum(domain_ids * mask)
    return correct / len(domain_logits)


def calc_output_ratio(domain_logits):
    mask = torch.softmax(domain_logits, dim=1) > 0.5
    ratio = torch.sum(mask[:, 0]) / len(domain_logits)
    return ratio


def process_args(
    args: DictConfig,
    phase: Literal["align", "adapt"] = "align",
    inference_task_ids: List[int] = [7],
) -> DictConfig:

    args.seed = np.random.randint(2**31)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)

    for domain in ["source", "target"]:
        env_id = args[domain + "_env_id"]

        if "ant" in env_id:
            args[domain + "_state_dim"] = 29
            args[domain + "_action_dim"] = 8
        elif "point" in env_id:
            args[domain + "_state_dim"] = 6
            args[domain + "_action_dim"] = 2
        elif "maze2d" in env_id:
            args[domain + "_state_dim"] = 4
            args[domain + "_action_dim"] = 2
        else:
            print("Unrecognized env_id:", env_id)
            raise ValueError

    args.policy.state_dim = max((args.source_state_dim, args.target_state_dim))
    args.policy.out_dim = max((args.source_action_dim, args.target_action_dim))

    if args.maze_type == "umaze":
        proxy_task_ids = list(range(1, 8))
        args.all_task_ids = list(range(1, 8))
    elif args.maze_type == "medium":
        proxy_task_ids = list(range(1, 27))
        args.all_task_ids = list(range(1, 27))
    elif args.maze_type == "large":
        proxy_task_ids = list(range(1, 47))
        args.all_task_ids = list(range(1, 47))
    else:
        print("Unrecognized maze_type:", args.maze_type)
        raise ValueError

    for inference_task_id in inference_task_ids:
        proxy_task_ids.remove(inference_task_id)

    if phase == "align":
        args.task_ids = proxy_task_ids
    elif phase == "adapt":
        args.task_ids = inference_task_ids
    args.num_task_ids = len(args.all_task_ids)
    args.policy.cond_dim = len(args.all_task_ids)

    return args


def get_success(obs, target, env_id):
    if "maze2d" in env_id:
        threshold = 0.1
    else:
        threshold = 1.2
    return np.linalg.norm(obs[:2] - target) <= threshold