import random
import subprocess
from typing import List, Literal, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from omegaconf import DictConfig

# yapf: disable
goal_to_id = {
    "maze2d": {
        "umaze": {(1, 1): 1, (1, 2): 2, (1, 3): 3, (2, 3): 4, (3, 1): 5, (3, 2): 6, (3, 3): 7,},
        "medium": {(1, 1):1, (1, 2):2, (1, 5):3, (1, 6):4, (2, 1):5, (2, 2):6, (2, 4):7, (2, 5):8, (2, 6):9, (3, 2):10, (3, 3):11, (3, 4):12, (4, 1):13, (4, 2):14, (4, 4):15, (4, 5):16, (4, 6):17, (5, 1):18, (5, 3):19, (5, 4):20, (5, 6):21, (6, 1):22, (6, 2):23, (6, 3):24, (6, 5):25, (6, 6):26,},
        "large": {(1, 1):1, (1, 2):2, (1, 3):3, (1, 4):4, (1, 6):5, (1, 7):6, (1, 8):7, (1, 9):8, (1, 10):9, (2, 1):10, (2, 4):11, (2, 6):12, (2, 8):13, (2, 10):14, (3, 1):15, (3, 2):16, (3, 3):17, (3, 4):18, (3, 5):19, (3, 6):20, (3, 8):21, (3, 9):22, (3, 10):23, (4, 1):24, (4, 6):25, (4, 10):26, (5, 1):27, (5, 2):28, (5, 4):29, (5, 6):30, (5, 7):31, (5, 8):32, (5, 9):33, (5, 10):34, (6, 2):35, (6, 4):36, (6, 6):37, (6, 8):38, (7, 1):39, (7, 2):40, (7, 4):41, (7, 5):42, (7, 6):43, (7, 8):44, (7, 9):45, (7, 10):46,},
    },
    "point": {
        "umaze": {(0, 0):1, (0, 8):2, (4, 0):3, (4, 8):4, (8, 0):5, (8, 4):6, (8, 8):7,},
        "medium": {(0, 0):1, (0, 4):2, (0, 12):3, (0, 16):4, (0, 20):5, (4, 0):6, (4, 4):7, (4, 8):8, (4, 12):9, (4, 20):10, (8, 8):11, (8, 16):12, (8, 20):13, (12, 4):14, (12, 8):15, (12, 12):16, (12, 16):17, (16, 0):18, (16, 4):19, (16, 12):20, (16, 20):21, (20, 0):22, (20, 4):23, (20, 12):24, (20, 16):25, (20, 20):26,},
        "large": {(0, 0):1, (0, 4):2, (0, 8):3, (0, 12):4, (0, 16):5, (0, 24):6, (4, 0):7, (4, 8):8, (4, 16):9, (4, 20):10, (4, 24):11, (8, 0):12, (8, 8):13, (12, 0):14, (12, 4):15, (12, 8):16, (12, 16):17, (12, 20):18, (12, 24):19, (16, 8):20, (16, 24):21, (20, 0):22, (20, 4):23, (20, 8):24, (20, 12):25, (20, 16):26, (20, 20):27, (20, 24):28, (24, 0):29, (24, 16):30, (28, 0):31, (28, 4):32, (28, 8):33, (28, 16):34, (28, 20):35, (28, 24):36, (32, 0):37, (32, 8):38, (32, 16):39, (32, 24):40, (36, 0):41, (36, 4):42, (36, 8):43, (36, 12):44, (36, 16):45, (36, 24):46,},
    }
}
# yapf: enable


def get_task_id(target, env_id):
    target = np.array(target)
    maze_type = env_id.split("-")[1]
    if "maze2d" in env_id:
        target = tuple((target + 0.5).astype("int"))
        return goal_to_id["maze2d"][maze_type][target]
    elif "point" in env_id or "ant" in env_id:
        target = tuple(((target + 2) / 4).astype("int") * 4)
        return goal_to_id["point"][maze_type][target]
    else:
        raise ValueError("Unrecognized env_id " + env_id)


def get_activations(name_list: List[str]):
    activations_list = []
    for name in name_list:
        if name == "relu":
            activations_list.append(nn.ReLU())
        elif name == "leaky_relu":
            activations_list.append(nn.LeakyReLU())
        elif name == "sigmoid":
            activations_list.append(nn.Sigmoid())
        elif name == "tanh":
            activations_list.append(nn.Tanh())
        elif name is None or name == "none":
            activations_list.append(nn.Identity())
        else:
            print(f"Unrecognized activation: {name}")
    return activations_list


def process_args(
    args: DictConfig,
    phase: Literal["align", "adapt"] = "align",
    inference_task_ids: List[int] = [7],
) -> DictConfig:

    args.seed = np.random.randint(2**31)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)

    for domain in ["source", "target"]:
        morph = args[domain + "_morph"]

        if morph == "ant":
            args[domain + "_state_dim"] = 29
            args[domain + "_action_dim"] = 8
        elif morph == "point":
            args[domain + "_state_dim"] = 6
            args[domain + "_action_dim"] = 2
        elif morph == "maze2d":
            args[domain + "_state_dim"] = 4
            args[domain + "_action_dim"] = 2
        else:
            print("Unrecognized env_id:", args[domain + "_env_id"])
            raise ValueError

    if args.maze_type == "umaze":
        proxy_task_ids = list(range(1, 8))
        args.all_task_ids = list(range(1, 8))
        args.num_task_ids = 7
    elif args.maze_type == "medium":
        proxy_task_ids = list(range(1, 27))
        args.all_task_ids = list(range(1, 27))
        args.num_task_ids = 26
    elif args.maze_type == "large":
        proxy_task_ids = list(range(1, 47))
        args.all_task_ids = list(range(1, 47))
        args.num_task_ids = 46
    else:
        print("Unrecognized maze_type:", args.maze_type)
        raise ValueError

    for inference_task_id in inference_task_ids:
        proxy_task_ids.remove(inference_task_id)

    if phase == "align":
        args.task_ids = proxy_task_ids
    elif phase == "adapt":
        args.task_ids = inference_task_ids

    return args


def calc_accuracy(logits, labels):
    mask = (logits > 0.5).long()
    correct = torch.sum(mask * labels + (1 - mask) * (1 - labels))
    return correct / len(logits)


def sigmoid_cross_entropy_with_logits(logits, labels):
    x = logits.unsqueeze(1)
    x = torch.cat((x, torch.zeros_like(x, device=x.device)), dim=-1)
    x, _ = torch.max(x, dim=1)
    x = x - logits * labels + torch.log(1 + torch.exp(-torch.abs(logits)))
    return torch.mean(x)


def logsigmoid(a):
    return -torch.nn.Softplus()(-a)


def logit_bernoulli_entropy(logits):
    ent = (1. - torch.sigmoid(logits)) * logits - logsigmoid(logits)
    return torch.mean(ent)
