import torch
import torch.nn as nn
from ..seq2seqbase import Seq2SeqAttrs
from stmodels.layers import DenseNet
from torch_scatter import scatter_mean, scatter_sum
from stmodels.embedding.time import *
from stmodels.libs.utils import clamp_preserve_gradients

class OperatorNet(nn.Module, Seq2SeqAttrs):
    def __init__(self, level_sizes, **model_kwargs):
        nn.Module.__init__(self)
        Seq2SeqAttrs.__init__(self, level_sizes, **model_kwargs)
        self.time_emb_net = nn.ModuleList([TriTimeEmbedding(1, self.time_dim*2)])
        self.branch_weight = DenseNet([self.location_dim + self.time_dim*2, self.ker_embed_size, self.embed_size], nn.GELU)
        self.trunck_nn = DenseNet([self.location_dim + self.time_dim*2, self.embed_size], nn.GELU)
        self.branch_bias = nn.Parameter(torch.randn(self.embed_size))
        self.batch_norm_net = nn.BatchNorm1d(self.embed_size)
    
    def time_emb(self, time):
        shape = list(time.shape)
        shape[-1] = self.time_dim*2
        
        t_emb = self.time_emb_net[0](time)
        t_emb = t_emb.reshape(-1, self.time_dim*2)
        t_emb = t_emb.reshape(shape)
        
        return t_emb
        
    
    def branch_net(self, t, conditional_state, edge_idx_all, location, t_past):
        assert (len(t_past) == conditional_state.shape[1]) 
        edge_idx_all = edge_idx_all.long()
        h_from = conditional_state[:,:,edge_idx_all[0],:]
        h_to = conditional_state[:,:,edge_idx_all[1],:]
        edge_loc = location[:,edge_idx_all[0]] - location[:,edge_idx_all[1]]
        time_diff = (t - t_past).unsqueeze(-1)
        
        edge_loc = edge_loc[:,None,:,:].repeat(1, len(time_diff), 1, 1)
        
        edge_time = self.time_emb(time_diff)
        edge_time = edge_time[None,:,None,:].repeat(edge_loc.shape[0], 1, edge_loc.shape[2], 1)
        
        edge_attr = torch.cat([edge_loc, edge_time], dim=-1)
        w = self.branch_weight(edge_attr)
        out = torch.zeros_like(conditional_state[:,0])
        message = (h_from * w).mean(dim=1) + h_to[:,-1]
        out = scatter_mean(message, edge_idx_all[1], dim=-2, out=out)
        out += self.branch_bias
        
        return out
    
    def trunck_net(self, t, location, t_past):
        time_diff = (t - t_past).unsqueeze(-1)
        e_loc = location[:,None,:,:].repeat(1, len(time_diff), 1, 1)
        
        e_time = self.time_emb(time_diff)
        e_time = e_time[None,:,None,:].repeat(e_loc.shape[0], 1, e_loc.shape[2], 1)
        
        y = torch.cat([e_loc, e_time], dim=-1)
        out = self.trunck_nn(y).mean(dim=1)
        
        return out
    
    def batch_norm(self, x):
        shape = x.shape
        x = x.reshape(-1, self.embed_size)
        return self.batch_norm_net(x).reshape(shape)
    
    def forward(self, t, hidden_state, conditional_state, edge_idx_all, location, t_past):
        # H: B,T,N,D
        assert (len(t_past) == conditional_state.shape[1]) 
        
        t = t.clone().detach().requires_grad_(True)
        
        if len(t.size()) == 0:
            branch = self.branch_net(t, conditional_state, edge_idx_all, location, t_past)
            trunck = self.trunck_net(t, location, t_past)
        else:
            branch = self.batch_branch_net(t, conditional_state, edge_idx_all, location, t_past)
            trunck = self.batch_trunck_net(t, location, t_past)
            
        output = torch.clamp(branch * trunck, -1e8, 1e8)
        # output = branch * trunck
        output = self.batch_norm(output)
        return output
    
    def batch_branch_net(self, t, conditional_state, edge_idx_all, location, t_past):
        assert (len(t_past) == conditional_state.shape[1]) 
        edge_idx_all = edge_idx_all.long()
        h_from = conditional_state[:,:,edge_idx_all[0],:]
        h_to = conditional_state[:,:,edge_idx_all[1],:]
        edge_loc = location[:,edge_idx_all[0]] - location[:,edge_idx_all[1]]
        time_diff = (t[:,None] - t_past[None,:]).unsqueeze(-1)
        
        edge_loc = edge_loc[:,None,None,:,:].repeat(1, time_diff.shape[0], time_diff.shape[1], 1, 1)
        
        edge_time = self.time_emb(time_diff)
        edge_time = edge_time[None,:,:,None,:].repeat(edge_loc.shape[0], 1, 1, edge_loc.shape[3], 1)
        
        edge_attr = torch.cat([edge_loc, edge_time], dim=-1)
        w = self.branch_weight(edge_attr)
        out = torch.zeros_like(conditional_state[:,[0]].repeat(1,time_diff.shape[0],1,1))
        message = (h_from[:,None,...] * w).mean(dim=2) + h_to[:,-1][:,None,...]
        out = scatter_mean(message, edge_idx_all[1], dim=-2, out=out)
        out += self.branch_bias
        
        return out
    
    def batch_trunck_net(self, t, location, t_past):
        time_diff = (t[:,None] - t_past[None,:]).unsqueeze(-1)
        e_loc = location[:,None,None,:,:].repeat(1, time_diff.shape[0], time_diff.shape[1], 1, 1)
        
        e_time = self.time_emb(time_diff)
        e_time = e_time[None,:,:,None,:].repeat(e_loc.shape[0], 1, 1, e_loc.shape[3], 1)
        
        y = torch.cat([e_loc, e_time], dim=-1)
        out = self.trunck_nn(y).mean(dim=2)
        
        return out