import numpy as np
import torch.nn as nn
import torch
from logic_utils import get_index_by_predname


from matplotlib.patches import Polygon
import matplotlib.pyplot as plt




class NSFReasoner_mi(nn.Module):
    """The Neuro-Symbolic Forward Reasoner.

    Args:
        perception_model (nn.Module): The perception model.
        facts_converter (nn.Module): The facts converter module.
        infer_module (nn.Module): The differentiable forward-chaining inference module.
        atoms (list(atom)): The set of ground atoms (facts).
    """

    def __init__(self,  V_T, check_existence_model,infer_module_MI, atoms, bk, clauses, atoms_mi,  train=False):
        super().__init__()

        self.atoms = atoms
        self.V_T = V_T
        self.bk = bk
        self.clauses = clauses
        self._train = train
        self.atoms_mi = atoms_mi
        self.ce = check_existence_model
        self.im_mi = infer_module_MI

    def get_params(self):
        return self.im_mi.get_params()  # + self.fc.get_params()

    def forward(self, x):

        V_0_mi = self.ce(self.atoms, self.atoms_mi, self.V_T)
        V_T_mi = self.im_mi (V_0_mi)
        a = np.where(V_T_mi>0.91)[1]
        b = np.array(self.atoms_mi)
        c = b[a]


        return  V_T_mi

    def predict(self, v, predname):
        """Extracting a value from the valuation tensor using a given predicate.
        """
        # v: batch * |atoms|
        target_index = get_index_by_predname(
            pred_str=predname, atoms=self.atoms)
        return v[:, target_index]

    def predict_multi(self, v, prednames):
        """Extracting values from the valuation tensor using given predicates.

        prednames = ['kp1', 'kp2', 'kp3']
        """
        # v: batch * |atoms|
        target_indices = []
        for predname in prednames:
            target_index = get_index_by_predname(
                pred_str=predname, atoms=self.atoms)
            target_indices.append(target_index)
        prob = torch.cat([v[:, i].unsqueeze(-1)
                         for i in target_indices], dim=1)
        B = v.size(0)
        N = len(prednames)
        assert prob.size(0) == B and prob.size(
            1) == N, 'Invalid shape in the prediction.'
        return prob

    def print_program(self):
        """Print asummary of logic programs using continuous weights.
        """
        print('====== LEARNED PROGRAM ======')
        C = self.clauses
        Ws_softmaxed = torch.softmax(self.im_mi.W, 1)
        program= []
        for i, W_ in enumerate(Ws_softmaxed):
            max_i = np.argmax(W_.detach().cpu().numpy())
            print('C_'+str(i)+': ',
                  C[max_i], W_[max_i].detach().cpu().item())
    def get_params_softmax(self):
        Ws_softmaxed = torch.softmax(self.im_mi.W, 1)
        return Ws_softmaxed


    def store_program(self):
        """Print asummary of logic programs using continuous weights.
        """
        C = self.clauses
        Ws_softmaxed = torch.softmax(self.im_mi.W, 1)
        program= []
        for i, W_ in enumerate(Ws_softmaxed):
            max_i = np.argmax(W_.detach().cpu().numpy())

            program.append(('C_'+str(i)+': ',
                  C[max_i], W_[max_i].detach().cpu().item()))
        return program
    def print_valuation_batch(self, valuation, n=40):
        self.print_program()
        for b in range(valuation.size(0)):
            print('===== BATCH: ', b, '=====')
            v = valuation[b].detach().cpu().numpy()
            idxs = np.argsort(-v)
            for i in idxs:
                if v[i] > 0.1:
                    print(i, self.atoms[i], ': ', round(v[i], 3))

    def get_valuation_text(self, valuation):
        text_batch = ''  # texts for each batch
        for b in range(valuation.size(0)):
            top_atoms = self.get_top_atoms(valuation[b].detach().cpu().numpy())
            text = '----BATCH ' + str(b) + '----\n'
            text += self.atoms_to_text(top_atoms)
            text += '\n'
            # texts.append(text)
            text_batch += text
        return text_batch

    def get_top_atoms(self, v):
        top_atoms = []
        for i, atom in enumerate(self.atoms):
            if v[i] > 0.7:
                top_atoms.append(atom)
        return top_atoms

    def atoms_to_text(self, atoms):
        text = ''
        for atom in atoms:
            text += str(atom) + ', '
        return text