

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = nn.Conv1d(n_inputs, n_outputs, kernel_size,
                               stride=stride, padding=padding, dilation=dilation)
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = nn.Conv1d(n_outputs, n_outputs, kernel_size,
                               stride=stride, padding=padding, dilation=dilation)
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)


class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()


class TCN(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TCN, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            padding = (kernel_size - 1) * dilation_size
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=padding, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


class SimpleTCN(nn.Module):
    def __init__(self, input_dim, hidden_dim=64, num_layers=3, output_dim=1, 
                 kernel_size=3, dropout=0.1, use_causal=True, forecast_horizon=24):
        super(SimpleTCN, self).__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.output_dim = output_dim
        self.forecast_horizon = forecast_horizon
        self.use_causal = use_causal
        
        self.input_projection = nn.Linear(input_dim, hidden_dim)
        
        self.tcn_layers = nn.ModuleList()
        for i in range(num_layers):
            dilation = 2 ** i
            if use_causal:
                padding = (kernel_size - 1) * dilation
            else:
                padding = (kernel_size - 1) // 2
            
            self.tcn_layers.append(
                nn.Conv1d(hidden_dim, hidden_dim, kernel_size, 
                         dilation=dilation, padding=padding)
            )
        
        self.output_layer = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()
        
        self._init_weights()
        
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        if len(x.shape) == 4:
            x = x.squeeze(1)
        batch_size, seq_len, _ = x.shape
        
        x = self.input_projection(x)  # (batch_size, seq_len, hidden_dim)
        
        x = x.transpose(1, 2)  # (batch_size, hidden_dim, seq_len)
        
        for tcn_layer in self.tcn_layers:
            residual = x
            x = self.relu(tcn_layer(x))
            x = self.dropout(x)
            
            if residual.size() == x.size():
                x = x + residual

        x = x.transpose(1, 2)  # (batch_size, seq_len, hidden_dim)

        x = x[:, -self.forecast_horizon:, :]  # (batch_size, forecast_horizon, hidden_dim)
        
        x = self.output_layer(x)  # (batch_size, forecast_horizon, output_dim)

        if self.output_dim == 1:
            x = x.squeeze(-1)  # (batch_size, forecast_horizon)
        
        return x


class TCNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim=64, num_layers=3, output_dim=1, 
                 kernel_size=3, dropout=0.1, use_causal=True, forecast_horizon=24):
        super(TCNModel, self).__init__()
        
        self.tcn = SimpleTCN(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            output_dim=output_dim,
            kernel_size=kernel_size,
            dropout=dropout,
            use_causal=use_causal,
            forecast_horizon=forecast_horizon
        )
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.output_dim = output_dim
        self.forecast_horizon = forecast_horizon
        
    def forward(self, x):
        if isinstance(x, dict):
            if 'numerical' in x:
                x = x['numerical']
            elif 'textual' in x:
                x = x['textual']
            elif 'structured' in x:
                x = x['structured']
            else:
                x = list(x.values())[0]
        
        if len(x.shape) == 2:
            x = x.unsqueeze(1)
        
        return self.tcn(x)
    
    def get_model_info(self):
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        
        return {
            "model_type": "TCN",
            "input_dim": self.input_dim,
            "hidden_dim": self.hidden_dim,
            "num_layers": self.num_layers,
            "output_dim": self.output_dim,
            "total_params": total_params,
            "trainable_params": trainable_params
        }


def create_tcn_model(input_dim, hidden_dim=64, num_layers=3, output_dim=1, 
                    kernel_size=3, dropout=0.1, use_causal=True):

    return TCNModel(
        input_dim=input_dim,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        output_dim=output_dim,
        kernel_size=kernel_size,
        dropout=dropout,
        use_causal=use_causal
    )


