import numpy as np
import torch
import torch.nn as nn
from .contconv import SpatialInducedConv
from .seq2seqbase import Seq2SeqAttrs
from .operator.operatorencoder import *
from .recurrent.decoder import *
from .operator.odedecoder import *
from .operator.intdecoder import *
from stmodels.embedding.time import *
from .operatefunc import *

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class STONet(nn.Module, Seq2SeqAttrs):
    def __init__(self, level_sizes, **model_kwargs):

        super().__init__()
        Seq2SeqAttrs.__init__(self, level_sizes, **model_kwargs)
        
        self.time_embedding = TriTimeEmbedding(self.time_dim, self.embed_size)
        self.operator_function = OperateFunc(self.embed_size + self.location_dim, embed_size=self.embed_size*2)
        self.projection_layer = nn.Linear(self.embed_size, self.output_dim)
        
        self.conv_ru = SpatialInducedConv(
            embed_size = self.embed_size*2, 
            ker_input_size = 2 + 2*self.location_dim, 
            ker_embed_size = self.ker_embed_size, 
            level_sizes = self.level_sizes, 
            level_num = self.level_num
        )
        self.conv_c = SpatialInducedConv(
            embed_size = self.embed_size*2, 
            ker_input_size = 2 + 2*self.location_dim, 
            ker_embed_size = self.ker_embed_size, 
            level_sizes = self.level_sizes, 
            level_num = self.level_num
        )
        self.encoder_model = EncoderModel(self.conv_ru, self.conv_c, level_sizes, **model_kwargs)
        
        if self.cont_dec == True:
            self.decoder_model = ODEDecoderModel(level_sizes, **model_kwargs)
        else:
            self.decoder_model = IntDecoderModel(level_sizes, **model_kwargs)
            
        self.cl_decay_steps = int(model_kwargs.get('cl_decay_steps', 1000))
        self.use_curriculum_learning = bool(model_kwargs.get('use_curriculum_learning', False))

    def update_edge_attr(self, batch):
        batch_size, in_len, node_num, feature_num = batch.data.shape

        t_emb = self.time_embedding(batch.time_in)
        t_emb = t_emb[:,:,None,:].expand(batch_size, in_len, node_num, -1)

        x = batch.x[:,None,:,:].expand(batch_size, in_len, node_num, -1)

        a_in = torch.cat([t_emb, x], dim=-1)
        a_out = self.operator_function(a_in).squeeze(dim=-1)

        edge_index_mid, edge_index_down, edge_index_up = batch.edge_index
        edge_attr_mid, edge_attr_down, edge_attr_up = batch.edge_attr
        edge_range_mid, edge_range_down, edge_range_up = batch.edge_range

        edge_attr_mid_plus = []
        for l in reversed(range(self.level_num)):
            a_edge = a_out[..., edge_index_mid[:,edge_range_mid[l,0]:edge_range_mid[l,1]]]
            edge_attr_mid_plus.append(a_edge)
        edge_attr_mid_plus = torch.cat(edge_attr_mid_plus, dim=-1).transpose(-1,-2)
        edge_attr_mid = torch.cat([
            edge_attr_mid[None, None, ...].repeat(batch_size, in_len, 1, 1), 
            edge_attr_mid_plus], dim=-1)
        
        edge_attr_down_plus = []
        for l in range(self.level_num-1):
            a_edge = a_out[..., edge_index_down[:,edge_range_down[l,0]:edge_range_down[l,1]]]
            edge_attr_down_plus.append(a_edge)
        edge_attr_down_plus = torch.cat(edge_attr_down_plus, dim=-1).transpose(-1,-2) if len(edge_attr_down_plus)!=0 else edge_attr_mid_plus
        edge_attr_down = torch.cat([
            edge_attr_down[None, None, ...].repeat(batch_size, in_len, 1, 1),
            edge_attr_down_plus], dim=-1)
        

        edge_attr_up_plus = []
        for l in range(self.level_num-1):
            a_edge = a_out[..., edge_index_up[:,edge_range_up[l,0]:edge_range_up[l,1]]]
            edge_attr_up_plus.append(a_edge)
        edge_attr_up_plus = torch.cat(edge_attr_up_plus, dim=-1).transpose(-1,-2) if len(edge_attr_up_plus)!=0 else edge_attr_mid_plus
        edge_attr_up = torch.cat([
            edge_attr_up[None, None, ...].repeat(batch_size, in_len, 1, 1), 
            edge_attr_up_plus], dim=-1)

        batch.edge_attr = [edge_attr_mid, edge_attr_down, edge_attr_up]

        return batch
    
    def _compute_sampling_threshold(self, batches_seen):
        return self.cl_decay_steps / (
                self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps))

    def get_kernel(self):
        return self.conv_ru, self.conv_c

    def encoder(self, batch):
        all_hidden_state = []
        encoder_hidden_state = None
        for t in range(self.encoder_model.seq_len):
            _, encoder_hidden_state = self.encoder_model(batch, t, encoder_hidden_state)
            hidden_state_agg = encoder_hidden_state[-1]
            all_hidden_state.append(hidden_state_agg)
        all_hidden_state = torch.stack(all_hidden_state, dim=1)
        return all_hidden_state
    
    
    def decoder(self, encoder_hidden_states, batch, batches_seen=None):
        
        t_past = batch.t_past if hasattr(batch, 't_past') else \
            torch.tensor(np.arange(-self.encoder_model.seq_len + 1, 1)).to(encoder_hidden_states)
        
        edge_index_all = batch.edge_index_all
        location = batch.x
        
        if self.cont_dec == True:
            t_eval = batch.t_eval if hasattr(batch, 't_eval') else \
                torch.tensor(np.arange(0, self.decoder_model.horizon + 1)).to(encoder_hidden_states)
        else:
            t_eval = batch.t_eval if hasattr(batch, 't_eval') else \
                torch.tensor(np.arange(0, self.decoder_model.horizon)).to(encoder_hidden_states)
                
        outputs = self.decoder_model(t_eval, encoder_hidden_states, edge_index_all, location, t_past) 
        
        return outputs


    def forward(self, batch, batches_seen=None):

        batch = self.update_edge_attr(batch)
        encoder_hidden_states = self.encoder(batch)
        
        decoder_hidden_states = self.decoder(encoder_hidden_states, batch, batches_seen=batches_seen)

        outputs = self.projection_layer(decoder_hidden_states)
        
        if self.reconst_loss == True:
            reconst = self.projection_layer(encoder_hidden_states)
            return outputs, reconst
        else:
            return outputs
    
