import torch
import numpy as np

class NetRate_batch:
    def __init__(self, dist="exp", delta=1, lr=0.0001, penl_l1 = 10, hard_thres=0.01, eps=0.001, batch_size=1000, max_Iter=100):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if torch.cuda.is_available():
            print("cuda device enabled")

        self.hyparams = {
            'dist': dist,
            'delta': delta,
            'lr': lr,
            'penl_l1': penl_l1,
            'hard_thres': hard_thres,
            'stopping': eps,
            'batch_size': batch_size,
            'max_Iter': max_Iter
        }
    
    def get_infection_info(self, cascades, t, delta):
        nc, nd = cascades.shape
        infect_status = (cascades < t).float()  # [B, N]

        # 3. Delay 矩阵与截零
        Delay = cascades[:, :, None] - cascades[:, None, :]  # [B, N, N]
        # —— 关键：把正延迟（后感染）都设为 0
        Delay = Delay.clamp_max_(0.0)

        # 4. 指示矩阵
        temp_ind = (Delay < 0).float()  # Delay<0  
        Ind_1   = infect_status[:, :, None] * temp_ind
        Ind_2   = infect_status[:, :, None] * temp_ind * infect_status[:, None, :]
        Ind_pow = (Delay < -delta).float()

        # 5. 找 source：每条 cascade 中“延迟<0”的次数 >= N-1
        num_pre_infect = temp_ind.sum(dim=2)   # [B, N]
        cas_source = torch.argmax((num_pre_infect >= (nd-1)).int(), dim=1)  # [B]

        # 6. non_source_mask
        non_source_mask = torch.ones(nc, nd, dtype=torch.bool, device=self.device)
        non_source_mask[torch.arange(nc, device=self.device), cas_source] = False

        # 7. infected_mask：感染且非 source
        infected_mask = infect_status.bool() & non_source_mask  # [B, N]
        
        return Delay, infected_mask, Ind_1, Ind_2, Ind_pow
    
    def optimize(self, A_ini, cascades, t):
        # Move numpy arrays to PyTorch tensors on GPU
        A_update = torch.tensor(A_ini, device=self.device, dtype=torch.float32)
        cascades = np.asarray(cascades, dtype=np.float32)

        Likelihood_record = []
        nc, nd = cascades.shape
        Iter = 0

        Gra_sum = torch.zeros((nd, nd), device=self.device, dtype=torch.float32)
        likelihood_sum = 0.0

        for start in range(0, nc, self.hyparams['batch_size']):
            end = min(nc, start + self.hyparams['batch_size'])
            casc_batch = torch.from_numpy(cascades[start:end])      # CPU
            casc_batch = casc_batch.to(self.device, non_blocking=True)  # GPU

            infect_info = self.get_infection_info(casc_batch, t, self.hyparams['delta'])

            logSurv, Haza = self.update_likelihood_component(A_update, self.hyparams["dist"], infect_info, self.hyparams['hard_thres'], self.hyparams["delta"])
            
            likelihood_sum += self.get_loss(logSurv, Haza, infect_info[1], self.hyparams["hard_thres"])

        likelihood_sum -= self.hyparams["penl_l1"] * torch.norm(A_update.ravel(), p=1)
        Likelihood_record.append(likelihood_sum)
        Likelihood = likelihood_sum.clone()

        while ((likelihood_sum - Likelihood) / torch.abs(Likelihood)) > self.hyparams["stopping"] or (Iter == 0):
            Likelihood = likelihood_sum.clone()
            Gra_sum = torch.zeros((nd, nd), device=self.device, dtype=torch.float32)
            likelihood_sum = 0.0

            for start in range(0, nc, self.hyparams['batch_size']):
                end = min(nc, start + self.hyparams['batch_size'])
                casc_batch = torch.from_numpy(cascades[start:end])      # CPU
                casc_batch = casc_batch.to(self.device, non_blocking=True)  # GPU
                infect_info = self.get_infection_info(casc_batch, t, self.hyparams['delta'])

                logSurv, Haza = self.update_likelihood_component(A_update, self.hyparams["dist"], infect_info, self.hyparams['hard_thres'], self.hyparams["delta"])

                Gra_sum += self.get_grad(Haza, infect_info, 
                                self.hyparams["dist"], self.hyparams["hard_thres"], self.hyparams["delta"])
                
                likelihood_sum += self.get_loss(logSurv, Haza, infect_info[1], self.hyparams["hard_thres"])

            # Update explicit and implicit network
            A_update += self.hyparams["lr"] * Gra_sum
            A_update = torch.sign(A_update) * torch.clamp(
                torch.abs(A_update) - self.hyparams["lr"] * self.hyparams["penl_l1"], min=0)
            A_update = torch.clamp(A_update, min=0)
            A_update.diagonal().zero_()

            # Update Q1
            likelihood_sum -= self.hyparams["penl_l1"] * torch.norm(A_update.ravel(), p=1)
            
            Likelihood_record.append(likelihood_sum)
            Iter += 1

            if Iter >= self.hyparams['max_Iter']:
              break

            print(Iter)

        return {
            'A': A_update.cpu().numpy(),
            'loss': Likelihood_record
        }
    
    def get_grad(self, Haza, infect_info, dist, hard_thres, delta):
        Delay, Ind_1, Ind_2 = infect_info[0], infect_info[2], infect_info[3]
        a = torch.sum(Ind_2 * Haza, dim=1)
        a = a.clamp_min_(hard_thres)

        if dist == "exp":
            grad_S = Ind_1 * Delay
            grad_H = Ind_2 * (1 / a).unsqueeze(1)
        elif dist == "ray":
            grad_S = Ind_1 * (-Delay**2/2)
            grad_H = Ind_2 * (-Delay) * (1 / a).unsqueeze(1)
        elif dist == "pow":
            Ind_pow = infect_info[4]
            Delay_S = Delay.clone(); Delay_H = Delay.clone()
            mask = (Delay == 0)
            Delay_S[mask] = -delta; Delay_H[mask] = -hard_thres
            grad_S = Ind_1 * -torch.log(-Delay_S / delta) * Ind_pow
            grad_H = Ind_2 / (-Delay_H) * (1 / a).unsqueeze(1) * Ind_pow

        grad = (grad_S + grad_H)
        grad = torch.sum(grad, dim=0)
        grad.diagonal().zero_()
        return grad
    
    def update_likelihood_component(self, A_update, dist, infect_info, hard_thres, delta):
        Delay = infect_info[0]; Ind_2 = infect_info[3]
        nc = Delay.shape[0]
        theta = A_update.unsqueeze(0)
        
        if dist == "exp":
            logSurv = theta * Delay
            Haza = theta * Ind_2
        elif dist == "ray":
            logSurv = theta * -(Delay)**2/2
            Haza = theta * Ind_2 * (-Delay)
        elif dist == "pow":
            Ind_pow = infect_info[4]
            Delay_S = Delay.clone(); Delay_H = Delay.clone()
            mask = (Delay == 0)
            Delay_S[mask] = -delta; Delay_H[mask] = -hard_thres
            logSurv = theta * -torch.log(-Delay_S / delta) * Ind_pow
            Haza = theta * Ind_2 / (-Delay_H) * Ind_pow
        return logSurv, Haza

    def get_loss(self, logSurv, Haza, infected_mask, hard_thres):
        a = torch.sum(Haza, dim=1).clamp_min_(hard_thres)
        hazard_vals = infected_mask * torch.log(a)
        H = hazard_vals.sum(dim=1)

        return logSurv.sum() + H.sum()