import math 
import time
import torch
import torch.nn.functional as F
import torch.distributions as D 
from utils.emissions import *



class FaCTHMM_model(torch.nn.Module):

    def __init__(self, data_dim=2, n_bits=4, n_flows=16, n_h=64, flow_layers=3, base_layers=3, dataset_stats=None):        
        super().__init__()
        self.flow = nn.ModuleList()
        base_block = []
        #base_block.append(NaiveBatchNormFlow(dataset_stats,n_flows))
        base_block.append(n_flow_init(dataset_stats,n_flows))
        for i in range(flow_layers):
            base_block.append(Multi_RealNVP(data_dim, n_flows, n_h, base_layers))
            if i < (flow_layers-1):
                base_block.append(ActNorm(data_dim, n_flows))
        self.flow = NormalizingFlowModel(base_block, n_flows)
        self.emit = Gaussian_Dist_Exact_nf(data_dim, n_bits, n_flows) 
        #===============emission parameters============================ 
        self.p_h_0 = torch.nn.Parameter(torch.randn(2**n_bits))
        self.q_param = torch.nn.Parameter(torch.FloatTensor(1,n_bits,2).uniform_(1e-6, 1.0))
        self.register_buffer("bit_base", 2**torch.arange(n_bits-1,-1,-1).unsqueeze(1)) 
        self.n_bits = n_bits 
        self.data_dim = data_dim

    def log_vector_kronecker(self, vec, mtx): # mtx^T *@ vec; vec: (batch,n_states) mtx: (batch,n_bits,2,2); n_states = 2^n_bits
        for i in range(mtx.shape[1]):
            vec = vec.view(vec.shape[0],2,-1).transpose(-2,-1)
            vec = (vec.unsqueeze(-1) + mtx[:,i].unsqueeze(1)).logsumexp(-2)
        return vec.view(vec.shape[0],-1)

    def log_vector_kronecker_decode(self, vec, mtx): # mtx^T *@ vec; vec: (batch,n_states) mtx: (batch,n_bits,2,2); n_states = 2^n_bits
        bt = vec.new_empty(0)
        for i in range(mtx.shape[1]):
            vec = vec.view(vec.shape[0],2,-1).transpose(-2,-1)
            bt = bt.view(vec.shape[0], i, 2, vec.shape[-2]).transpose(-2,-1)
            vec, bt_ =  (vec.unsqueeze(-1) + mtx[:,i].unsqueeze(1)).max(-2)
            bt = torch.cat([bt,bt_.unsqueeze(1)],dim=1)
        bt = bt.view(bt.shape[0], self.n_bits, -1)
        #bt = self.bit_base.matmul(bt).squeeze(1)
        bt = (self.bit_base*bt).sum(1)
        return vec.view(vec.shape[0],-1), bt
    
    def forward(self, s): # forward algorithm of HMM
        #========================================================= 
        # s[0]: lat,lon,t
        # s[1]: batchsize
        #=========================================================
        latlon = s[0][:,:-1]        
        bs = s[1]
        t_diff = s[0][bs[0]:,[-1]]
        #==emission probabilities=================================
        h, log_det_nf = self.flow(latlon)
        #p_oh = (self.emit(h) + log_det_nf.unsqueeze(-1)).view(latlon.shape[0],-1) # (n_obs, n_flows, 2**n_bits/n_flows)
        p_oh = (self.emit(h) + log_det_nf.unsqueeze(1)).view(latlon.shape[0],-1) # (n_obs, 2**n_bits/n_flows, n_flows)

        p_h_0 = self.p_h_0 - self.p_h_0.logsumexp(0,keepdim=True)
        #=========================================================           
        
        #===========calculate matrix exponential with n_bits 2 x 2 matrix=========== 
        q0_plus_q1 = self.q_param.sum(2) #(1,n_bits)
        
        #in case q0_plus_q1==0, the zeros in self.q_param will work on Q_param = Q_param.unsqueeze(2) * self.q_param
        #q0_plus_q1[q0_plus_q1==0] = 1
        
        Q_param = -(-q0_plus_q1*t_diff).expm1() / q0_plus_q1
        Q_param = Q_param.unsqueeze(2) * self.q_param #(n_obs,n_bits,2)
        #need inplace operation to cut off gradients in following becasue t_diff may be zero, leading to identity transition prob matrix 
        #=====20230825 speed==============
        #Q_param[Q_param==0]=0
        #Q_param[Q_param==1]=1
        #===========================
        Q_param = torch.cat([Q_param.log(),(-Q_param).log1p()],dim=2)[:,:,[2,0,1,3]].view(t_diff.shape[0],self.n_bits,2,2)  #(n_obs,n_bits,2,2)
        #log_Q_param = Q_param.log()
        #log1p_nQ_param = (-Q_param).log1p()
        #Q_param = torch.cat([log1p_nQ_param[:,:,:1],log_Q_param,log1p_nQ_param[:,:,-1:]],dim=2).view(t_diff.shape[0],self.n_bits,2,2)  #(n_obs,n_bits,2,2)
        #=========================================================================== 

        #=====================forward algorithm=====================================
        p_oh = p_oh.split(bs,0)
        Q_param = Q_param.split(bs[1:],0)
        alphas = p_h_0 + p_oh[0] #(n_seq,n_states)
        for b_id in range(len(bs)-1):
            alphas[:bs[b_id+1]] = self.log_vector_kronecker( alphas[:bs[b_id+1]],   Q_param[b_id] )
            alphas[:bs[b_id+1]] = alphas[:bs[b_id+1]] + p_oh[b_id+1]
 
            #alp_s = log_vector_kronecker( alphas[:bs[b_id+1]],   Q_param[b_id] ) #vector * kronecker
            ##need inplace operation to cut off gradients for logsumexp
            #alp_s[alp_s==-float("inf")]=-float("inf")
            #alphas[alphas==-float("inf")]=-float("inf")
            #alphas[:bs[b_id+1]] = alp_s      
        alphas = alphas.logsumexp(-1)
        return -alphas.sum()

    @torch.no_grad()
    def decode(self, s): # forward algorithm of HMM
        #========================================================= 
        # s[0]: lat,lon,t
        # s[1]: batchsize
        #=========================================================
        latlon = s[0][:,:-1]
        bs = s[1]
        t_diff = s[0][bs[0]:,[-1]]
        #==emission probabilities=================================
        h, log_det_nf = self.flow(latlon)
        p_oh = (self.emit(h) + log_det_nf.unsqueeze(1)).view(latlon.shape[0],-1) # (n_obs, 2**n_bits/n_flows, n_flows)

        p_h_0 = self.p_h_0 - self.p_h_0.logsumexp(0,keepdim=True)
        #=========================================================           

        #===========calculate matrix exponential with n_bits 2 x 2 matrix=========== 
        q0_plus_q1 = self.q_param.sum(2) #(1,n_bits)
        Q_param = -(-q0_plus_q1*t_diff).expm1() / q0_plus_q1
        Q_param = Q_param.unsqueeze(2) * self.q_param #(n_obs,n_bits,2)
        Q_param = torch.cat([Q_param.log(),(-Q_param).log1p()],dim=2)[:,:,[2,0,1,3]].view(t_diff.shape[0],self.n_bits,2,2)  #(n_obs,n_bits,2,2)
        #=========================================================================== 

        #=====================forward algorithm=====================================
        p_oh = p_oh.split(bs,0)
        Q_param = Q_param.split(bs[1:],0)
        v = p_h_0 + p_oh[0] #(n_seq,n_states)
        bt = []
        for b_id in range(len(bs)-1):
            v[:bs[b_id+1]], bt_ = self.log_vector_kronecker_decode( v[:bs[b_id+1]],   Q_param[b_id] )
            bt.append(2**self.n_bits-1-bt_)
            v[:bs[b_id+1]] = v[:bs[b_id+1]] + p_oh[b_id+1]
        #==========================================================================

        #============back tracing===============================
        P, qt = v.max(-1,keepdim=True)
        q = [ qt[:bs[-1]].clone() ]
        for r_id, bt_ in enumerate(reversed(bt)):
            qt[:bt_.shape[0]] = bt_.take_along_dim(qt[:bt_.shape[0]],1)
            q.append( qt[:bs[-r_id-2]].clone() )
        #=======================================================
        return P, torch.cat(q[::-1],dim=0)

    @torch.no_grad()
    def predict(self, s):
        #=========================================================
        # s[0]: lat,lon,t
        # s[1]: batchsize
        #=========================================================
        latlon = s[0][:,:-1]
        bs = s[1]
        t_diff = s[0][bs[0]:,[-1]]
        #==emission probabilities=================================
        h, log_det_nf = self.flow(latlon) 
        p_oh = (self.emit(h) + log_det_nf.unsqueeze(1)).view(latlon.shape[0],-1)  # (n_obs, 2**n_bits/n_flows, n_flows)
        
        p_h_0 = self.p_h_0 - self.p_h_0.logsumexp(0,keepdim=True)
        #=========================================================

        #===========calculate matrix exponential with n_bits 2 x 2 matrix===========
        q0_plus_q1 = self.q_param.sum(2) #(1,n_bits)
        Q_param = -(-q0_plus_q1*t_diff).expm1() / q0_plus_q1
        Q_param = Q_param.unsqueeze(2) * self.q_param #(n_obs,n_bits,2)
        Q_param = torch.cat([Q_param.log(),(-Q_param).log1p()],dim=2)[:,:,[2,0,1,3]].view(t_diff.shape[0],self.n_bits,2,2)  #(n_obs,n_bits,2,2)
        #===========================================================================

        #=================forward algorithm=====================================
        p_oh = p_oh.split(bs,0)
        Q_param = Q_param.split(bs[1:],0)
        coefs = [p_h_0.repeat(bs[0],1)] 
        for b_id in range(len(bs)-1):
            alphas = coefs[-1] + p_oh[b_id]
            coefs.append(self.log_vector_kronecker( alphas[:bs[b_id+1]],   Q_param[b_id] ) )
        coefs = torch.concat(coefs, dim=0)
        return coefs - coefs.logsumexp(-1, keepdim=True)
        
    @torch.no_grad()
    def sample(self, samp_size, s):
        if not hasattr(self, "g_comp"):
            samp_loc = self.emit.mu.view(2**self.n_bits, -1)
            samp_scale = 2**-0.5/self.emit.sqrt2_sigma_inv
            samp_scale = samp_scale.view(2**self.n_bits, -1)
            self.g_comp = [D.normal.Normal(loc=samp_loc[_], scale=samp_scale[_]) for _ in range(2**self.n_bits)]
        coefs = self.predict(s)
        w = D.multinomial.Multinomial(samp_size,logits=coefs,validate_args=False)
        w = w.sample().int()
        pred = [self.g_comp[bit].sample((_,)).split(w[:,bit].tolist(),0) for bit, _ in enumerate(w.sum(0))]
        pred = torch.cat([torch.cat([_[i] for _ in pred],0) for i in range(coefs.shape[0])],0) #(samp_size*n_batch, data_dim)
        flow_id = torch.repeat_interleave(torch.arange(coefs.shape[1],device=pred.device).fmod(self.flow.n_flows).repeat(coefs.shape[0]), w.view(-1))
        pred = self.flow.reverse(pred, flow_id)
        return pred.view(coefs.shape[0],samp_size,-1).transpose(0,1)

    @torch.no_grad()
    def sum_batch_crps(self, samp_size, s): #Modified from pyro
        pred = self.sample(samp_size, s)
        opts = dict(device=pred.device, dtype=pred.dtype)
        pred = pred.sort(dim=0).values
        weight = torch.arange(-samp_size+1, samp_size, 2, **opts).unsqueeze(-1).unsqueeze(-1)
        return ( (pred - s[0][:,:-1]).abs().mean(0) - (pred * weight).sum(0) / samp_size**2 ).sum()/self.data_dim

    @torch.no_grad()
    def predict_onestep(self, states, t_diff): #states: (n_states, 1); t_diff: (n_obs,1)
        if not hasattr(self, "g_comp"):
            samp_loc = self.emit.mu.view(2**self.n_bits, -1)
            samp_scale = 2**-0.5/self.emit.sqrt2_sigma_inv
            samp_scale = samp_scale.view(2**self.n_bits, -1)
            self.g_comp = [D.normal.Normal(loc=samp_loc[_], scale=samp_scale[_]) for _ in range(2**self.n_bits)]
        h = torch.stack([self.g_comp[_].sample() for _ in states])
        loc = self.flow.reverse(h, states.squeeze() % self.flow.n_flows)
        
        #===========calculate matrix exponential with n_bits 2 x 2 matrix===========
        q0_plus_q1 = self.q_param.sum(2) #(1,n_bits)
        #q0_plus_q1[q0_plus_q1==0] = 1
        Q_param = -(-q0_plus_q1*t_diff).expm1() / q0_plus_q1
        Q_param = Q_param.unsqueeze(2) * self.q_param #(n_obs,n_bits,2)
        Q_param = torch.cat([Q_param.log(),(-Q_param).log1p()],dim=2)[:,:,[2,0,1,3]].view(t_diff.shape[0],self.n_bits,2,2)  #(n_obs,n_bits,2,2)
        #===========================================================================
        
        alphas = torch.full((t_diff.shape[0], 2**self.n_bits), -torch.inf, device = t_diff.device)
        alphas[range(t_diff.shape[0]), states]=0
        alphas = self.log_vector_kronecker(alphas, Q_param)
        sampled_states = torch.multinomial(alphas.isfinite().float(),1)
        return sampled_states, loc
