import torch
import numpy as np

class NetRate_torch:
    def __init__(self, dist="exp", delta=1, lr=0.0001, penl_l1 = 10, hard_thres=0.01, eps=0.0001):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if torch.cuda.is_available():
            print("cuda device enabled")

        self.hyparams = {
            'dist': dist,
            'delta': delta,
            'lr': lr,
            'penl_l1': penl_l1,
            'hard_thres': hard_thres,
            'stopping': eps
        }

    def get_infection_info(self, cascades, t, delta):
        nc, nd = cascades.shape
        infect_index_status = (cascades < t).astype(int)
        
        Delay = cascades[:, :, None] - cascades[:, None, :]
        Delay[Delay > 0] = 0
        
        temp_ind = (Delay < 0).astype(int)
        Ind_1 = infect_index_status[:, :, None] * temp_ind
        Ind_2 = infect_index_status[:, :, None] * temp_ind * infect_index_status[:, None, :]
        Ind_pow = (Delay < -delta).astype(int)
        
        cas_source = np.argmax(np.sum(Delay < 0, axis=2) >= nd - 1, axis=1)
        infect_index = [np.setdiff1d(np.where(infect_index_status[i] > 0)[0], [cas_source[i]]) for i in range(nc)]
        
        Delay = torch.tensor(Delay, device=self.device, dtype=torch.float64)
        Ind_1 = torch.tensor(Ind_1, device=self.device, dtype=torch.float64)
        Ind_2 = torch.tensor(Ind_2, device=self.device, dtype=torch.float64)
        Ind_pow = torch.tensor(Ind_pow, device=self.device, dtype=torch.float64)

        infected_mask = torch.zeros(nc, nd, device=self.device, dtype=torch.bool)
        for i in range(nc):
            infected_mask[i, infect_index[i]] = True
        
        return Delay, infected_mask, Ind_1, Ind_2, Ind_pow
    
    def optimize(self, A_ini, cascades, t):
        A_ini = torch.tensor(A_ini, device=self.device, dtype=torch.float64)

        # Get infection info
        infect_info = self.get_infection_info(cascades, t, self.hyparams['delta'])

        # Compute initial likelihood
        logSurv_ini, Haza_ini = self.update_likelihood_component(A_ini, self.hyparams["dist"], infect_info, self.hyparams['hard_thres'], self.hyparams["delta"])
        Likelihood_ini = self.get_loss(logSurv_ini, Haza_ini, infect_info[1], self.hyparams["hard_thres"])
        Likelihood_ini -= self.hyparams["penl_l1"] * torch.norm(A_ini.ravel(), p=1)

        # Run NetRate
        return self._run_iterations(A_ini, Haza_ini, infect_info, Likelihood_ini)

    def _run_iterations(self, A_update, Haza, infect_info, Likelihood_update):
        Likelihood_record = []
        Likelihood = Likelihood_update.clone()
        Iter = 0

        while ((Likelihood_update - Likelihood) / torch.abs(Likelihood)) > self.hyparams["stopping"] or (Iter == 0):
            Likelihood = Likelihood_update.clone()

            # Run single network update step
            step_results = self.update_network(A_update, Haza, infect_info, Likelihood_update)

            # Unpack results
            A_update, Haza, Likelihood_update = step_results

            Likelihood_record.append(Likelihood_update)
            Iter += 1

        return {
            'A': A_update.cpu().numpy(),
            'loss': Likelihood_record
        }
    
    def update_network(self, A_update, Haza, infect_info, Likelihood_update):
        # Gradient ascent
        Gra = self.get_grad(Haza, infect_info, self.hyparams["dist"], self.hyparams["hard_thres"], self.hyparams["delta"])
        A_update += self.hyparams["lr"] * Gra
        A_update = torch.sign(A_update) * torch.clamp(
            torch.abs(A_update) - self.hyparams["lr"] * self.hyparams["penl_l1"], min=0)
        A_update = torch.clamp(A_update, min=0)
        A_update = A_update - torch.diag_embed(torch.diagonal(A_update))
        
        # Update loss
        logSurv, Haza = self.update_likelihood_component(A_update, self.hyparams["dist"], infect_info, self.hyparams['hard_thres'], self.hyparams["delta"])
        Likelihood_update = self.get_loss(logSurv, Haza, infect_info[1], self.hyparams["hard_thres"])
        Likelihood_update -= self.hyparams["penl_l1"] * torch.norm(A_update.ravel(), p=1)
            
        return A_update, Haza, Likelihood_update
    
    def get_grad(self, Haza, infect_info, dist, hard_thres, delta):
        Delay, Ind_1, Ind_2 = infect_info[0], infect_info[2], infect_info[3]
        a = torch.sum(Ind_2 * Haza, dim=1)
        a = a.clamp_min_(hard_thres)

        if dist == "exp":
            grad_S = Ind_1 * Delay
            grad_H = Ind_2 * (1 / a).unsqueeze(1)
        elif dist == "ray":
            grad_S = Ind_1 * (-Delay**2/2)
            grad_H = Ind_2 * (-Delay) * (1 / a).unsqueeze(1)
        elif dist == "pow":
            Ind_pow = infect_info[4]
            Delay_S = Delay.clone(); Delay_H = Delay.clone()
            mask = (Delay == 0)
            Delay_S[mask] = -delta; Delay_H[mask] = -hard_thres
            grad_S = Ind_1 * -torch.log(-Delay_S / delta) * Ind_pow
            grad_H = Ind_2 / (-Delay_H) * (1 / a).unsqueeze(1) * Ind_pow

        grad = (grad_S + grad_H)
        grad = torch.sum(grad, dim=0)
        grad = grad - torch.diag_embed(torch.diagonal(grad))
        return grad
    
    def update_likelihood_component(self, A_update, dist, infect_info, hard_thres, delta):
        Delay = infect_info[0]; Ind_2 = infect_info[3]
        nc = Delay.shape[0]
        theta = A_update.repeat(nc, 1, 1)
        
        if dist == "exp":
            logSurv = theta * Delay
            Haza = theta * Ind_2
        elif dist == "ray":
            logSurv = theta * -(Delay)**2/2
            Haza = theta * Ind_2 * (-Delay)
        elif dist == "pow":
            Ind_pow = infect_info[4]
            Delay_S = Delay.clone(); Delay_H = Delay.clone()
            mask = (Delay == 0)
            Delay_S[mask] = -delta; Delay_H[mask] = -hard_thres
            logSurv = theta * -torch.log(-Delay_S / delta) * Ind_pow
            Haza = theta * Ind_2 / (-Delay_H) * Ind_pow
        return logSurv, Haza

    def get_loss(self, logSurv, Haza, infected_mask, hard_thres):
        a = torch.sum(Haza, dim=1).clamp_min_(hard_thres)
        hazard_vals = infected_mask * torch.log(a)
        H = hazard_vals.sum(dim=1)

        return logSurv.sum() + H.sum()