import torch
import numpy as np

class MMRate_torch:
    def __init__(self, dist="exp", delta=1, lr_A=0.0025, lr_B=0.0025, penl_svd=1, penl_l1_B=0.2,
                 penl_l1_A=0.2, hard_thres=0.01, max_Iter=1000, eps=1e-7):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if torch.cuda.is_available():
            print("cuda device enabled")

        self.hyparams = {
            'dist': dist,
            'delta': delta,
            'lr_A': lr_A,
            'lr_B': lr_B,
            'penl_svd': penl_svd,
            'penl_l1_B': penl_l1_B,
            'penl_l1_A': penl_l1_A,
            'hard_thres': hard_thres,
            'max_Iter': max_Iter,
            'stopping': eps
        }

    def em_optimize(self, A_ini, B_ini, P_pathway_ini, P_pathway_ind_ini, supp_A, cascades, t):    
        # Move numpy arrays to PyTorch tensors on GPU
        A_ini = torch.tensor(A_ini, device=self.device, dtype=torch.float64)
        B_ini = torch.tensor(B_ini, device=self.device, dtype=torch.float64)
        P_pathway_ini = torch.tensor(P_pathway_ini, device=self.device, dtype=torch.float64)
        P_pathway_ind_ini = torch.tensor(P_pathway_ind_ini, device=self.device, dtype=torch.float64).unsqueeze(1)
        supp_A = torch.tensor(supp_A, device=self.device, dtype=torch.int32)

        # Get infection info
        infect_info = self.get_infection_info(cascades, t, self.hyparams['delta'])

        # Initialize components
        log_S_E_ini, log_S_I_ini, H_E_ini, H_I_ini, Q1_ini, Q2_ini, Q3_ini = self.initialize_likelihood(A_ini, B_ini, P_pathway_ini, P_pathway_ind_ini, infect_info)

        Likelihood_ini = Q1_ini + Q2_ini + Q3_ini
        loss_components = (log_S_E_ini, log_S_I_ini, H_E_ini, H_I_ini)
        params = (A_ini, B_ini, P_pathway_ini)
        losses = (Q1_ini, Q2_ini, Likelihood_ini)

        # Run EM optimization
        return self._run_em_iterations(params, supp_A, loss_components, infect_info, losses)

    def _run_em_iterations(self, params, supp_A, loss_components, infect_info, losses):
        A_update, B_update, P_pathway_update = params
        log_S_E, log_S_I, H_E, H_I = loss_components
        Q1_update, Q2_update, Likelihood_update = losses

        Likelihood_record = []

        for k in range(self.hyparams["max_Iter"]):
            # Run single EM step
            step_results = self.em_step(
                A_update, B_update, supp_A, P_pathway_update,
                log_S_E, log_S_I, H_E, H_I, infect_info, Q1_update, Q2_update
            )

            # Unpack results
            (A_update, B_update, B_update_use, P_pathway_update,
             Q1_update, Q3_update, Q2_update, log_S_E, H_E, log_S_I, H_I) = step_results

            Likelihood_update = Q1_update + Q2_update + Q3_update
            Likelihood_record.append(Likelihood_update)

        return {
            'A': A_update.cpu().numpy(),
            'B': B_update.cpu().numpy(),
            'B_use': B_update_use.cpu().numpy(),
            'p': P_pathway_update.cpu().numpy(),
            'loss': Likelihood_record
        }
    
    def initialize_likelihood(self, A_ini, B_ini, P_pathway_ini, P_pathway_ind_ini, infect_info):
        # Compute initial Q1
        log_S_E_ini, H_E_ini = self.update_explicit_likelihood_component(A_ini, self.hyparams["dist"], infect_info, self.hyparams['hard_thres'], self.hyparams["delta"])
        Q1_ini = self.get_Q1(log_S_E_ini, H_E_ini, infect_info[1], P_pathway_ind_ini, 
                self.hyparams["hard_thres"])
        Q1_ini -= self.hyparams['penl_l1_A'] * torch.norm(A_ini.ravel(), p=1)

        # Compute initial Q2
        B_ini_use = B_ini - torch.diag_embed(torch.diagonal(B_ini))
        log_S_I_ini, H_I_ini = self.update_implicit_likelihood_component(B_ini_use, self.hyparams["dist"], infect_info, self.hyparams['hard_thres'], self.hyparams['delta'])
        Q2_ini = self.get_Q2(
            log_S_I_ini, H_I_ini, infect_info[1], P_pathway_ind_ini,
            self.hyparams["hard_thres"])
        Q2_ini -= self.hyparams['penl_svd'] * torch.norm(B_ini, p='nuc') - self.hyparams['penl_l1_B'] * torch.norm(B_ini.ravel(), p=1)

        # Compute initial Q3
        non_source_mask = infect_info[5]
        Q3_ini = self.get_Q3(P_pathway_ind_ini, P_pathway_ini, non_source_mask)

        return log_S_E_ini, log_S_I_ini, H_E_ini, H_I_ini, Q1_ini, Q2_ini, Q3_ini

    def em_step(self, A_update, B_update, supp_A, P_pathway_update, log_S_E, log_S_I, H_E, H_I, infect_info, Q1_update, Q2_update):
        Delay, infected_mask, Ind_1, Ind_2, Ind_pow, non_source_mask = infect_info

        # E-step: Update P_pathway_ind
        P_pathway_ind = self.e_step(P_pathway_update, log_S_E, log_S_I, H_E, H_I, infected_mask)

        # M-step updates
        P_pathway_update = torch.mean(P_pathway_ind[:, 0], dim=0)
        Q3_update = self.get_Q3(P_pathway_ind, P_pathway_update, non_source_mask)

        # Update explicit and implicit network
        A_update, B_update, B_update_use, Q1_update, Q2_update, log_S_E, log_S_I, H_E, H_I = \
        self.m_step(A_update, supp_A, B_update, H_E, H_I, infect_info, P_pathway_ind)

        return (A_update, B_update, B_update_use, P_pathway_update,
                Q1_update, Q3_update, Q2_update, log_S_E, H_E, log_S_I, H_I)

    def get_infection_info(self, cascades, t, delta):
        nc, nd = cascades.shape
        infect_index_status = (cascades < t).astype(int)
        
        Delay = cascades[:, :, None] - cascades[:, None, :]
        Delay[Delay > 0] = 0
        
        temp_ind = (Delay < 0).astype(int)
        Ind_1 = infect_index_status[:, :, None] * temp_ind
        Ind_2 = infect_index_status[:, :, None] * temp_ind * infect_index_status[:, None, :]
        Ind_pow = (Delay < -delta).astype(int)
        
        cas_source = np.argmax(np.sum(Delay < 0, axis=2) >= nd - 1, axis=1)
        infect_index = [np.setdiff1d(np.where(infect_index_status[i] > 0)[0], [cas_source[i]]) for i in range(nc)]
        
        Delay = torch.tensor(Delay, device=self.device, dtype=torch.float64)
        Ind_1 = torch.tensor(Ind_1, device=self.device, dtype=torch.float64)
        Ind_2 = torch.tensor(Ind_2, device=self.device, dtype=torch.float64)
        Ind_pow = torch.tensor(Ind_pow, device=self.device, dtype=torch.float64)

        non_source_mask = torch.ones(nc, nd, dtype=torch.bool, device=self.device)
        non_source_mask[torch.arange(nc), cas_source] = False

        infected_mask = torch.zeros(nc, nd, device=self.device, dtype=torch.bool)
        for i in range(nc):
            infected_mask[i, infect_index[i]] = True
        
        return Delay, infected_mask, Ind_1, Ind_2, Ind_pow, non_source_mask

    def update_explicit_likelihood_component(self, A_update, dist, infect_info, hard_thres, delta):
        Delay = infect_info[0]
        Ind_2 = infect_info[3]
        nc = Delay.shape[0]
        theta_E = A_update.repeat(nc, 1, 1)

        if dist == "exp":
            log_S_E = theta_E * Delay
            H_E = theta_E * Ind_2
        elif dist == "ray":
            log_S_E = theta_E * -(Delay) ** 2 / 2
            H_E = theta_E * Ind_2 * (-Delay)
        elif dist == "pow":
            Ind_pow = infect_info[4]
            Delay_S = Delay.clone()
            Delay_H = Delay.clone()
            mask = (Delay == 0)
            Delay_S[mask] = -delta
            Delay_H[mask] = -hard_thres
            log_S_E = theta_E * -torch.log(-Delay_S / delta) * Ind_pow
            H_E = theta_E * Ind_2 / (-Delay_H) * Ind_pow

        return log_S_E, H_E

    def update_implicit_likelihood_component(self, B_update_use, dist, infect_info, hard_thres, delta):
        Delay = infect_info[0]
        Ind_2 = infect_info[3]
        nc = Delay.shape[0]
        theta_I = B_update_use.repeat(nc, 1, 1)

        if dist == "exp":
            log_S_I = theta_I * Delay
            H_I = theta_I * Ind_2
        elif dist == "ray":
            log_S_I = theta_I * -(Delay) ** 2 / 2
            H_I = theta_I * Ind_2 * (-Delay)
        elif dist == "pow":
            Ind_pow = infect_info[4]
            Delay_S = Delay.clone()
            Delay_H = Delay.clone()
            mask = (Delay == 0)
            Delay_S[mask] = -delta
            Delay_H[mask] = -hard_thres
            log_S_I = theta_I * -torch.log(-Delay_S / delta) * Ind_pow
            H_I = theta_I * Ind_2 / (-Delay_H) * Ind_pow

        return log_S_I, H_I
    
    def get_Q1(self, log_S_E, H_E, infected_mask, P_pathway_ind, hard_thres):
        logSurv_E = torch.sum(log_S_E * P_pathway_ind.unsqueeze(1), dim=(1, 2))  
        a = torch.sum(H_E, dim=1).clamp_min_(hard_thres)
        hazard_vals = infected_mask * P_pathway_ind * torch.log(a)
        Haza_E = hazard_vals.sum(dim=1)

        return logSurv_E.sum() + Haza_E.sum()

    def get_Q2(self, log_S_I, H_I, infected_mask, P_pathway_ind, hard_thres):
        logSurv_I = torch.sum(log_S_I * (1 - P_pathway_ind).unsqueeze(1), dim=(1, 2))
        b = torch.sum(H_I, dim=1).clamp_min_(hard_thres)
        hazard_vals = infected_mask * (1 - P_pathway_ind) * torch.log(b)
        Haza_I = hazard_vals.sum(dim=1)

        return logSurv_I.sum() + Haza_I.sum()
    
    def get_Q3(self, P_pathway_ind, P_pathway_update, non_source_mask):
        nc = P_pathway_ind.shape[0]
        nd = P_pathway_ind.shape[1]

        p = P_pathway_update * torch.ones(nc, nd, device=self.device)
        log_terms = (P_pathway_ind * torch.log(p) + (1 - P_pathway_ind) * torch.log(1 - p))
        log_terms = log_terms * non_source_mask
        
        return log_terms.sum()
    
    def e_step(self, P_pathway_update, log_S_E, log_S_I, H_E, H_I, infected_mask):
        nc = log_S_E.shape[0]; nd = log_S_E.shape[1]
        S_ratio = torch.exp(torch.sum(log_S_I, dim=1) - torch.sum(log_S_E, dim=1))

        a = torch.sum(H_E, dim=1).clamp_min_(self.hyparams["hard_thres"])
        b = torch.sum(H_I, dim=1) / a
        H_ratio = torch.ones_like(b)
        H_ratio[infected_mask] = b[infected_mask]

        c = 1 / (1 + ((1 - P_pathway_update) / P_pathway_update) * torch.prod(S_ratio * H_ratio, axis=1))
        c = torch.abs(c - self.hyparams['hard_thres'])
        return c.unsqueeze(1).repeat(1, nd)
    
    def get_Q1_grad(self, H_E, infect_info, P_pathway_ind, supp_A, dist, hard_thres, delta):
        Delay, Ind_1, Ind_2 = infect_info[0], infect_info[2], infect_info[3]
        a = torch.sum(Ind_2 * H_E, dim=1)
        a = a.clamp_min_(hard_thres)

        if dist == "exp":
            grad_S_E = Ind_1 * Delay * P_pathway_ind.unsqueeze(1)
            grad_H_E = Ind_2 * (P_pathway_ind / a).unsqueeze(1)
        elif dist == "ray":
            grad_S_E = Ind_1 * (-Delay**2 / 2) * P_pathway_ind.unsqueeze(1)
            grad_H_E = Ind_2 * (-Delay) * (P_pathway_ind / a).unsqueeze(1)
        elif dist == "pow":
            Ind_pow = infect_info[4]
            Delay_S, Delay_H = Delay.clone(), Delay.clone()
            mask = (Delay == 0)
            Delay_S[mask] = -delta
            Delay_H[mask] = -hard_thres
            grad_S_E = Ind_1 * -torch.log(-Delay_S / delta) * P_pathway_ind.unsqueeze(1) * Ind_pow
            grad_H_E = Ind_2 / (-Delay_H) * (P_pathway_ind / a).unsqueeze(1) * Ind_pow

        grad_E = (grad_S_E + grad_H_E) * supp_A
        grad_E = torch.sum(grad_E, dim=0)
        grad_E = grad_E - torch.diag_embed(torch.diagonal(grad_E))
        return grad_E

    def get_Q2_grad(self, H_I, infect_info, P_pathway_ind, dist, hard_thres, delta):
        Delay, Ind_1, Ind_2 = infect_info[0], infect_info[2], infect_info[3]
        a = torch.sum(Ind_2 * H_I, dim=1)
        a = a.clamp_min_(hard_thres)

        if dist == "exp":
            grad_S_I = Ind_1 * Delay * (1 - P_pathway_ind).unsqueeze(1)
            grad_H_I = Ind_2 * ((1 - P_pathway_ind) / a).unsqueeze(1)
        elif dist == "ray":
            grad_S_I = Ind_1 * (-Delay**2 / 2) * (1 - P_pathway_ind).unsqueeze(1)
            grad_H_I = Ind_2 * (-Delay) * ((1 - P_pathway_ind) / a).unsqueeze(1)
        elif dist == "pow":
            Ind_pow = infect_info[4]
            Delay_S, Delay_H = Delay.clone(), Delay.clone()
            mask = (Delay == 0)
            Delay_S[mask] = -delta
            Delay_H[mask] = -hard_thres
            grad_S_I = Ind_1 * -torch.log(-Delay_S / delta) * (1 - P_pathway_ind).unsqueeze(1) * Ind_pow
            grad_H_I = Ind_2 / (-Delay_H) * ((1 - P_pathway_ind) / a).unsqueeze(1) * Ind_pow

        grad_I = grad_S_I + grad_H_I
        grad_I = torch.sum(grad_I, dim=0)
        return grad_I
    
    def m_step(self, A_update, supp_A, B_update, H_E, H_I, infect_info, P_pathway_ind):
        # Update theta using gradient ascent
        # Gradient ascent
        Gra_E = self.get_Q1_grad(H_E, infect_info, P_pathway_ind, supp_A, 
                                self.hyparams["dist"], self.hyparams["hard_thres"], self.hyparams["delta"])
        A_update = A_update + self.hyparams["lr_A"] * Gra_E
        A_update = torch.sign(A_update) * torch.abs(A_update).clamp_min_(self.hyparams["lr_A"] * self.hyparams["penl_l1_A"])
        A_update = torch.clamp(A_update, min=0)
        A_update = A_update - torch.diag_embed(torch.diagonal(A_update))

        # Update Q1
        log_S_E, H_E = self.update_explicit_likelihood_component(A_update, self.hyparams["dist"], infect_info, self.hyparams['hard_thres'], self.hyparams["delta"])
        Q1_update = self.get_Q1(
            log_S_E, H_E, infect_info[1], P_pathway_ind, 
            self.hyparams["hard_thres"])
        Q1_update -= self.hyparams["penl_l1_A"] * torch.norm(A_update.ravel(), p=1)
        
        # Update psi using gradient ascent and projection
        # Gradient ascent
        Gra_I = self.get_Q2_grad(H_I, infect_info, P_pathway_ind, 
                                self.hyparams["dist"], self.hyparams["hard_thres"], self.hyparams["delta"])
        B_update = B_update + self.hyparams["lr_B"] * Gra_I

        # Projection step
        U, S, Vh = torch.linalg.svd(B_update, False)
        S = torch.clamp(S - self.hyparams["lr_B"] * self.hyparams["penl_svd"], min=0)
        B_update = U @ torch.diag(S) @ Vh
        B_update = torch.sign(B_update) * torch.abs(B_update).clamp_min_(self.hyparams["lr_B"] * self.hyparams["penl_l1_B"])
        B_update = torch.clamp(B_update, min=0)
        B_update_use = B_update - torch.diag_embed(torch.diagonal(B_update))

        # Update Q2
        log_S_I, H_I = self.update_implicit_likelihood_component(B_update_use, self.hyparams["dist"], infect_info, self.hyparams['hard_thres'], self.hyparams["delta"])
        Q2_update = self.get_Q2(
            log_S_I, H_I, infect_info[1], P_pathway_ind,
            self.hyparams["hard_thres"])
        Q2_update -= (self.hyparams["penl_svd"] * torch.norm(B_update, p='nuc') - 
                        self.hyparams["penl_l1_B"] * torch.norm(B_update.ravel(), p=1))
        
        return A_update, B_update, B_update_use, Q1_update, Q2_update, log_S_E, log_S_I, H_E, H_I
