import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class ENC_Image(nn.Module):
    def __init__(self, in_dim, z_dim, H, setting):
        super(ENC_Image, self).__init__()

        kernel = setting["kernel"]
        stride = setting["stride"]
        padding = setting["padding"]
        n_selected = setting["n_selected"]
        h_dim = 64

        self.nn_encoder_im = nn.Sequential(
            nn.Conv2d(in_dim*H, h_dim // 2, kernel_size=kernel[0],
                      stride=stride[0], padding=padding[0]),
            nn.ReLU(),
            nn.Conv2d(h_dim // 2, h_dim, kernel_size=kernel[1],
                      stride=stride[1], padding=padding[1]),
            nn.ReLU(),
            nn.Conv2d(h_dim, h_dim, kernel_size=kernel[2],
                      stride=stride[2], padding=padding[2]),
        )

        self.nn_encoder_fc = nn.Sequential(
            nn.Linear(h_dim*n_selected, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, z_dim),
        )

        self.h_dim = h_dim
        self.z_dim = z_dim
        self.im_dim = int(math.sqrt(n_selected))

    def forward(self, x):
        y = self.nn_encoder_im(x)
        y = torch.flatten(y, start_dim=1)
        y = self.nn_encoder_fc(y)
        return y
    

class ENC_Feature(nn.Module):
    def __init__(self, in_dim, z_dim):
        super(ENC_Feature, self).__init__()

        self.nn_encoder_fc = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.Linear(in_dim, 256),
            nn.ReLU(),
            nn.Linear(256, z_dim),
        )
        self.z_dim = z_dim

    def forward(self, x):
        y = self.nn_encoder_fc(x)
        return y


class Latent2Action(nn.Module):
    def __init__(self, in_dim, act_dim, has_gripper=False):
        super(Latent2Action, self).__init__()

        if has_gripper: move_dim = act_dim - 1
        else: move_dim = act_dim

        self.nn_bc = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.Linear(in_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, move_dim*2),
        )

        if has_gripper:
            self.nn_grasp = nn.Sequential(
                nn.LayerNorm(in_dim),
                nn.Linear(in_dim, 256),
                nn.ReLU(),
                nn.Linear(256, 256),
                nn.ReLU(),
                nn.Linear(256, 2),
            )
        else:
            self.nn_grasp = None

        self.move_dim = move_dim
        self.has_gripper = has_gripper
        self.criterion = torch.nn.CrossEntropyLoss()


    def forward(self, x_t, a_t):
        x = self.nn_bc(x_t)
        act_mean, act_std = torch.split(x, self.move_dim, dim=-1)
        act_std = F.softplus(act_std) + 1e-3

        logp = -0.5*(math.log(2*torch.pi) \
                + 2.0*torch.log(act_std) + (act_mean-a_t[:,:self.move_dim])**2/act_std**2)
        bc_loss = -torch.mean(torch.sum(logp, axis=1))

        a_mse = torch.mean((a_t[:,:self.move_dim] - act_mean)**2, axis=1)

        if self.has_gripper:
            y = self.nn_grasp(x_t)
            grasp_loss = self.criterion(y, a_t[:,-1].type(torch.int64))
            bc_loss += 1e0 * grasp_loss
            y_max = torch.argmax(y, axis=1)
            g_mse = (a_t[:,-1]-y_max)**2
        else:
            g_mse = torch.tensor([0.0])
        return bc_loss, [a_mse, g_mse]
    
    def get_action(self, x_t):
        x = self.nn_bc(x_t)
        act_mean, _ = torch.split(x, self.move_dim, dim=-1)
        act_mean = act_mean[0].detach().cpu().tolist()
        
        if self.has_gripper:
            y = self.nn_grasp(x_t)
            y_max = torch.argmax(y, axis=1).detach().cpu().tolist()
            action = act_mean + y_max
        else:
            action = act_mean
        return np.array(action)


class BC_Model(nn.Module):
    def __init__(self, input_type, obs_dim, action_dim,
                 consider_gripper, config, setting):
        super(BC_Model, self).__init__()

        self.config = config
        z_s_dim = config.z_dep_dim

        if input_type == "image":
            self.state_single_encoder = ENC_Image(obs_dim, z_s_dim, 1, setting)
        elif input_type == "feature":
            self.state_single_encoder = ENC_Feature(obs_dim, z_s_dim)

        self.latent2action = Latent2Action(z_s_dim, action_dim, consider_gripper)

        self.z_s_dim = z_s_dim
        self.input_type = input_type

    def forward(self, x_list, a_list):
        if len(x_list.size()) == 3:
            x_t = x_list[:,0,:]
            a_t = a_list[:,0,:]
        else:
            x_t = x_list
            a_t = a_list
        z_s = self.state_single_encoder(x_t)
        act_loss, act_mse = self.latent2action(z_s, a_t)
        return act_loss, act_mse
    
    def get_action(self, x_t):
        z_t = self.state_single_encoder(x_t)
        a_hat = self.latent2action.get_action(z_t)
        return a_hat
