import torch
import numpy as np
import time

from utils import batch_by_size, cal_ranks, cal_performance
from torch.optim import Adam, SGD, Adagrad
from torch.optim.lr_scheduler import ExponentialLR
from models import KGEModule

class BaseModel(object):
    def __init__(self, n_ent, n_rel, args, struct):
        self.model = KGEModule(n_ent, n_rel, args, struct)
        self.model.cuda()

        self.n_ent = n_ent
        self.n_rel = n_rel
        self.time_tot = 0
        # Load EHR features if requested
        if getattr(self.args, 'ehr_only', False):
            import numpy as np, torch
            assert self.args.ehr_path, "--ehr_path is required for --ehr_only"
            ehr_np = np.load(self.args.ehr_path)        # expected [N, d_ehr] = [1238, 64]
            ehr_t = torch.from_numpy(ehr_np).float()
            self.args.ehr_feats_tensor = ehr_t          # pass into model ctor
        self.args = args
        
        

    def train(self, train_data, tester_val, tester_tst):
        head, tail, rela = train_data
        # useful information related to cache
        n_train = len(head)
        
        kd_soft = None
        if getattr(self.args, 'kd', False) and self.args.kd_targets:
            kd_soft = torch.from_numpy(np.load(self.args.kd_targets)).float().cuda()  # [n_train, n_rel]
            assert kd_soft.size(0) == n_train

        if self.args.optim=='adam' or self.args.optim=='Adam':
            self.optimizer = Adam(self.model.parameters(), lr=self.args.lr)
        elif self.args.optim=='adagrad' or self.args.optim=='Adagrad':
            self.optimizer = Adagrad(self.model.parameters(), lr=self.args.lr)
        else:
            self.optimizer = SGD(self.model.parameters(), lr=self.args.lr)

        scheduler = ExponentialLR(self.optimizer, self.args.decay_rate)

        n_epoch = self.args.n_epoch
        n_batch = self.args.n_batch
        best_mrr = 0

        # used for counting repeated triplets for margin based loss

        for epoch in range(n_epoch):
            start = time.time()

            self.epoch = epoch
            rand_idx = torch.randperm(n_train)
            head = head[rand_idx].cuda()
            tail = tail[rand_idx].cuda()
            rela = rela[rand_idx].cuda()
            if kd_soft is not None:
                kd_soft = kd_soft[rand_idx].cuda()
            
            # --- Down-sample SIM edges if requested ---
            if getattr(self.args, "sim_rel_ids", None):
                # sim_rel_ids is a Python list of ints (set in train.py)
                sim_ids = torch.tensor(self.args.sim_rel_ids, device=rela.device)
                is_sim = (rela.view(-1, 1) == sim_ids.view(1, -1)).any(dim=1)  # which triples are SIM
                # Keep all non-SIM; keep SIM with probability sim_keep_prob (0..1)
                keep_sim = torch.rand_like(rela.float()) < float(getattr(self.args, "sim_keep_prob", 0.3))
                keep = (~is_sim) | (is_sim & keep_sim)
                head, tail, rela = head[keep], tail[keep], rela[keep]


            epoch_loss = 0.0
            pos = 0
            
            for h, t, r in batch_by_size(n_batch, head, tail, rela, n_sample=n_train):
                self.model.zero_grad()
            
                # base loss (you may disable if kd_only)
                base_loss = self.model.forward(h, t, r) if not self.args.kd_only else 0.0
            
                if kd_soft is not None:
                    B = h.size(0)
                    kd_batch = kd_soft[pos:pos+B]               # [B, n_rel]
                    pos += B
                    zS = self.model.pair_all_rel_logits(h, t)   # [B, n_rel]
                    kd_loss = torch.nn.functional.binary_cross_entropy_with_logits(zS, kd_batch)
                    loss = (1.0 - self.args.kd_alpha) * base_loss + self.args.kd_alpha * kd_loss
                else:
                    loss = base_loss
            
                loss += self.args.lamb * self.model.regul
                loss.backward()
                self.optimizer.step()
                self.prox_operator()
                epoch_loss += float(loss.detach().cpu().numpy())


            self.time_tot += time.time() - start
            scheduler.step()

            if (epoch + 1) % self.args.epoch_per_test == 0:
                # ---- (C) coerce to floats, print pretty, return numeric-only ----
                valid_mrr, valid_mr, valid_1, valid_10 = map(float, tester_val())
                test_mrr,  test_mr,  test_1,  test_10  = map(float, tester_tst())
        
                out_pretty = (
                    f"$valid mrr:{valid_mrr:.4f}, H@1:{valid_1:.4f}, H@10:{valid_10:.4f}\t\t"
                    f"$test mrr:{test_mrr:.4f}, H@1:{test_1:.4f}, H@10:{test_10:.4f}"
                )
                out_numeric = (
                    f"{valid_mrr:.6f} {valid_1:.6f} {valid_10:.6f} "
                    f"{test_mrr:.6f} {test_1:.6f} {test_10:.6f}\n"
                )
        
                if not self.args.mode == 'search':
                    print(out_pretty)
        
                # keep best and store NUMERIC string so train.py can write/parse safely
                if valid_mrr > best_mrr:
                    best_mrr = valid_mrr
                    best_str = out_numeric
                    
        
                if best_mrr < self.args.thres:
                    print(f"\tearly stopped in Epoch:{epoch+1}, best_mrr:{best_mrr:.6f}", self.model.struct)
                    return best_mrr, best_str
            
            print(f"Epoch {epoch}, loss = {loss.item()}")
        
        return best_mrr, best_str

    def prox_operator(self,):
        for n, p in self.model.named_parameters():
            if 'ent' in n:
                X = p.data.clone()
                Z = torch.norm(X, p=2, dim=1, keepdim=True)
                Z[Z<1] = 1
                X = X/Z
                p.data.copy_(X.view(self.n_ent, -1))

    def test_link(self, test_data, head_filter, tail_filter):
        heads, tails, relas = test_data
        
        # --- Optional relation whitelist for evaluation (e.g., DDI-only) ---
        whitelist = getattr(self.args, "eval_rel_whitelist", None)
        if whitelist is not None and len(whitelist) > 0:
            relas_t = relas if isinstance(relas, torch.Tensor) else torch.LongTensor(relas)
            mask = torch.zeros_like(relas_t, dtype=torch.bool)
            for rid in whitelist:
                mask |= (relas_t == int(rid))
            # If mask drops everything, fall back to original to avoid empty slice
            if mask.any():
                heads, tails, relas = heads[mask], tails[mask], relas[mask]
                # Also shrink the filters to match the kept rows
                head_filter = head_filter[mask.numpy()] if whitelist is not None else head_filter
                tail_filter = tail_filter[mask.numpy()] if whitelist is not None else tail_filter
        
        batch_size = self.args.test_batch_size
        num_batch = len(heads) // batch_size + int(len(heads)%batch_size>0)

        head_probs = []
        tail_probs = []
        for i in range(num_batch):
            start = i * batch_size
            end = min( (i+1)*batch_size, len(heads))
            batch_h = heads[start:end].cuda()
            batch_t = tails[start:end].cuda()
            batch_r = relas[start:end].cuda()

            h_embed = self.model.ent_embed(batch_h)
            r_embed = self.model.rel_embed(batch_r)
            t_embed = self.model.ent_embed(batch_t)

            head_scores = torch.sigmoid(self.model.test_head(r_embed, t_embed)).data
            tail_scores = torch.sigmoid(self.model.test_tail(h_embed, r_embed)).data

            head_probs.append(head_scores.data.cpu().numpy())
            tail_probs.append(tail_scores.data.cpu().numpy())

        head_probs = np.concatenate(head_probs) * head_filter
        tail_probs = np.concatenate(tail_probs) * tail_filter
        head_ranks = cal_ranks(head_probs, label=heads.data.numpy())
        tail_ranks = cal_ranks(tail_probs, label=tails.data.numpy())
        h_mrr, h_mr, h_h1, h_h10 = cal_performance(head_ranks)
        t_mrr, t_mr, t_h1, t_h10 = cal_performance(tail_ranks)
        mrr = (h_mrr + t_mrr) / 2
        mr = (h_mr + t_mr) / 2
        h1  = (h_h1  + t_h1 ) / 2
        h10 = (h_h10 + t_h10) / 2
        return mrr, mr, h1, h10




