import torch
from torch import nn
import torch.nn.functional as F
from .gwnet_arch import GraphWaveNet
import pickle


class T3GWNet(nn.Module):

    def __init__(self, num_nodes, dropout=0.3, supports=None,
                    gcn_bool=True, addaptadj=True, aptinit=None,
                    in_dim=2, out_dim=12, residual_channels=32,
                    dilation_channels=32, skip_channels=256, end_channels=512,
                    kernel_size=2, blocks=4, layers=2, 
                    event_embedding_dim=1024, event_hidden_embedding_dim=256, pre_trained_ts_path=None, event_embedding_path=None):
        super(T3GWNet, self).__init__()
        
        self.event_embedding_dim = event_embedding_dim
        self.event_hidden_embedding_dim = event_hidden_embedding_dim
        self.pre_trained_tsformer_path = pre_trained_ts_path
        
        self.ts_encoder = GraphWaveNet(num_nodes, dropout, supports, gcn_bool, addaptadj, aptinit, in_dim, out_dim, residual_channels,
                    dilation_channels, skip_channels, end_channels, kernel_size, blocks, layers)
        
        self.text_transform = nn.Sequential(
            nn.Conv2d(in_channels=self.event_embedding_dim, out_channels=self.event_hidden_embedding_dim, kernel_size=(1, 1), bias=True),
            nn.ReLU(),
            nn.Dropout(p=0.15),
            nn.Conv2d(in_channels=self.event_hidden_embedding_dim, out_channels=self.event_hidden_embedding_dim, kernel_size=(1, 1), bias=True),
        )
        
        self.mm_forecaster = nn.Sequential(
            nn.Conv2d(in_channels=end_channels + event_hidden_embedding_dim, out_channels=end_channels, kernel_size=(1, 1), bias=True),
            nn.ReLU(),
            nn.Dropout(p=0.15),
            nn.Conv2d(in_channels=end_channels, out_channels=out_dim, kernel_size=(1, 1), bias=True),
        )
        
        self.ts_forecaster = nn.Sequential(
            nn.Conv2d(in_channels=end_channels, out_channels=end_channels, kernel_size=(1, 1), bias=True),
            nn.ReLU(),
            nn.Dropout(p=0.15),
            nn.Conv2d(in_channels=end_channels, out_channels=out_dim, kernel_size=(1, 1), bias=True),
        )
        
        with open(event_embedding_path, "rb") as f:
            self.daily_info = pickle.load(f)
                
        self.load_pre_trained_model()
        self.ts_encoder.end_conv_2 = nn.Sequential()
        
    def load_pre_trained_model(self):
        """Load pre-trained model"""

        # load parameters
        checkpoint_dict = torch.load(self.pre_trained_tsformer_path)
        self.ts_encoder.load_state_dict(checkpoint_dict["model_state_dict"])
        # freeze parameters
        for param in self.ts_encoder.parameters():
            param.requires_grad = False

    def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, batch_seen: int, epoch: int, train: bool, **kwargs) -> torch.Tensor:
        """Feedforward function of Graph WaveNet.

        Args:
            history_data (torch.Tensor): shape [B, L, N, C]

        Returns:
            torch.Tensor: [B, L, N, 1]
        """

        ts_emb = self.ts_encoder.encoder_forward(history_data[..., 0:-1], future_data, batch_seen, epoch, train, **kwargs) # (B, D, N, 1)
        B, _, N, _ = ts_emb.shape
        
        diy_batch = (history_data[..., -1]* 365).type(torch.LongTensor)
        event_emb = torch.zeros((B, self.event_embedding_dim), requires_grad=False).to(ts_emb)
        for batch_id, diy in enumerate(diy_batch[:, 0, 0]):
            diy = int(diy)
            if diy in self.daily_info:
                event_emb[batch_id, ...] = torch.tensor(self.daily_info[diy]['embedding']).to(ts_emb) # (B, 1024)
        event_emb = event_emb.unsqueeze(-1).unsqueeze(-1)
        event_emb = self.text_transform(event_emb) # (B, D, 1, 1)
        event_emb = event_emb.expand(B, self.event_hidden_embedding_dim, N, 1)
        
        fused_emb = torch.cat((event_emb, ts_emb), dim=1)
        
        mm_pred = self.mm_forecaster(fused_emb)
        ts_pred = self.ts_forecaster(ts_emb)
        
        return ts_pred + mm_pred