import torch
import numpy as np
import torch.nn.functional as F

class NetEM_final:
    def __init__(self, dist="exp", delta=1, lr_A=0.0025, lr_B=0.0025, penl_svd=1, penl_l1_B=0.2,
                 penl_l1_A=0.2, hard_thres=0.01, max_Iter=1000, eps=1e-7, svd_freq=10, loss_freq=100, batch_size=1000, rank_guess=5, oversample=20):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if torch.cuda.is_available():
            print("cuda device enabled")

        self.hyparams = {
            'dist': dist,
            'delta': delta,
            'lr_A': lr_A,
            'lr_B': lr_B,
            'penl_svd': penl_svd,
            'penl_l1_B': penl_l1_B,
            'penl_l1_A': penl_l1_A,
            'hard_thres': hard_thres,
            'max_Iter': max_Iter,
            'stopping': eps,
            'svd_freq': svd_freq,
            'loss_freq': loss_freq,
            'batch_size': batch_size,
            'rank_guess': rank_guess,
            'p': oversample
        }

    def em_optimize(self, A_ini, B_ini, P_pathway_ini, P_pathway_ind_ini, supp_A, cascades, t):    
        # Move numpy arrays to PyTorch tensors on GPU
        A_update = torch.tensor(A_ini, device=self.device, dtype=torch.float32)
        B_update = torch.tensor(B_ini, device=self.device, dtype=torch.float32)
        P_pathway_update = torch.tensor(P_pathway_ini, device=self.device, dtype=torch.float32)
        P_pathway_ind = torch.tensor(P_pathway_ind_ini, device=self.device, dtype=torch.float32)
        supp_A = torch.tensor(supp_A, device=self.device, dtype=torch.int32)
        cascades = np.asarray(cascades, dtype=np.float32)

        nc, nd = cascades.shape 
        P_pathway_ind_sum = torch.zeros((nd), device=self.device, dtype=torch.float32)
        Gra_E_sum = torch.zeros((nd, nd), device=self.device, dtype=torch.float32)
        Gra_I_sum = torch.zeros((nd, nd), device=self.device, dtype=torch.float32)

        Likelihood_record = []
        
        for k in range(self.hyparams["max_Iter"]):
            epoch_loss = torch.tensor(0.0, device=self.device, dtype=torch.float32)

            P_pathway_ind_sum.fill_(0)
            Gra_E_sum.fill_(0)
            Gra_I_sum.fill_(0)

            B_update_use = B_update.clone()
            B_update_use.diagonal().zero_()

            for start in range(0, nc, self.hyparams['batch_size']):
                end = min(nc, start + self.hyparams['batch_size'])
                batch_cpu = torch.from_numpy(cascades[start:end])         # 零拷贝
                casc_batch = batch_cpu.to(self.device, non_blocking=True)

                infect_info = self.get_infection_info(casc_batch, t, self.hyparams['delta'])

                log_S_E, H_E = self.update_explicit_likelihood_component(A_update, self.hyparams["dist"], infect_info, self.hyparams['hard_thres'], self.hyparams["delta"])
                log_S_I, H_I = self.update_implicit_likelihood_component(B_update_use, self.hyparams["dist"], infect_info, self.hyparams['hard_thres'], self.hyparams["delta"])

                P_pathway_ind = self.e_step(P_pathway_update, log_S_E, log_S_I, H_E, H_I, infect_info[1])
                P_pathway_ind_sum += torch.sum(P_pathway_ind, dim=0)

                Gra_E_sum += self.get_Q1_grad(H_E, infect_info, P_pathway_ind, supp_A, 
                                self.hyparams["dist"], self.hyparams["hard_thres"], self.hyparams["delta"])
                
                Gra_I_sum += self.get_Q2_grad(H_I, infect_info, P_pathway_ind, 
                                    self.hyparams["dist"], self.hyparams["hard_thres"], self.hyparams["delta"])
                
                if k % self.hyparams['loss_freq'] == 0:
                    Q1 = self.get_Q1(
                    log_S_E, H_E, infect_info[1], P_pathway_ind, 
                    self.hyparams["hard_thres"])
                    
                    Q2 = self.get_Q2(
                    log_S_I, H_I, infect_info[1], P_pathway_ind,
                    self.hyparams["hard_thres"])

                    Q3 = self.get_Q3(P_pathway_ind, P_pathway_update, infect_info[5])

                    epoch_loss += Q1 + Q2 + Q3
                
            P_pathway_update = P_pathway_ind_sum / nc
                
            A_update = A_update + self.hyparams["lr_A"] * Gra_E_sum
            A_update = torch.sign(A_update) * torch.clamp(torch.abs(A_update) - self.hyparams["lr_A"] * self.hyparams["penl_l1_A"], min=0.0)
            A_update = torch.clamp(A_update, min=0)
            A_update.diagonal().zero_()

            B_update = B_update + self.hyparams["lr_B"] * Gra_I_sum
            if k % self.hyparams['svd_freq'] == 0:
                U, S, Vh = torch.linalg.svd(B_update, False)
                S = torch.clamp(S - self.hyparams["lr_B"] * self.hyparams["penl_svd"], min=0)
                B_update = U @ torch.diag(S) @ Vh
                # B_update = self.randomized_svd_shrink(
                # B_update,
                # rank=self.hyparams['rank_guess'],
                # lr_B=self.hyparams["lr_B"],
                # penl_svd=self.hyparams["penl_svd"],
                # p=self.hyparams['p'],
                # n_iter=1
                # )
            B_update = B_update = torch.sign(B_update) * torch.clamp(torch.abs(B_update) - self.hyparams["lr_B"] * self.hyparams["penl_l1_B"], min=0.0)
            B_update = torch.clamp(B_update, min=0)

            if k % self.hyparams['loss_freq'] == 0:
                epoch_loss -= self.hyparams["penl_l1_A"] * torch.norm(A_update.ravel(), p=1) + \
                                (self.hyparams["penl_svd"] * torch.norm(B_update, p='nuc') - 
                                self.hyparams["penl_l1_B"] * torch.norm(B_update.ravel(), p=1))
                
                Likelihood_record.append(epoch_loss)
            
        return {
            'A': A_update.cpu().numpy(),
            'B': B_update.cpu().numpy(),
            'B_use': B_update_use.cpu().numpy(),
            'p': P_pathway_update.cpu().numpy(),
            'loss': Likelihood_record
        }

    def get_infection_info(self, cascades, t, delta):
        nc, nd = cascades.shape
        infect_status = (cascades < t).float()  # [B, N]

        # 3. Delay 矩阵与截零
        Delay = cascades[:, :, None] - cascades[:, None, :]  # [B, N, N]
        # —— 关键：把正延迟（后感染）都设为 0
        Delay = Delay.clamp_max_(0.0)

        # 4. 指示矩阵
        temp_ind = (Delay < 0).float()  # Delay<0  
        Ind_1   = infect_status[:, :, None] * temp_ind
        Ind_2   = infect_status[:, :, None] * temp_ind * infect_status[:, None, :]
        Ind_pow = (Delay < -delta).float()

        # 5. 找 source：每条 cascade 中“延迟<0”的次数 >= N-1
        num_pre_infect = temp_ind.sum(dim=2)   # [B, N]
        cas_source = torch.argmax((num_pre_infect >= (nd-1)).int(), dim=1)  # [B]

        # 6. non_source_mask
        non_source_mask = torch.ones(nc, nd, dtype=torch.bool, device=self.device)
        non_source_mask[torch.arange(nc, device=self.device), cas_source] = False

        # 7. infected_mask：感染且非 source
        infected_mask = infect_status.bool() & non_source_mask  # [B, N]
        
        return Delay, infected_mask, Ind_1, Ind_2, Ind_pow, non_source_mask

    def update_explicit_likelihood_component(self, A_update, dist, infect_info, hard_thres, delta):
        Delay = infect_info[0]
        Ind_2 = infect_info[3]
        nc = Delay.shape[0]
        theta_E = A_update.unsqueeze(0)

        if dist == "exp":
            log_S_E = theta_E * Delay
            H_E = theta_E * Ind_2
        elif dist == "ray":
            log_S_E = theta_E * -(Delay) ** 2 / 2
            H_E = theta_E * Ind_2 * (-Delay)
        elif dist == "pow":
            Ind_pow = infect_info[4]
            Delay_S = Delay.clone()
            Delay_H = Delay.clone()
            mask = (Delay == 0)
            Delay_S[mask] = -delta
            Delay_H[mask] = -hard_thres
            log_S_E = theta_E * -torch.log(-Delay_S / delta) * Ind_pow
            H_E = theta_E * Ind_2 / (-Delay_H) * Ind_pow

        return log_S_E, H_E

    def update_implicit_likelihood_component(self, B_update_use, dist, infect_info, hard_thres, delta):
        Delay = infect_info[0]
        Ind_2 = infect_info[3]
        nc = Delay.shape[0]
        theta_I = B_update_use.unsqueeze(0)

        if dist == "exp":
            log_S_I = theta_I * Delay
            H_I = theta_I * Ind_2
        elif dist == "ray":
            log_S_I = theta_I * -(Delay) ** 2 / 2
            H_I = theta_I * Ind_2 * (-Delay)
        elif dist == "pow":
            Ind_pow = infect_info[4]
            Delay_S = Delay.clone()
            Delay_H = Delay.clone()
            mask = (Delay == 0)
            Delay_S[mask] = -delta
            Delay_H[mask] = -hard_thres
            log_S_I = theta_I * -torch.log(-Delay_S / delta) * Ind_pow
            H_I = theta_I * Ind_2 / (-Delay_H) * Ind_pow

        return log_S_I, H_I
    
    def get_Q1(self, log_S_E, H_E, infected_mask, P_pathway_ind, hard_thres):
        logSurv_E = torch.sum(log_S_E * P_pathway_ind.unsqueeze(1), dim=(1, 2))  
        a = torch.sum(H_E, dim=1).clamp_min_(hard_thres)
        hazard_vals = infected_mask * P_pathway_ind * torch.log(a)
        Haza_E = hazard_vals.sum(dim=1)

        return logSurv_E.sum() + Haza_E.sum()

    def get_Q2(self, log_S_I, H_I, infected_mask, P_pathway_ind, hard_thres):
        logSurv_I = torch.sum(log_S_I * (1 - P_pathway_ind).unsqueeze(1), dim=(1, 2))
        b = torch.sum(H_I, dim=1).clamp_min_(hard_thres)
        hazard_vals = infected_mask * (1 - P_pathway_ind) * torch.log(b)
        Haza_I = hazard_vals.sum(dim=1)

        return logSurv_I.sum() + Haza_I.sum()
    
    def get_Q3(self, P_pathway_ind, P_pathway_update, non_source_mask):
        nc = P_pathway_ind.shape[0]
        nd = P_pathway_update.shape[0]

        p = P_pathway_update.unsqueeze(0).expand(nc, nd)
        log_terms = (P_pathway_ind * torch.log(p) + (1 - P_pathway_ind) * torch.log(1 - p))
        log_terms = log_terms * non_source_mask
        
        return log_terms.sum()
    
    def e_step(self, P_pathway_update, log_S_E, log_S_I, H_E, H_I, infected_mask):
        nc = log_S_E.shape[0]
        S_ratio = torch.exp(torch.sum(log_S_I, dim=1) - torch.sum(log_S_E, dim=1))

        a = torch.sum(H_E, dim=1).clamp_min_(self.hyparams["hard_thres"])
        b = torch.sum(H_I, dim=1) / a
        H_ratio = torch.ones_like(b)
        H_ratio[infected_mask] = b[infected_mask]

        c = 1 / (1 + ((1 - P_pathway_update) / P_pathway_update).unsqueeze(0) * S_ratio * H_ratio)
        c = torch.nan_to_num(c, nan=self.hyparams['hard_thres'])

        return torch.abs(c - self.hyparams['hard_thres'])
    
    def get_Q1_grad(self, H_E, infect_info, P_pathway_ind, supp_A, dist, hard_thres, delta):
        Delay, Ind_1, Ind_2 = infect_info[0], infect_info[2], infect_info[3]
        a = torch.sum(Ind_2 * H_E, dim=1)
        a = a.clamp_min_(hard_thres)

        if dist == "exp":
            grad_S_E = Ind_1 * Delay * P_pathway_ind.unsqueeze(1)
            grad_H_E = Ind_2 * (P_pathway_ind / a).unsqueeze(1)
        elif dist == "ray":
            grad_S_E = Ind_1 * (-Delay**2 / 2) * P_pathway_ind.unsqueeze(1)
            grad_H_E = Ind_2 * (-Delay) * (P_pathway_ind / a).unsqueeze(1)
        elif dist == "pow":
            Ind_pow = infect_info[4]
            Delay_S, Delay_H = Delay.clone(), Delay.clone()
            mask = (Delay == 0)
            Delay_S[mask] = -delta
            Delay_H[mask] = -hard_thres
            grad_S_E = Ind_1 * -torch.log(-Delay_S / delta) * P_pathway_ind.unsqueeze(1) * Ind_pow
            grad_H_E = Ind_2 / (-Delay_H) * (P_pathway_ind / a).unsqueeze(1) * Ind_pow

        grad_E = (grad_S_E + grad_H_E) * supp_A
        grad_E = torch.sum(grad_E, dim=0)
        grad_E.diagonal().zero_()
        return grad_E

    def get_Q2_grad(self, H_I, infect_info, P_pathway_ind, dist, hard_thres, delta):
        Delay, Ind_1, Ind_2 = infect_info[0], infect_info[2], infect_info[3]
        a = torch.sum(Ind_2 * H_I, dim=1)
        a = a.clamp_min_(hard_thres)

        if dist == "exp":
            grad_S_I = Ind_1 * Delay * (1 - P_pathway_ind).unsqueeze(1)
            grad_H_I = Ind_2 * ((1 - P_pathway_ind) / a).unsqueeze(1)
        elif dist == "ray":
            grad_S_I = Ind_1 * (-Delay**2 / 2) * (1 - P_pathway_ind).unsqueeze(1)
            grad_H_I = Ind_2 * (-Delay) * ((1 - P_pathway_ind) / a).unsqueeze(1)
        elif dist == "pow":
            Ind_pow = infect_info[4]
            Delay_S, Delay_H = Delay.clone(), Delay.clone()
            mask = (Delay == 0)
            Delay_S[mask] = -delta
            Delay_H[mask] = -hard_thres
            grad_S_I = Ind_1 * -torch.log(-Delay_S / delta) * (1 - P_pathway_ind).unsqueeze(1) * Ind_pow
            grad_H_I = Ind_2 / (-Delay_H) * ((1 - P_pathway_ind) / a).unsqueeze(1) * Ind_pow

        grad_I = grad_S_I + grad_H_I
        grad_I = torch.sum(grad_I, dim=0)
        return grad_I
    
    def randomized_svd_shrink(self, B, rank, lr_B, penl_svd, p, n_iter=2):
        # B: [d,d], rank=k
        d = B.shape[0]
        l = rank + p

        # 1. 随机投影构造子空间
        Omega = torch.randn(d, l, device=self.device, dtype=torch.float32)
        Y = B @ Omega
        for _ in range(n_iter):
            Y = B @ (B.T @ Y)
        Q, _ = torch.linalg.qr(Y, mode='reduced')  # [d,l]

        # 2. 在子空间做 SVD
        B_tilde = Q.T @ B                           # [l,d]
        U_t, S, Vh = torch.linalg.svd(B_tilde, full_matrices=False)
        U = Q @ U_t                                 # [d,l]

        # 3. 核范数软阈值
        S_shrink = torch.clamp(S[:rank] - lr_B * penl_svd, min=0)
        U_k = U[:, :rank]   # [d,k]
        Vh_k = Vh[:rank, :] # [k,d]

        # 4. 重构
        return U_k @ torch.diag(S_shrink) @ Vh_k     # [d,d]