import torch
import torch.nn as nn
from neural_utils import MLP, SigmoidBias, LogisticRegression, BinaryClassification


class ValuationFunction(nn.Module):
    def __init__(self):
        super(ValuationFunction, self).__init__()

    def forward(self, x):
        return x


##
# Valuation functions for YOLO
##

class YOLOColorValuationFunction(nn.Module):
    def __init__(self):
        super(YOLOColorValuationFunction, self).__init__()

    def forward(self, z, a):
        """
        z: 2-d tensor B * d
            object-centric representation
            [x1, y1, x2, y2, color1, color2, color3, shape1, shape2, shape3, prob]
        a: 1-d tensor |color|
            one-hot encoding of the color
        """
        z_color = z[:, 4:7]
        a_batch = a.repeat((z.size(0), 1))
        return (a_batch * z_color).sum(dim=1)


class YOLOShapeValuationFunction(nn.Module):
    def __init__(self):
        super(YOLOShapeValuationFunction, self).__init__()

    def forward(self, z, a):
        """
        z: 2-d tensor B * d
            should be specified as object i
            object-centric representation
            [x1, y1, x2, y2, color1, color2, color3, shape1, shape2, shape3, prob]
        a: 1-d tensor |color|
            one-hot encoding of the color
        """
        z_shape = z[:, 7:10]
        # a_batch = a.repeat((z.size(0), 1))  # one-hot encoding for batch
        return (a * z_shape).sum(dim=1)


class YOLOInValuationFunction(nn.Module):
    def __init__(self):
        super(YOLOInValuationFunction, self).__init__()

    def forward(self, z, img):
        """
        z: 2-d tensor B * d
            object-centric representation
            [x1, y1, x2, y2, color1, color2, color3, shape1, shape2, shape3, prob]
        """
        return z[:, -1]


class YOLOClosebyValuationFunction(nn.Module):
    def __init__(self, device):
        super(YOLOClosebyValuationFunction, self).__init__()
        self.device = device
        # input is a distance (scalar)
        self.logi = LogisticRegression(input_dim=1)
        self.logi.to(device)

    def forward(self, z_1, z_2):
        """
        z_1, z_2: 2-d tensor B * d
            object-centric representation
            [x1, y1, x2, y2, color1, color2, color3, shape1, shape2, shape3, prob]
        """
        c_1 = self.to_center(z_1)
        c_2 = self.to_center(z_2)
        dist = torch.norm(c_1 - c_2, dim=0).unsqueeze(-1)
        self.diff = dist  # for debug, to be removed
        return self.logi(dist).squeeze()

    def to_center(self, z):
        x = (z[:, 0] + z[:, 2]) / 2
        y = (z[:, 1] + z[:, 3]) / 2
        return torch.stack((x, y))


class YOLOOnlineValuationFunction(nn.Module):
    def __init__(self, device):
        super(YOLOOnlineValuationFunction, self).__init__()
        self.logi = LogisticRegression(input_dim=1)
        self.logi.to(device)

    def forward(self, z_1, z_2, z_3, z_4, z_5):
        """
        z: 3-d tensor B * e * d
            object-centric representation
            [x1, y1, x2, y2, color1, color2, color3, shape1, shape2, shape3, prob]
        """
        X = torch.stack([self.to_center_x(z)
                        for z in [z_1, z_2, z_3, z_4, z_5]], dim=1).unsqueeze(-1)
        Y = torch.stack([self.to_center_y(z)
                        for z in [z_1, z_2, z_3, z_4, z_5]], dim=1).unsqueeze(-1)
        # add bias term
        X = torch.cat([torch.ones_like(X), X], dim=2)
        X_T = torch.transpose(X, 1, 2)
        W = torch.matmul(torch.matmul(
            torch.inverse(torch.matmul(X_T, X)), X_T), Y)
        diff = torch.norm(Y - torch.sum(torch.transpose(W, 1, 2)
                          * X, dim=2).unsqueeze(-1), dim=1)
        self.diff = diff
        return self.logi(diff).squeeze()

    def to_center_x(self, z):
        x = (z[:, 0] + z[:, 2]) / 2
        return x

    def to_center_y(self, z):
        y = (z[:, 1] + z[:, 3]) / 2
        return y


class ___YOLOOnlineValuationFunction(nn.Module):
    def __init__(self, device):
        super(YOLOOnlineValuationFunction, self).__init__()
        self.mlp = MLP(in_channels=10, out_channels=1, hidden_dim=512)
        #self.mlp = LogisticRegression(input_dim=8)
        #self.mlp = LogisticRegression(input_dim=10)
        self.mlp.to(device)

    def forward(self, z_1, z_2, z_3, z_4, z_5):
        """
        z: 3-d tensor B * e * d
            object-centric representation
            [x1, y1, x2, y2, color1, color2, color3, shape1, shape2, shape3, prob]
        """
        c_1 = self.to_center(z_1)
        c_2 = self.to_center(z_2)
        c_3 = self.to_center(z_3)
        c_4 = self.to_center(z_4)
        c_5 = self.to_center(z_5)
        x = torch.cat([c_1, c_2, c_3, c_4, c_5]).view((z_1.size(0), -1))
        #print("x: ", x)
        #print("online: ", self.mlp(x).squeeze())
        return self.mlp(x).squeeze()

    def to_center(self, z):
        x = (z[:, 0] + z[:, 2]) / 2
        y = (z[:, 1] + z[:, 3]) / 2
        return torch.stack((x, y))


##
# Valuation functions for slot attention
##


class SlotAttentionInValuationFunction(nn.Module):
    def __init__(self, device):
        super(SlotAttentionInValuationFunction, self).__init__()

    def forward(self, z, a):
        """
        z: 2-d tensor B * d
            object-centric representation
            obj_prob + coords + shape + size + material + color
            coords = [x,y,z]
            CLASSES = {
                "shape": ["sphere", "cube", "cylinder"],
                "size": ["large", "small"],
                "material": ["rubber", "metal"],
                "color": ["cyan", "blue", "yellow", "purple", "red", "green", "gray", "brown"],}
        a: 1-d tensor |shape|
            one-hot encoding of the shape
        """
        # prob of object: object indicator
        return z[:, 0]


class SlotAttentionShapeValuationFunction(nn.Module):
    def __init__(self, device):
        super(SlotAttentionShapeValuationFunction, self).__init__()

    def forward(self, z, a):
        """
        z: 2-d tensor B * d
            object-centric representation
            obj_prob + coords + shape + size + material + color
            coords = [x,y,z]
            CLASSES = {
                "shape": ["sphere", "cube", "cylinder"],
                "size": ["large", "small"],
                "material": ["rubber", "metal"],
                "color": ["cyan", "blue", "yellow", "purple", "red", "green", "gray", "brown"],}
        a: 1-d tensor |shape|
            one-hot encoding of the shape
        """
        z_shape = z[:, 4:7]
        #a_batch = a.repeat((z.size(0), 1))
        return (a * z_shape).sum(dim=1)


class SlotAttentionSizeValuationFunction(nn.Module):
    def __init__(self, device):
        super(SlotAttentionSizeValuationFunction, self).__init__()

    def forward(self, z, a):
        """
        z: 2-d tensor B * d
            object-centric representation
            obj_prob + coords + shape + size + material + color
            coords = [x,y,z]
            CLASSES = {
                "shape": ["sphere", "cube", "cylinder"],
                "size": ["large", "small"],
                "material": ["rubber", "metal"],
                "color": ["cyan", "blue", "yellow", "purple", "red", "green", "gray", "brown"],}
        a: 1-d tensor |shape|
            one-hot encoding of the shape
        """
        z_size = z[:, 7:9]
        #a_batch = a.repeat((z.size(0), 1))
        return (a * z_size).sum(dim=1)


class SlotAttentionMaterialValuationFunction(nn.Module):
    def __init__(self, device):
        super(SlotAttentionMaterialValuationFunction, self).__init__()

    def forward(self, z, a):
        """
        z: 2-d tensor B * d
            object-centric representation
            obj_prob + coords + shape + size + material + color
            coords = [x,y,z]
            CLASSES = {
                "shape": ["sphere", "cube", "cylinder"],
                "size": ["large", "small"],
                "material": ["rubber", "metal"],
                "color": ["cyan", "blue", "yellow", "purple", "red", "green", "gray", "brown"],}
        a: 1-d tensor |shape|
            one-hot encoding of the shape
        """
        z_material = z[:, 9:11]
        return (a * z_material).sum(dim=1)


class SlotAttentionColorValuationFunction(nn.Module):
    def __init__(self, device):
        super(SlotAttentionColorValuationFunction, self).__init__()

    def forward(self, z, a):
        """
        z: 2-d tensor B * d
            object-centric representation
            obj_prob + coords + shape + size + material + color
            coords = [x,y,z]
            CLASSES = {
                "shape": ["sphere", "cube", "cylinder"],
                "size": ["large", "small"],
                "material": ["rubber", "metal"],
                "color": ["cyan", "blue", "yellow", "purple", "red", "green", "gray", "brown"],}
        a: 1-d tensor |color|
            one-hot encoding of the color
        """
        z_color = z[:, 11:19]
        #a_batch = a.repeat((z.size(0), 1))
        return (a * z_color).sum(dim=1)

# TODO


class SlotAttentionSideValuationFunction(nn.Module):
    def __init__(self, device):
        super(SlotAttentionSideValuationFunction, self).__init__()
        self.logi = LogisticRegression(input_dim=3, output_dim=2)
        self.logi.to(device)

    def forward(self, z, a):
        """
        z: 2-d tensor B * d
            object-centric representation
            obj_prob + coords + shape + size + material + color
            coords = [x,y,z]
            CLASSES = {
                "shape": ["sphere", "cube", "cylinder"],
                "size": ["large", "small"],
                "material": ["rubber", "metal"],
                "color": ["cyan", "blue", "yellow", "purple", "red", "green", "gray", "brown"],}
        a: 1-d tensor |color|
            one-hot encoding of the color
        """
        z_xyz = z[:, 1:4]
        z_side = self.logi(z_xyz)
        #a_batch = a.repeat((z.size(0), 1))
        return (a * z_size).sum(dim=1)


class SlotAttentionRightSideValuationFunction(nn.Module):
    def __init__(self, device):
        super(SlotAttentionRightSideValuationFunction, self).__init__()
        self.logi = LogisticRegression(input_dim=1, output_dim=1)
        self.logi.to(device)

    def forward(self, z):
        """
        z: 2-d tensor B * d
            object-centric representation
            obj_prob + coords + shape + size + material + color
            coords = [x,y,z]
            CLASSES = {
                "shape": ["sphere", "cube", "cylinder"],
                "size": ["large", "small"],
                "material": ["rubber", "metal"],
                "color": ["cyan", "blue", "yellow", "purple", "red", "green", "gray", "brown"],}
        a: 1-d tensor |color|
            one-hot encoding of the color
        """
        z_x = z[:, 1].unsqueeze(-1)  # (B, )
        prob = self.logi(z_x).squeeze()  # (B, )
        objectness = z[:, 0]  # (B, )
        return prob * objectness


class SlotAttentionLeftSideValuationFunction(nn.Module):
    def __init__(self, device):
        super(SlotAttentionLeftSideValuationFunction, self).__init__()
        self.logi = LogisticRegression(input_dim=1, output_dim=1)
        self.logi.to(device)

    def forward(self, z):
        """
        z: 2-d tensor B * d
            object-centric representation
            obj_prob + coords + shape + size + material + color
            coords = [x,y,z]
            CLASSES = {
                "shape": ["sphere", "cube", "cylinder"],
                "size": ["large", "small"],
                "material": ["rubber", "metal"],
                "color": ["cyan", "blue", "yellow", "purple", "red", "green", "gray", "brown"],}
        a: 1-d tensor |color|
            one-hot encoding of the color
        """
        z_x = z[:, 1].unsqueeze(-1)  # (B, )
        prob = self.logi(z_x).squeeze()  # (B, )
        objectness = z[:, 0]  # (B, )
        return prob * objectness


class SlotAttentionFrontValuationFunction(nn.Module):
    def __init__(self, device):
        super(SlotAttentionFrontValuationFunction, self).__init__()
        self.logi = LogisticRegression(input_dim=6, output_dim=1)
        self.logi.to(device)

    def forward(self, z_1, z_2):
        """
        z: 2-d tensor B * d
            object-centric representation
            obj_prob + coords + shape + size + material + color
            coords = [x,y,z]
            CLASSES = {
                "shape": ["sphere", "cube", "cylinder"],
                "size": ["large", "small"],
                "material": ["rubber", "metal"],
                "color": ["cyan", "blue", "yellow", "purple", "red", "green", "gray", "brown"],}
        a: 1-d tensor |color|
            one-hot encoding of the color
        """
        xyz_1 = z_1[:, 1:4]
        xyz_2 = z_2[:, 1:4]
        xyzxyz = torch.cat([xyz_1, xyz_2], dim=1)
        prob = self.logi(xyzxyz).squeeze()  # (B,)
        objectness = z_1[:, 0] * z_2[:, 0]  # (B,)
        return prob * objectness


class SlotAttentionSameXRangeValuationFunction(nn.Module):
    def __init__(self, device):
        super(SlotAttentionSameXRangeValuationFunction, self).__init__()
        self.mlp = BinaryClassification(input_dim=1)
        self.mlp.to(device)

    def forward(self, z_1, z_2):
        """
        z: 2-d tensor B * d
            object-centric representation
            obj_prob + coords + shape + size + material + color
            coords = [x,y,z]
            CLASSES = {
                "shape": ["sphere", "cube", "cylinder"],
                "size": ["large", "small"],
                "material": ["rubber", "metal"],
                "color": ["cyan", "blue", "yellow", "purple", "red", "green", "gray", "brown"],}
        a: 1-d tensor |color|
            one-hot encoding of the color
        """
        dist = torch.norm(z_1[:, 1] - z_2[:, 1], dim=0).unsqueeze(-1)
        return self.mlp(dist).squeeze()

# maybe not required


class SlotAttentionInFrontValuationFunction(nn.Module):
    def __init__(self, device):
        super(SlotAttentionInFrontValuationFunction, self).__init__()
        self.mlp = BinaryClassification(input_dim=1)
        self.mlp.to(device)

    def forward(self, z, a):
        """
        z: 2-d tensor B * d
            object-centric representation
            obj_prob + coords + shape + size + material + color
            coords = [x,y,z]
            CLASSES = {
                "shape": ["sphere", "cube", "cylinder"],
                "size": ["large", "small"],
                "material": ["rubber", "metal"],
                "color": ["cyan", "blue", "yellow", "purple", "red", "green", "gray", "brown"],}
        a: 1-d tensor |color|
            one-hot encoding of the color
        """
        z_color = z[:, 11:19]
        #a_batch = a.repeat((z.size(0), 1))
        return (a * z_color).sum(dim=1)
