import torch
import torch.nn as nn
from torch.nn import functional as F
from tqdm import tqdm
import random
from device import device
from dataloader import DataCLUTRR
from reasoning_module import ReasoningModule

class ReasoningModel(nn.Module):
    def __init__(self, data:DataCLUTRR, fuzzy_induction=False,
     merge_inductive_preds=True, dim_train=2,
     dim_max=100, reweight='replace'):
        super(ReasoningModel, self).__init__()
        self.dataloader = data
        self.predicate_id2str = data.id2r
        self.predicate_str2id = data.r2id
        self.fuzzy_induction = fuzzy_induction
        self.merge_inductive_preds = merge_inductive_preds
        self.randinit = False
        self.reweight = reweight

        self.fixed_modules = nn.ModuleList()
        self.cache = nn.ModuleList()

        self.data = list()

        self.n_predicate_base = data.n_predicate
        self.n_predicate_invent = 0
        self.n_predicate_train = 0

        self.dim_base = self.n_predicate_base
        self.dim_invent = self.dim_base
        self.dim_train = self.dim_base
        self.wid_train = dim_train
        self.wid_max = dim_max

    def read_data(self, filename):
        self.data = self.dataloader.get_data(filename)
        for i in range(len(self.data)):
            self.data[i].insert(0, 1.0)
            A = self.data[i][-1]
            assert A.shape[0] == self.n_predicate_base
            Anew = torch.zeros((self.wid_max,A.shape[1],A.shape[2]), device=device)
            Anew[:A.shape[0]] = A
            self.data[i][-1] = Anew

    def generate_train_batch(self, batch_size=10, head=None):
        # batch = target, query, graph
        pos = list()
        neg = list()
        npos = 0
        nneg = 0

        if head == None:
            for wt, target, query_edge, A in self.data:
                if wt != 0:
                    head = target
                    break
        if head == None:
            return None

        random.shuffle(self.data)
        if self.fuzzy_induction:
            random.shuffle(self.data)
        for wt, target, query_edge, A in self.data:
            if target != head and nneg < batch_size:
                neg.append((wt, target, query_edge, A))
                nneg += 1
            if wt == 0.0:
                continue
            if target == head and npos < batch_size:
                pos.append((wt, target, query_edge, A))
                npos += 1
            if npos == batch_size and nneg == batch_size:
                break

        if npos == 0:
            return None
        else:
            return head, pos, neg

    def info(self, head, dim):
        ret = dict()
        ret['head'] = head
        ret['dim'] = dim
        return ret

    def forward(self, graph, layernorm=False):
        A = graph.clone().detach()
        for module in self.cache:
            pred = module(A)
            if layernorm:
                m = pred.sum()/pred.nonzero().shape[0]
                A[module.info['dim']] = pred / m
            else:
                A[module.info['dim']] = pred
        return A

    def init_training_cache(self, head=-1):
        if self.dim_train - self.dim_invent != self.wid_train:
            self.cache = nn.ModuleList()
            for i in range(self.wid_train-1):
                self.cache.append(ReasoningModule(self.dim_invent+i,self.dim_base, self.info(-1, self.dim_invent+i), self.randinit, 1.0))
            self.dim_train = self.dim_invent + self.wid_train
            self.cache.append(ReasoningModule(self.dim_train-1,self.dim_base, self.info(head, self.dim_train-1), self.randinit, 1.0))
        elif head != None:
            self.cache[-1].info['head'] = head
            
    def train_one_module(self, batch, epoch=100, lr=1e-1, norm=0.9):
        head, pos, neg = batch
        self.init_training_cache(head)
        opt = torch.optim.Adam(filter(lambda p:p.requires_grad, self.parameters()), lr=lr)
        with tqdm(range(epoch), ncols=80) as _t:
            for ep in _t:
                scorep = torch.zeros((), device=device)
                scoren = torch.zeros((), device=device)
                sump = 0.0
                sumn = 0.0
                for wt, target, query, graph in pos:
                    pred = self.forward(graph)[self.dim_train-1][query]
                    scorep = scorep + wt*pred
                    sump += wt
                for wt, target, query, graph in neg:
                    pred = self.forward(graph)[self.dim_train-1][query]
                    scoren = scoren + pred
                    sumn += 1.0
                scorep = scorep / sump
                scoren = scoren/len(neg)# + scorep
                loss = -(scorep+1e-10) + (1-norm)*(scoren+1e-10)

                _t.set_postfix_str('pos: {:.3f} neg:{:.3f}'.format(scorep.item(),scoren.item()))
                opt.zero_grad()
                loss.backward()
                opt.step()
        # print(self.cache[-1].w1)
        # print(self.cache[-1].w2)
        # print(self.cache[-1].mode)
        # raise
        self.cache[-1].fix_parameters()
        print(self.cache[-1].output_model(self.predicate_id2str))
        return self.cache[-1]
    
    def check_cache(self, acc_threshold=0.0):
        # try to move valuable modules from cache to invent and apply results
        # check for correctness and save necessary info
        for module in self.cache:
            module.fix_parameters()

        correct = 0
        wrong = 0
        save = list()
        head = self.cache[-1].info['head']
        with torch.no_grad():
            with tqdm(self.data, ncols=80) as _t:
                _t.set_description_str('Checking correctness')
                for _, target, query, graph in _t:
                    pred = self.forward(graph)
                    save.append(pred)
                    pred = pred[self.dim_train-1][query]
                    if pred:
                        if target == head:
                            correct += 1
                        else:
                            wrong += 1
                    _t.set_postfix_str(str(correct)+'/'+str(correct+wrong))
        if correct == 0 or correct/(correct+wrong) < acc_threshold:
            random.shuffle(self.data)
            self.randinit = True
            self.dim_train = -1
            return
        accuracy = correct/(correct+wrong)
        # debug ----------------------------------
        # accuracy = 0.5
        # ----------------------------------------
        self.cache[-1].info['accuracy'] = accuracy
        self.cache[-1].info['correct'] = correct
        self.cache[-1].info['wrong'] = wrong
        
        # now move associated modules from cache to invent
        target_dims = set()
        mapping = dict()
        S = list()
        S.append(self.dim_train-1)
        while len(S):
            dim = S.pop()
            if dim < self.dim_invent:
                continue
            target_dims.add(dim)
            module = self.cache[dim-self.dim_invent]
            S.append(module.phi1)
            S.append(module.phi2)
        target_dims = list(target_dims)
        target_dims.sort()
        
        n_moved=0
        for dim in target_dims:
            self.cache[dim-self.dim_invent].info['dim'] = self.dim_invent + n_moved
            self.fixed_modules.append(self.cache[dim-self.dim_invent])
            mapping[dim] = self.cache[dim-self.dim_invent].info['dim']
            n_moved += 1
        for i in range(self.dim_invent, len(self.fixed_modules)):
            if self.fixed_modules[i].phi1 in mapping:
                self.fixed_modules[i].phi1 = mapping[self.fixed_modules[i].phi1]
            if self.fixed_modules[i].phi2 in mapping:
                self.fixed_modules[i].phi2 = mapping[self.fixed_modules[i].phi2]
        self.dim_invent = self.dim_base + len(self.fixed_modules)
        self.dim_train = self.dim_invent
        self.cache = nn.ModuleList()

        # for module in self.cache:
        #     self.fixed_modules.append(module)
        # self.cache = nn.ModuleList()
        # self.dim_invent = self.dim_base + len(self.fixed_modules)
        # self.dim_train = self.dim_invent

        # reweight training samples and apply new invent results
        # with torch.no_grad():
        #     for i in range(len(self.data)):
        #         wt, target, query, graph = self.data[i]
        #         cur_graph_dim = self.dim_invent-n_moved
        #         for dim in target_dims:
        #             graph[cur_graph_dim] = save[i][dim]
        #             cur_graph_dim += 1
        #         pred = graph[self.fixed_modules[-1].info['dim']][query]
        #         head = self.fixed_modules[-1].info['head']
        #         if pred:
        #             if head == target:
        #                 # decrease wt
        #                 if self.reweight == 'replace':
        #                     self.data[i][0] = 0.0
        #                 else:
        #                     self.data[i][0] *= 1-accuracy

        # apply merge induction and reweight
        if self.merge_inductive_preds:
            with torch.no_grad():
                for i in range(len(self.data)):
                    _, target, query, graph = self.data[i]
                    self.inference_(graph, fast_train=False)
                    # ---------------------------------
                    if graph[target][query] > 0:
                        self.data[i][0] = 0.0
                    # ---------------------------------
        self.randinit = False

    def train_model(self, max_try_time=1000, batch_size=10, epoch=100, lr=1e-1, norm=0.9):
        for ep in range(max_try_time):
            if self.dim_train >= self.wid_max:
                return
            batch = self.generate_train_batch(batch_size)
            if batch == None:
                break
            self.train_one_module(batch, epoch, lr, norm)
            self.check_cache()
            cnt = 0
            if not self.fuzzy_induction:
                for wt,_,_,_ in self.data:
                    if wt == 0:
                        cnt += 1
            else:
                with torch.no_grad():
                    for _,target,query_edge,A in self.data:
                        graph = A.clone()
                        self.inference_(graph,max_loop=1)
                        dim = graph[self.dataloader.ans_edge_st:self.dataloader.ans_edge_ed].argmax(dim=0)[query_edge]
                        if dim == target-self.dataloader.ans_edge_st:
                            cnt += 1
                        del graph
            print(str(cnt)+'/'+str(len(self.data))+' dim: '+str(self.dim_invent)+' try: '+str(ep))
        print('Finished')

    def propagation_inv2base_(self, graph:torch.Tensor, origin_graph:torch.Tensor):
        with torch.no_grad():
            graph[:self.dim_base] = 0.0
            for i in range(self.dim_base, self.dim_invent):
                module = self.fixed_modules[i-self.dim_base]
                head = module.info['head']
                dim = module.info['dim']
                if head == -1:
                    continue
                if self.fuzzy_induction:
                    acc = module.info['accuracy']
                    graph[head] = 1-(1-graph[head])*(1-graph[dim]*module.weight)
                else:
                    acc = 1.0
                    graph[head] = FuzzyLogic.LOR(graph[head], graph[dim]*acc)
            graph[:self.dim_base] = torch.max(origin_graph[:self.dim_base], graph[:self.dim_base])

    def propagation_base2inv_(self, graph:torch.Tensor):
        with torch.no_grad():
            graph[self.dim_base:self.dim_invent] = 0.0
            for module in self.fixed_modules:
                pred = module(graph)
                graph[module.info['dim']] = pred

    def inference_(self, graph:torch.Tensor, max_loop=1, fast_train=False):
        base = graph[:self.dim_base].clone()
        with torch.no_grad():
            if fast_train:
                graph_old = graph.clone()
                module = self.fixed_modules[-1]
                head = module.info['head']
                dim = module.info['dim']
                graph[head] = (graph[head]+graph[dim]).clamp(max=1.0)
                if (graph_old - graph).norm(p=1) == 0:
                    return
            for _ in range(max_loop):
                self.propagation_base2inv_(graph)
                graph_old = graph.clone()
                self.propagation_inv2base_(graph, base)
                if (graph_old - graph).norm(p=1) == 0:
                    return


                


class FuzzyLogic:
    @staticmethod
    def LAND(p1, p2):
        return p1*p2
    
    @staticmethod
    def LOR(p1, p2):
        return p1+p2-p1*p2
        # return p1+p2

            