import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from typing import List, Dict, Tuple, Optional
import math
from dataclasses import dataclass

class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super().__init__(); self.chomp_size = chomp_size
    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous() if self.chomp_size > 0 else x

class TemporalBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, dilation, dropout=0.1):
        super().__init__()
        padding = (kernel_size - 1) * dilation
        self.conv1 = nn.Conv1d(in_ch, out_ch, kernel_size, padding=padding, dilation=dilation)
        self.chomp1 = Chomp1d(padding); self.relu1 = nn.Tanh(); self.do1 = nn.Dropout(dropout)
        self.conv2 = nn.Conv1d(out_ch, out_ch, kernel_size, padding=padding, dilation=dilation)
        self.chomp2 = Chomp1d(padding); self.relu2 = nn.Tanh(); self.do2 = nn.Dropout(dropout)
        self.down = nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else None
        self.relu = nn.Tanh()

    def forward(self, x):

        out = self.relu1(self.chomp1(self.conv1(x)))
        out = self.relu2(self.chomp2(self.conv2(out)))
        res = x if self.down is None else self.down(x)
        return self.relu(out + res)  

class SharedTCNEmbedder(nn.Module):
    def __init__(
        self,
        window_size: int,
        embedding_dim: int,
        n_layers: int = 6,
        ch: int = 16,
        ksize: int = 3,
        dropout: float = 0.2,
    ):
        super().__init__()
        layers = []
        in_ch = 1

        for i in range(n_layers):
            layers.append(
                TemporalBlock(
                    in_ch,
                    ch,
                    kernel_size=ksize,
                    dilation=2 ** i,
                    dropout=dropout,
                )
            )
            in_ch = ch

        self.tcn = nn.Sequential(*layers)
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool1d(1), 
            nn.Flatten(),           
            nn.Linear(ch, embedding_dim), 
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x = x.unsqueeze(1)     
        h = self.tcn(x)      
        z = self.head(h)     
        return F.normalize(z, dim=-1)
    
class Global_Embedder(nn.Module):

    def __init__(self, window_size: int, num_series: int, global_embedding_dim: int = 128):
        super().__init__()
        self.window_size = window_size
        self.num_series = num_series
        self.global_embedding_dim = global_embedding_dim
        
        input_dim = window_size * num_series
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 32),
            nn.Tanh(),
            nn.Linear(32, global_embedding_dim)
        )
        
    def forward(self, global_windows: torch.Tensor) -> torch.Tensor:

        B, W, N = global_windows.shape

        flattened = global_windows.reshape(B, W * N)

        global_embeddings = self.encoder(flattened)
        
        return global_embeddings

class Global_Predictor(nn.Module):

    def __init__(self, num_series: int, global_embedding_dim: int = 128):
        super().__init__()
        self.num_series = num_series
        self.global_embedding_dim = global_embedding_dim
        
        self.decoder = nn.Linear(global_embedding_dim, num_series)
        
    def forward(self, global_embeddings: torch.Tensor) -> torch.Tensor:
        predicted = self.decoder(global_embeddings)  
        
        return predicted


class TopoDistill(nn.Module):
    def __init__(
        self,
        num_sequences: int,
        window_size: int,
        embedding_dim: int = 128,
        global_embedding_dim = 128,
        adapter_dim: int = 32,
        TCN_n_layers: int = 6,
        TCN_ch: int = 16,
        TCN_ksize: int = 3,
        global_win_last_n: int = 5
    ):
        super().__init__()
        
        self.num_sequences = num_sequences
        self.window_size = window_size
        self.embedding_dim = embedding_dim
        self.adapter_dim = adapter_dim
        
        self.shared_embedder = SharedTCNEmbedder(
            window_size=window_size,
            embedding_dim=embedding_dim,
            n_layers=TCN_n_layers,   
            ch=TCN_ch,
            ksize=TCN_ksize,
            dropout=0,
        )

        self.global_embedder = Global_Embedder(
            window_size=global_win_last_n,
            num_series=num_sequences,
            global_embedding_dim=global_embedding_dim   
        )
        
        self.global_predictor = Global_Predictor(
            num_series=num_sequences,
            global_embedding_dim=global_embedding_dim
        )
        
        self.inference_mode = False
        
    def set_inference_mode(self, mode: bool):
        self.inference_mode = mode

    
    def forward(self, batch_data: Dict) -> Dict:

        if self.inference_mode:
            individual_windows = batch_data['individual_windows']  
            individual_embeddings = self.shared_embedder(individual_windows)  
            return {
                'individual_embeddings': individual_embeddings
            }
        
        global_windows = batch_data['global_windows']  
        individual_windows = batch_data['individual_windows'] 
        series_ids = batch_data['series_ids']  
        
        B, W_temp, N = global_windows.shape

        global_embeddings = self.global_embedder(global_windows)

        individual_embeddings = self.shared_embedder(individual_windows) 
        
        expanded_global_embeddings = global_embeddings.repeat_interleave(N, dim=0) 
        
        observed_embeddings = self.get_adapted_embedding(expanded_global_embeddings, series_ids) 
        
        return {
            'global_embeddings': global_embeddings, 
            'individual_embeddings': individual_embeddings, 
            'individual_reconstruction': None,
            'observed_embeddings': observed_embeddings, 
            'reconstructed_global': None,
            'adapter_vectors': None
        }
