import torch
import torch.nn as nn
from tqdm import tqdm
from module import *

class LogicModel(nn.Module):
    def __init__(self, dim1_occupied, dim2_occupied) -> None:
        super().__init__()
        self.dim1 = dim1_occupied
        self.dim2 = dim2_occupied
        self.dim_base1 = dim1_occupied
        self.dim_base2 = dim2_occupied
        self.closed = False

        self.submodules = nn.ModuleList()

    def add_module1(self, head=-1):
        assert not self.closed
        if head >= 0:
            self.closed = True
        m = LogicModule1(self.dim1,self.dim2)
        m.info['dim'] = self.dim1
        m.info['head'] = head
        self.submodules.append(m)
        self.dim1 += 1

    def add_module2(self, head=-1):
        assert not self.closed
        if head >= 0:
            self.closed = True
        m = LogicModule2(self.dim1,self.dim2)
        m.info['dim'] = self.dim2
        m.info['head'] = head
        self.submodules.append(m)
        self.dim2 += 1

    def forward(self, a, A):
        b = a.clone().detach()
        B = A.clone().detach()
        for m in self.submodules:
            if m.info['head'] == -2:
                continue
            pred = m(b, B)
            b = b.clone()
            target_dim = m.info['dim']
            if m.info['arity'] == 1:
                b[target_dim] = pred
            elif m.info['arity'] == 2:
                B[target_dim] = pred
            else:
                raise
            if m.info['fixed'] and m.info['head'] >= 0:
                target_dim = m.info['head']
                if m.info['arity'] == 1:
                    b[target_dim] = torch.max(pred,b[target_dim])
                elif m.info['arity'] == 2:
                    B[target_dim] = torch.max(pred,B[target_dim])
                else:
                    raise
        return b, B

    def inference_(self, a, A):
        with torch.no_grad():
            while True:
                a_old = a.clone()
                A_old = A.clone()
                for m in self.submodules:
                    if 'fixed' in m.info and m.info['fixed'] == True:
                        pred = m.logic(a, A)
                        if m.info['arity'] == 1:
                            a[m.info['dim']] = pred
                        elif m.info['arity'] == 2:
                            A[m.info['dim']] = pred
                        if 'head' in m.info and m.info['head'] >=0:
                            if m.info['arity'] == 1:
                                a[m.info['head']] = torch.max(a[m.info['dim']], a[m.info['head']])
                            elif m.info['arity'] == 2:
                                A[m.info['head']] = torch.max(A[m.info['dim']], A[m.info['head']])
                if (a_old - a).norm() == 0 and (A_old - A).norm() == 0:
                    break

    def train_one_module(self, module_id:int, a, A, p, lr=1e-1, epoch=100):
        opt = torch.optim.Adam(filter(lambda p:p.requires_grad, self.parameters()),lr=lr)
        with tqdm(range(epoch),ncols=80) as _t:
            for _ in _t:
                pred = self.forward(a, A)[0][self.submodules[module_id].info['dim']]
                # pred = m1(b, B)
                loss = -((p*pred).sum()/p.sum()).log() + (((1-p)*pred).sum()/(1-p).sum()).log()
                _t.set_postfix_str('POS: {:.2f} NEG: {:.2f}'.format((p*pred).sum(),((1-p)*pred).sum()))
                opt.zero_grad()
                loss.backward()
                opt.step()

    def fix_parameters(self):
        for m in self.submodules:
            m.fix_parameters()
        self.dim1 = self.dim_base1
        self.dim2 = self.dim_base2
        assert self.closed
        self.closed = False