from typing import List

import torch
import torch.nn as nn


class DAILAgent(nn.Module):

    def __init__(self, args):
        super().__init__()
        self.cond_dim = args.num_task_ids
        self.source_policy = MLP(args.models.source_policy)
        self.target_policy = MLP(args.models.target_policy)
        self.state_map = MLP(args.models.state_map)
        self.action_map = MLP(args.models.action_map)
        self.inv_state_map = MLP(args.models.inv_state_map)
        self.dynamics_model = MLP(args.models.dynamics_model)
        self.discriminator = MLP(args.models.discriminator)
        self.decode_with_state = args.decode_with_state if hasattr(
            args, 'decode_with_state') else False
        if args.image_observation:
            from common.ours.models import ConvDecoder, ConvEncoder
            self.source_image_encoder = ConvEncoder(
                in_channels=3,
                out_dim=args.image_state_dim,
                coord_conv=args.use_coord_conv,
                pretrained=args.pretrained,
            )
            self.target_image_encoder = ConvEncoder(
                in_channels=3,
                out_dim=args.image_state_dim,
                coord_conv=args.use_coord_conv,
                pretrained=args.pretrained,
            )
            self.image_encoder = self.target_image_encoder  # used in evaluate
            if args.use_image_decoder:
                self.source_image_decoder = ConvDecoder(
                    image_latent_dim=args.image_state_dim, )
                self.target_image_decoder = ConvDecoder(
                    image_latent_dim=args.image_state_dim, )

            else:
                self.source_image_decoder = None
                self.target_image_decoder = None

    def forward(self, target_obs, task_ids):
        source_obs = self.state_map(target_obs)
        source_obs_ = torch.cat((source_obs, task_ids), dim=-1)
        source_act = self.source_policy(source_obs_)
        if self.decode_with_state:
            source_act = torch.cat((source_act, target_obs), dim=-1)
        target_act = self.action_map(source_act)
        return target_act


class MLP(nn.Module):

    def __init__(self, params):
        super().__init__()

        hid_dims: List[int] = params["hid_dims"]
        activations: List[nn.Module] = get_activations(params["activations"])
        depth = len(hid_dims)

        layers = [
            nn.Linear(params["in_dim"], hid_dims[0]),
            activations[0],
        ]
        for i in range(1, depth):
            lin = nn.Linear(hid_dims[i - 1], hid_dims[i])
            act = activations[i]
            layers.extend([lin, act])

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


def get_activations(name_list: List[str]):

    activations_list = []
    for name in name_list:
        if name == "relu":
            activations_list.append(nn.ReLU())
        elif name == "leaky_relu":
            activations_list.append(nn.LeakyReLU())
        elif name == "sigmoid":
            activations_list.append(nn.Sigmoid())
        elif name == "mish":
            activations_list.append(nn.Mish())
        elif name is None or name == "none":
            activations_list.append(nn.Identity())
        else:
            print(f"Unrecognized activation: {name}")
    return activations_list
