import numpy as np
import torch.nn as nn
import torch
from logic_utils import get_index_by_predname


class VFCReasoner(nn.Module):
    def __init__(self, perception_module, facts_converter,  infer_module, atoms, bk, clauses):
        super().__init__()
        self.pm = perception_module
        self.fc = facts_converter
        self.im = infer_module
        self.atoms = atoms
        self.bk = bk
        self.clauses = clauses
        print("ATOMS: ", atoms)

    def get_params(self):
        return self.im.get_params() + self.fc.get_params()

    def forward(self, x):
        zs = self.pm(x)
        v_0 = self.fc(zs, self.atoms, self.bk)
        v_T = self.im(v_0)
        # self.print_valuation_batch(v_0)
        return v_T

    def infer(self, x):
        """
        x: raw image that consists of several objects
        """
        zs = self.pm.percept(x)
        v_0 = self.fc.convert(zs, self.atoms)
        # self.print_valuation_batch(v_0)
        v_T = self.im.infer(v_0)
        return v_T

    def predict(self, v, predname):
        # v: batch * |atoms|
        target_index = get_index_by_predname(
            pred_str=predname, atoms=self.atoms)
        return v[:, target_index]

    def predict_multi(self, v, prednames):
        # v: batch * |atoms|
        target_indices = []
        for predname in prednames:
            target_index = get_index_by_predname(
                pred_str=predname, atoms=self.atoms)
            target_indices.append(target_index)
        return torch.cat([v[:, i].unsqueeze(-1) for i in target_indices], dim=0)

    def print_program(self):
        """
        print summary of logic programs by discretizing continuous weights

        Inputs
        ------
        C : List[.logic.Clause]
            set of clauses
        IM : .logic.infer.InferModule
            infer module that contains clause weights
        """
        print('====== LEARNED PROGRAM ======')
        IM = self.im
        C = self.clauses
        #print('Ws: ', self.IM.Ws)

        Ws_softmaxed = torch.softmax(IM.W, 1)
        #print('Ws_softmaxed: ', Ws_softmaxed)

        for i, W_ in enumerate(Ws_softmaxed):
            #W_ = IM.softmax(W)
            max_i = np.argmax(W_.detach().cpu().numpy())
            print('C_'+str(i)+': ',
                  C[max_i], W_[max_i].detach().cpu().item())

    def print_valuation_batch(self, valuation, n=40):
        self.print_program()
        for b in range(valuation.size(0)):
            print('===== BATCH: ', b, '=====')
            v = valuation[b].detach().cpu().numpy()
            idxs = np.argsort(-v)
            for i in idxs:
                if v[i] > 0.1:
                    print(i, self.atoms[i], ': ', round(v[i], 3))

    def get_valuation_text(self, valuation):
        text_batch = ''  # texts for each batch
        for b in range(valuation.size(0)):
            top_atoms = self.get_top_atoms(valuation[b].detach().cpu().numpy())
            text = '----BATCH ' + str(b) + '----\n'
            text += self.atoms_to_text(top_atoms)
            text += '\n'
            # texts.append(text)
            text_batch += text
        return text_batch

    def get_top_atoms(self, v):
        top_atoms = []
        for i, atom in enumerate(self.atoms):
            if v[i] > 0.7:
                top_atoms.append(atom)
        return top_atoms

    def atoms_to_text(self, atoms):
        text = ''
        for atom in atoms:
            text += str(atom) + ', '
        return text
