import torch
import numpy as np
import time

from utils import batch_by_size, cal_ranks, cal_performance
from torch.optim import Adam, SGD, Adagrad
from torch.optim.lr_scheduler import ExponentialLR
from models import KGEModule

import torch
import torch.nn as nn
import torch.nn.functional as F

class FusionModule(nn.Module):
    def __init__(self, d=64):
        super().__init__()
        self.Pk = nn.Linear(d, d)
        self.Pe = nn.Linear(d, d)
        self.g1 = nn.Linear(2*d, d)
        self.g2 = nn.Linear(d, d)
        self.act = nn.ReLU()
        self.sig = nn.Sigmoid()

    def forward(self, zk, ze, has_ehr=None):
        zk = self.Pk(zk)
        ze = self.Pe(ze)
        h  = self.act(self.g1(torch.cat([zk, ze], dim=-1)))
        g  = self.sig(self.g2(h))                # [B, d]
        zf = g*zk + (1-g)*ze
        if has_ehr is not None:
            m = (~has_ehr).float().unsqueeze(-1) # 1 where EHR missing → use KG
            zf = m*zk + (1-m)*zf
        return zf

class BaseModel(object):
    
    def __init__(self, n_ent, n_rel, args, struct):
        self.model = KGEModule(n_ent, n_rel, args, struct)
        self.model.cuda()

        self.n_ent = n_ent
        self.n_rel = n_rel
        self.time_tot = 0
        self.args = args
        
        self.fusion = None
        self.Ek_full = None
        self.Ee_full = None
        self.ehr_mask = None

    def train(self, train_data, tester_val, tester_tst):
        head, tail, rela = train_data
        # useful information related to cache
        n_train = len(head)

        params = list(self.model.parameters())
        if self.fusion is not None:
            params += list(self.fusion.parameters())
        
        if self.args.optim.lower() == 'adam':
            self.optimizer = Adam(params, lr=self.args.lr)
        elif self.args.optim.lower() == 'adagrad':
            self.optimizer = Adagrad(params, lr=self.args.lr)
        else:
            self.optimizer = SGD(params, lr=self.args.lr)

        scheduler = ExponentialLR(self.optimizer, self.args.decay_rate)

        n_epoch = self.args.n_epoch
        n_batch = self.args.n_batch
        best_mrr = 0

        # used for counting repeated triplets for margin based loss

        for epoch in range(n_epoch):
            start = time.time()

            self.epoch = epoch
            rand_idx = torch.randperm(n_train)
            head = head[rand_idx].cuda()
            tail = tail[rand_idx].cuda()
            rela = rela[rand_idx].cuda()

            epoch_loss = 0

            for h, t, r in batch_by_size(n_batch, head, tail, rela, n_sample=n_train):
                self.model.zero_grad()

                loss = self.model.forward(h, t, r)
                loss += self.args.lamb * self.model.regul
                loss.backward()
                self.optimizer.step()
                self.prox_operator()
                epoch_loss += loss.data.cpu().numpy()

            self.time_tot += time.time() - start
            scheduler.step()

            if (epoch+1) %  self.args.epoch_per_test == 0:
                # output performance 
                valid_mrr, valid_mr, valid_1, valid_10 = tester_val()
                test_mrr,  test_mr,  test_1,  test_10  = tester_tst()
                out_str = '$valid mrr:%.4f, H@1:%.4f, H@10:%.4f\t\t$test mrr:%.4f, H@1:%.4f, H@10:%.4f\n'%(valid_mrr, valid_1, valid_10, test_mrr, test_1, test_10)
                # if not self.args.mode == 'search':
                #     print(out_str)

                # output the best performance info
                if valid_mrr > best_mrr:
                    best_mrr = valid_mrr
                    best_str = out_str
                if best_mrr < self.args.thres:
                    print('\tearly stopped in Epoch:{}, best_mrr:{}'.format(epoch+1, best_mrr), self.model.struct)
                    return best_mrr, best_str
            print(f"Epoch {epoch}, loss = {loss.item()}")
        return best_mrr, best_str

    def prox_operator(self,):
        for n, p in self.model.named_parameters():
            if 'ent' in n:
                X = p.data.clone()
                Z = torch.norm(X, p=2, dim=1, keepdim=True)
                Z[Z<1] = 1
                X = X/Z
                p.data.copy_(X.view(self.n_ent, -1))

    def test_link(self, test_data, head_filter, tail_filter):
        heads, tails, relas = test_data
        batch_size = self.args.test_batch_size
        num_batch = len(heads) // batch_size + int(len(heads)%batch_size>0)

        head_probs = []
        tail_probs = []
        for i in range(num_batch):
            start = i * batch_size
            end = min( (i+1)*batch_size, len(heads))
            batch_h = heads[start:end].cuda()
            batch_t = tails[start:end].cuda()
            batch_r = relas[start:end].cuda()

            if self.fusion is None:
                h_embed = self.model.ent_embed(batch_h)
                t_embed = self.model.ent_embed(batch_t)
            else:
                h_embed = self.fused_lookup(batch_h)
                t_embed = self.fused_lookup(batch_t)
            
            r_embed = self.model.rel_embed(batch_r)


            head_scores = torch.sigmoid(self.model.test_head(r_embed, t_embed)).data
            tail_scores = torch.sigmoid(self.model.test_tail(h_embed, r_embed)).data

            head_probs.append(head_scores.data.cpu().numpy())
            tail_probs.append(tail_scores.data.cpu().numpy())

        head_probs = np.concatenate(head_probs) * head_filter
        tail_probs = np.concatenate(tail_probs) * tail_filter
        head_ranks = cal_ranks(head_probs, label=heads.data.numpy())
        tail_ranks = cal_ranks(tail_probs, label=tails.data.numpy())
        h_mrr, h_mr, h_h1, h_h10 = cal_performance(head_ranks)
        t_mrr, t_mr, t_h1, t_h10 = cal_performance(tail_ranks)
        mrr = (h_mrr + t_mrr) / 2
        mr = (h_mr + t_mr) / 2
        h1  = (h_h1  + t_h1 ) / 2
        h10 = (h_h10 + t_h10) / 2
        return mrr, mr, h1, h10
    
    def enable_fusion(self, Ee_full: torch.Tensor, ehr_mask: torch.Tensor):
        """
        Ee_full: [n_ent, 64] EHR table (zeros for non-overlap rows)
        ehr_mask: [n_ent] bool True where Ee_full is available
        """
        dev = next(self.model.parameters()).device
        self.Ek_full  = self.model.ent_embed.weight.detach()          # [n_ent, 64]
        self.Ee_full  = Ee_full.to(dev)
        self.ehr_mask = ehr_mask.to(dev).bool()
        d = self.Ek_full.size(1)
        self.fusion   = FusionModule(d=d).to(dev)

    def fused_lookup(self, ent_ids: torch.Tensor):
        flat  = ent_ids.view(-1)
        zk    = self.Ek_full[flat]
        ze    = self.Ee_full[flat]
        has_e = self.ehr_mask[flat]
        zf    = self.fusion(zk, ze, has_e)
        return zf.view(*ent_ids.shape, -1)





