from typing import List

import torch
import torch.nn as nn


class DAILAgent(nn.Module):

    def __init__(self, args):
        super().__init__()
        self.source_policy = MLP(args.models.source_policy)
        self.target_policy = MLP(args.models.target_policy)
        self.state_map = MLP(args.models.state_map)
        self.action_map = MLP(args.models.action_map)
        self.inv_state_map = MLP(args.models.inv_state_map)
        self.dynamics_model = MLP(args.models.dynamics_model)
        self.discriminator = MLP(args.models.discriminator)

    def forward(self, target_obs, task_ids):
        source_obs = self.state_map(target_obs)
        source_obs_ = torch.cat((source_obs, task_ids), dim=-1)
        source_act = self.source_policy(source_obs_)
        target_act = self.action_map(source_act)
        return target_act


class MLP(nn.Module):

    def __init__(self, params):
        super().__init__()

        hid_dims: List[int] = params["hid_dims"]
        activations: List[nn.Module] = get_activations(params["activations"])
        depth = len(hid_dims)

        layers = [
            nn.Linear(params["in_dim"], hid_dims[0]),
            activations[0],
        ]
        for i in range(1, depth):
            lin = nn.Linear(hid_dims[i - 1], hid_dims[i])
            act = activations[i]
            layers.extend([lin, act])

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


def get_activations(name_list: List[str]):

    activations_list = []
    for name in name_list:
        if name == "relu":
            activations_list.append(nn.ReLU())
        elif name == "leaky_relu":
            activations_list.append(nn.LeakyReLU())
        elif name == "sigmoid":
            activations_list.append(nn.Sigmoid())
        elif name == "mish":
            activations_list.append(nn.Mish())
        elif name is None or name == "none":
            activations_list.append(nn.Identity())
        else:
            print(f"Unrecognized activation: {name}")
    return activations_list
