import torch
import torch.nn as nn
import torch.nn.functional as F
from neural_utils import MLP, SigmoidBias, LogisticRegression
from valuation_func import *


class ValuationModule():
    '''
    Valuation module have a set of valuation functions and a function to call them.
    '''

    def __init__(self, lang, perception_module):
        self.pm = perception_module
        self.lang = lang

    def get_params(self):
        """
        Return the list of larnable parameters.
        """
        pass

    def eval(self, atom, zs):
        """
        return the probability for given atom
        should be overridden for each different perception module.
        """
        pass


class YOLOValuationModule(nn.Module):
    def __init__(self, lang, device, pretrained=True):
        super().__init__()
        self.lang = lang
        self.device = device
        self.classes = ['red square', 'red circle', 'red triangle',
                        'yellow square', 'yellow circle',  'yellow triangle',
                        'blue square', 'blue circle', 'blue triangle']
        self.layers, self.vfs = self.init_valuation_functions(
            device, pretrained)
        # attr_term -> vector representation dic
        self.attrs = self.init_attr_encodings(device)
        self.pretrained = pretrained
        #self.preprocess = YOLOPreprocess(device)

    def init_valuation_functions(self, device, pretrained):
        layers = []
        vfs = {}  # pred name -> valuation function
        # ModuleList can be indexed like a regular Python list, but modules
        # it contains are properly registered, and will be visible by all Module methods.
        v_color = YOLOColorValuationFunction()
        vfs['color'] = v_color
        v_shape = YOLOShapeValuationFunction()
        vfs['shape'] = v_shape
        v_in = YOLOInValuationFunction()
        vfs['in'] = v_in
        v_closeby = YOLOClosebyValuationFunction(device)
        vfs['closeby'] = v_closeby
        v_online = YOLOOnlineValuationFunction(device)
        vfs['online'] = v_online
        if pretrained:
            vfs['closeby'].load_state_dict(torch.load(
                'weights/closeby_pretrain.pt', map_location=device))
            vfs['closeby'].eval()
            vfs['online'].load_state_dict(torch.load(
                'weights/online_pretrain.pt', map_location=device))
            vfs['online'].eval()
            print('Pretrained  neural predicates have been loaded!')
        return nn.ModuleList([v_color, v_shape, v_in, v_closeby, v_online]), vfs

    def init_attr_encodings(self, device):
        # encode color and shape into one-hot encoding
        attr_names = ['color', 'shape']
        attrs = {}
        for dtype_name in attr_names:
            for term in self.lang.get_by_dtype_name(dtype_name):
                term_index = self.lang.term_index(term)
                num_classes = len(self.lang.get_by_dtype_name(dtype_name))
                one_hot = F.one_hot(torch.tensor(
                    term_index).to(device), num_classes=num_classes)
                one_hot.to(device)
                attrs[term] = one_hot
        return attrs

    def forward(self, zs, atom):
        """
        Convert object-centric representations to a valuation vector.
        TODO: FIRST:)
        """
        # term: logical term
        # arg: vector representation of the term
        #zs = self.preprocess(zs)
        if atom.pred.name in self.vfs:
            args = [self.ground_to_vector(term, zs) for term in atom.terms]
            # call valuation function
            return self.vfs[atom.pred.name](*args)
        else:
            return torch.zeros((zs.size(0), )).to(
                torch.float32).to(self.device)

    def ground_to_vector(self, term, zs):
        term_index = self.lang.term_index(term)
        if term.dtype.name == 'object':
            return zs[:, term_index]
        elif term.dtype.name == 'color' or term.dtype.name == 'shape':
            return self.attrs[term]
        else:
            # term == image
            return None

    def get_params(self):
        return self.vfs['closeby'].parameters()

    def class_to_attrs(self, cid):
        # number of classes
        # nc: 9

        # class names
        # names: [ 'red square', 'red circle', 'red triangle',
        # 'yellow square', 'yellow circle',  'yellow triangle',
        # 'blue square', 'blue circle', 'blue triangle']
        class_str = self.classes[cid]
        color_str, shape_str = class_str.split(' ')
        return color_str, shape_str


class SlotAttentionValuationModule(nn.Module):
    def __init__(self, lang, device, pretrained=True):
        super().__init__()
        self.lang = lang
        self.device = device
        self.colors = ["cyan", "blue", "yellow",
                       "purple", "red", "green", "gray", "brown"]
        self.shapes = ["sphere", "cube", "cylinder"]
        self.sizes = ["large", "small"]
        self.materials = ["rubber", "metal"]
        self.sides = ["left", "right"]

        self.layers, self.vfs = self.init_valuation_functions(
            device, pretrained)
        # attr_term -> vector representation dic
        #self.attrs = self.init_attr_encodings(device)

    def init_valuation_functions(self, device, pretrained):
        layers = []
        vfs = {}  # pred name -> valuation function
        # ModuleList can be indexed like a regular Python list, but modules
        # it contains are properly registered, and will be visible by all Module methods.
        v_color = SlotAttentionColorValuationFunction(device)
        vfs['color'] = v_color
        v_shape = SlotAttentionShapeValuationFunction(device)
        vfs['shape'] = v_shape
        v_in = SlotAttentionInValuationFunction(device)
        vfs['in'] = v_in
        v_size = SlotAttentionSizeValuationFunction(device)
        vfs['size'] = v_size
        v_material = SlotAttentionMaterialValuationFunction(device)
        vfs['material'] = v_material
        v_rightside = SlotAttentionRightSideValuationFunction(device)
        vfs['rightside'] = v_rightside
        v_leftside = SlotAttentionLeftSideValuationFunction(device)
        vfs['leftside'] = v_leftside
        v_front = SlotAttentionFrontValuationFunction(device)
        vfs['front'] = v_front

        if pretrained:
            vfs['rightside'].load_state_dict(torch.load(
                'weights/rightside_pretrain.pt', map_location=device))
            vfs['rightside'].eval()
            vfs['leftside'].load_state_dict(torch.load(
                'weights/leftside_pretrain.pt', map_location=device))
            vfs['leftside'].eval()
            vfs['front'].load_state_dict(torch.load(
                'weights/front_pretrain.pt', map_location=device))
            vfs['front'].eval()
            print('Pretrained  neural predicates have been loaded!')
        return nn.ModuleList([v_color, v_shape, v_in, v_size, v_material, v_rightside, v_leftside, v_front]), vfs

    """
    def init_attr_encodings(self, device):
        # encode color and shape into one-hot encoding
        attr_names = ['color', 'shape']
        attrs = {}
        for dtype_name in attr_names:
            for term in self.lang.get_by_dtype_name(dtype_name):
                term_index = self.lang.term_index(term)
                num_classes = len(self.lang.get_by_dtype_name(dtype_name))
                one_hot = F.one_hot(torch.tensor(
                    term_index).to(device), num_classes=num_classes)
                one_hot.to(device)
                attrs[term] = one_hot
        return attrs
    """

    def forward(self, zs, atom):
        """
        Convert object-centric representations to a valuation vector.
        TODO: FIRST:)
        """
        # term: logical term
        # arg: vector representation of the term
        #zs = self.preprocess(zs)
        args = [self.ground_to_vector(term, zs) for term in atom.terms]
        # call valuation function
        return self.vfs[atom.pred.name](*args)

    def ground_to_vector(self, term, zs):
        term_index = self.lang.term_index(term)
        if term.dtype.name == 'object':
            return zs[:, term_index]
        elif term.dtype.name == 'image':
            return None
        else:
            # other attributes
            return self.term_to_onehot(term, batch_size=zs.size(0))

    def term_to_onehot(self, term, batch_size):
        if term.dtype.name == 'color':
            return self.to_onehot_batch(self.colors.index(term.name), len(self.colors), batch_size)
        elif term.dtype.name == 'shape':
            return self.to_onehot_batch(self.shapes.index(term.name), len(self.shapes), batch_size)
        elif term.dtype.name == 'material':
            return self.to_onehot_batch(self.materials.index(term.name), len(self.materials), batch_size)
        elif term.dtype.name == 'size':
            return self.to_onehot_batch(self.sizes.index(term.name), len(self.sizes), batch_size)
        elif term.dtype.name == 'side':
            return self.to_onehot_batch(self.sides.index(term.name), len(self.sides), batch_size)
        else:
            assert True, 'Invalid term: ' + str(term)

    def to_onehot_batch(self, i, length, batch_size):
        onehot = torch.zeros(batch_size, length, ).to(self.device)
        onehot[:, i] = 1.0
        return onehot

    def to_onehot(self, i, length):
        onehot = torch.zeros(length, ).to(self.device)
        onehot[i] = 1.0
        return onehot
