# DPL model for BOIA
from utils.args import *
from models.utils.utils_problog import *
from models.sddoiadpl import SDDOIADPL
from utils.losses import SDDOIA_Cumulative
from utils.dpl_loss import SDDOIA_DPL


def get_parser() -> ArgumentParser:
    """Returns the parser

    Returns:
        argparse: argument parser
    """
    parser = ArgumentParser(description="Learning via" "Concept Extractor .")
    add_management_args(parser)
    add_experiment_args(parser)
    return parser


class BoiaDSLDPL(SDDOIADPL):
    """DPL MODEL FOR BOIA"""

    NAME = "boiadsldpl"

    """
    BOIA
    """

    def __init__(
        self,
        encoder,
        n_images=1,
        c_split=(),
        args=None,
        model_dict=None,
        n_facts=21,
        nr_classes=4,
    ):
        """Initialize method

        Args:
            self: instance
            encoder (nn.Module): encoder
            n_images (int, default=1): number of images
            c_split: concept splits
            args: command line arguments
            model_dict (default=None): model dictionary
            n_facts (int, default=21): number of concepts
            nr_classes (int, nr_classes): number of classes for the multiclass classification problem
            retun_embeddings (bool): whether to return embeddings

        Returns:
            None: This function does not return a value.
        """
        super(BoiaDSLDPL, self).__init__(
            encoder,
            n_images=n_images,
            c_split=c_split,
            args=args,
            model_dict=model_dict,
            n_facts=n_facts,
            nr_classes=nr_classes,
        )

        # recompute the matrices
        self.FS_w_q = torch.nn.Parameter(torch.randn([32, 4], requires_grad=True).to(self.device))
        self.L_w_q = torch.nn.Parameter(torch.randn([64, 2], requires_grad=True).to(self.device))
        self.R_w_q = torch.nn.Parameter(torch.randn([64, 2], requires_grad=True).to(self.device))


    def problog_inference(self, pCs, query=None):
        """Performs ProbLog inference to retrieve the worlds probability distribution P(w). Works with two encoded bits.

        Args:
            self: instance
            pCs: probability of concepts
            query (default=None): query

        Returns:
            query_prob: query probability
            worlds_prob: worlds probability
        """

        # for forward
        tl_green = pCs[:, :2]  # traffic light is green
        follow = pCs[:, 2:4]  # follow car ahead
        clear = pCs[:, 4:6]  # road is clear

        # for stop
        tl_red = pCs[:, 6:8]  # traffic light is red
        t_sign = pCs[:, 8:10]  # traffic sign present

        A = tl_green.unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(5).unsqueeze(6)
        B = follow.unsqueeze(1).unsqueeze(3).unsqueeze(4).unsqueeze(5).unsqueeze(6)
        C = clear.unsqueeze(1).unsqueeze(2).unsqueeze(4).unsqueeze(5).unsqueeze(6)
        D = tl_red.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(5).unsqueeze(6)
        E = t_sign.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(6)

        w_FS = (
            A.multiply(B).multiply(C).multiply(D).multiply(E).view(-1, 32)
        )

        #
        labels_FS = torch.einsum("bi,ik->bk", w_FS, torch.nn.functional.softmax(self.FS_w_q, dim=-1))
        ##

        # for LEFT
        left_lane = pCs[:, 18:20]  # there is LEFT lane
        tl_green_left = pCs[:, 20:22]  # tl green on LEFT
        follow_left = pCs[:, 22:24]  # follow car going LEFT

        # for LEFT-STOP
        no_left_lane = pCs[:, 24:26]  # no lane on LEFT
        l_obs = pCs[:, 26:28]  # LEFT obstacle
        left_line = pCs[:, 28:30]  # solid line on LEFT

        AL = left_lane.unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(5).unsqueeze(6)
        BL = (
            tl_green_left.unsqueeze(1)
            .unsqueeze(3)
            .unsqueeze(4)
            .unsqueeze(5)
            .unsqueeze(6)
        )
        CL = (
            follow_left.unsqueeze(1).unsqueeze(2).unsqueeze(4).unsqueeze(5).unsqueeze(6)
        )
        DL = (
            no_left_lane.unsqueeze(1)
            .unsqueeze(2)
            .unsqueeze(3)
            .unsqueeze(5)
            .unsqueeze(6)
        )
        EL = l_obs.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(6)
        FL = left_line.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(5)

        w_L = (
            AL.multiply(BL)
            .multiply(CL)
            .multiply(DL)
            .multiply(EL)
            .multiply(FL)
            .view(-1, 64)
        )

        label_L = torch.einsum("bi,ik->bk", w_L, torch.nn.functional.softmax(self.L_w_q, dim=-1))
        ##

        # for RIGHT
        rigt_lane = pCs[:, 30:32]  # there is RIGHT lane
        tl_green_rigt = pCs[:, 32:34]  # tl green on RIGHT
        follow_rigt = pCs[:, 34:36]  # follow car going RIGHT

        # for RIGHT-STOP
        no_rigt_lane = pCs[:, 36:38]  # no lane on RIGHT
        r_obs = pCs[:, 38:40]  # RIGHT obstacle
        rigt_line = pCs[:, 40:42]  # solid line on RIGHT

        AL = rigt_lane.unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(5).unsqueeze(6)
        BL = (
            tl_green_rigt.unsqueeze(1)
            .unsqueeze(3)
            .unsqueeze(4)
            .unsqueeze(5)
            .unsqueeze(6)
        )
        CL = (
            follow_rigt.unsqueeze(1).unsqueeze(2).unsqueeze(4).unsqueeze(5).unsqueeze(6)
        )
        DL = (
            no_rigt_lane.unsqueeze(1)
            .unsqueeze(2)
            .unsqueeze(3)
            .unsqueeze(5)
            .unsqueeze(6)
        )
        EL = r_obs.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(6)
        FL = rigt_line.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(5)

        w_R = (
            AL.multiply(BL)
            .multiply(CL)
            .multiply(DL)
            .multiply(EL)
            .multiply(FL)
            .view(-1, 64)
        )

        label_R = torch.einsum("bi,ik->bk", w_R, torch.nn.functional.softmax(self.R_w_q, dim=-1))

        pred = torch.cat([labels_FS, label_L, label_R], dim=1)  # this is 8 dim

        # avoid overflow
        pred = (pred + 1e-5) / (1 + 2 * 1e-5)

        return pred
    
    def get_pred_from_prob(self, pCs, presence=True):

        pCs = pCs.unsqueeze(-1)

        pC = []
        for i in range(pCs.size(1)):
            c = torch.cat((1 - pCs[:, i], pCs[:, i]), dim=1)
            pC.append(c)
        pC = torch.cat(pC, dim=1)

        py = self.problog_inference(pC)
        return py

    @staticmethod
    def get_loss(args):
        """Loss function for the architecture

        Args:
            args: command line arguments

        Returns:
            loss: loss function

        Raises:
            err: NotImplementedError if the loss function is not available
        """
        if args.dataset in ["boia"]:
            return SDDOIA_DPL(SDDOIA_Cumulative)
        else:
            return NotImplementedError("Wrong dataset choice")
