import torch
import torch.nn as nn

from yolov5.models.experimental import attempt_load
from yolov5.utils.general import non_max_suppression

from slot_attention.model import SlotAttention_model
import sys
sys.path.insert(0, './yolov5')


class PerceptionModule():
    """
    Bridging between neural perception modules and neural predicates.
    It contains information about neural network, e.g., the output size.
    nn_id specifies perception model.
    """

    def __init__(self, nn_id, e, d, device=None):
        # super(PerceptionModule, self).__init__()
        self.nn_id = nn_id
        self.e = e  # num of entities
        self.d = d  # num of dimension
        self.device = device
        # R^x -> R^{e*d}, a neural network (YOLO, Slot attention, etc.)
        self.nn = None

    def percept(self, x):
        # return self.nn(x)
        return torch.rand((self.e, self.d))

    # to be overridden depending on the object-centric model
    def to_pos(self, z):
        pass

    def __str__(self):
        return "PerceptionModule(nn_id={}, entities={}, dimension={})".format(self.nn_id, self.e, self.d)

    def __repr__(self):
        return "PerceptionModule(nn_id={}, entities={}, dimension={})".format(self.nn_id, self.e, self.d)


class YOLOPerceptionModule(nn.Module):
    def __init__(self, nn_id, e, d, device=None, train=False):
        super().__init__()
        self.nn_id = nn_id
        self.e = e  # num of entities
        self.d = d  # num of dimension
        self.device = device
        self.train = train
        self.model = self.load_model(
            path='yolov5/weights/best.pt', device=device)
        # function to transform e * d shape, YOLO returns class labels, it should be decomposed into attributes and the probabilities.
        self.preprocess = YOLOPreprocess(device)

    def load_model(self, path, device):
        if device == 'cpu':
            # yolo_net = torch.load(path,
            #                      map_location=torch.device('cpu'))['model']
            yolo_net = torch.hub.load('ultralytics/yolov5', 'yolov5s')
        else:
            yolo_net = attempt_load(weights=path)
            yolo_net.to(device)
            if not self.train:
                for param in yolo_net.parameters():
                    param.requires_grad = False
        return yolo_net

    def forward(self, imgs):
        pred = self.model(imgs)[0]  # yolo model returns tuple
        # yolov5.utils.general.non_max_supression returns List[tensors]
        # with lengh of batch size
        # list with the length of batch size, because the number of objects can vary image to iamge

        yolo_output = self.pad_result(
            non_max_suppression(pred, max_det=self.e))
        return self.preprocess(yolo_output)

    def pad_result(self, output):
        # padding the result by zeros
        # (batch, n_obj, 6) -> (batch, n_max_obj, 6)
        padded_list = []
        for objs in output:
            if objs.size(0) < self.e:
                diff = self.e - objs.size(0)
                zero_tensor = torch.zeros((diff, 6)).to(self.device)
                padded = torch.cat([objs, zero_tensor], dim=0)
                padded_list.append(padded)
            else:
                padded_list.append(objs)
        return torch.stack(padded_list)


class SlotAttentionPerceptionModule(nn.Module):
    def __init__(self, nn_id, e, d, device, train=False):
        super().__init__()
        self.nn_id = nn_id
        self.e = e  # num of entities -> n_slots=10
        self.d = d  # num of dimension -> encoder_hidden_channels=64
        self.device = device
        self.train = train
        self.model = self.load_model(
            path='')

    def load_model(self, path):
        if self.device == 'cpu':
            # TODO
            print('GPU required')
            sa_net = SlotAttention_model(n_slots=10, n_iters=3, n_attr=18,
                                         encoder_hidden_channels=64,
                                         attention_hidden_channels=128)
            return sa_net
        else:
            # load pretrained concept embedding module
            sa_net = SlotAttention_model(n_slots=10, n_iters=3, n_attr=18,
                                         encoder_hidden_channels=64,
                                         attention_hidden_channels=128, device=self.device)
            log = torch.load(
                "slot_attention_pretrain/logs/slot-attention-clevr")  # ,
            # map_location=torch.device(device))
            sa_net.load_state_dict(log['weights'], strict=True)
            sa_net.to(self.device)

            print("Pretrained slot attention model loaded!")
            if not self.train:
                for param in sa_net.parameters():
                    param.requires_grad = False
            return sa_net

    def forward(self, imgs):
        return self.model(imgs)


class YOLOPreprocess(nn.Module):
    def __init__(self, device, img_size=128):
        super().__init__()
        self.device = device
        self.img_size = img_size
        self.classes = ['red square', 'red circle', 'red triangle',
                        'yellow square', 'yellow circle',  'yellow triangle',
                        'blue square', 'blue circle', 'blue triangle']
        self.colors = torch.stack([
            torch.tensor([1, 0, 0]).to(device),
            torch.tensor([1, 0, 0]).to(device),
            torch.tensor([1, 0, 0]).to(device),
            torch.tensor([0, 1, 0]).to(device),
            torch.tensor([0, 1, 0]).to(device),
            torch.tensor([0, 1, 0]).to(device),
            torch.tensor([0, 0, 1]).to(device),
            torch.tensor([0, 0, 1]).to(device),
            torch.tensor([0, 0, 1]).to(device)
        ])
        self.shapes = torch.stack([
            torch.tensor([1, 0, 0]).to(device),
            torch.tensor([0, 1, 0]).to(device),
            torch.tensor([0, 0, 1]).to(device),
            torch.tensor([1, 0, 0]).to(device),
            torch.tensor([0, 1, 0]).to(device),
            torch.tensor([0, 0, 1]).to(device),
            torch.tensor([1, 0, 0]).to(device),
            torch.tensor([0, 1, 0]).to(device),
            torch.tensor([0, 0, 1]).to(device)
        ])

    def forward(self, x):
        """
        def forward(self, x):
        A priprocess function to the yolo output
        [[x1, y1, x2, y2, prob, class]]
        -> [[x1, y1, x2, y2, color1, color2, color3, shape1, shape2, shape3, prob]]
        on batch computation
        """
        batch_size = x.size(0)
        obj_num = x.size(1)
        object_list = []
        for i in range(obj_num):
            zi = x[:, i]
            class_id = zi[:, -1].to(torch.int64)
            color = self.colors[class_id] * zi[:, -2].unsqueeze(-1)
            shape = self.shapes[class_id] * zi[:, -2].unsqueeze(-1)
            xyxy = zi[:, 0:4] / self.img_size
            prob = zi[:, -2].unsqueeze(-1)
            obj = torch.cat([xyxy, color, shape, prob], dim=-1)
            object_list.append(obj)
        return torch.stack(object_list, dim=1).to(self.device)
