import torch
import torch.nn as nn
import torch.nn.functional as F
import torchcde
from torchdiffeq import odeint

import math
import time 
from main.model.interpolation import KernelInterpolation, WeightedKernelInterpolation, GPInterpolation, WeightedGPInterpolation


class SimpleCDEFunc(nn.Module):
    def __init__(self, input_channels, hidden_channels, seq_len):
        super().__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        
        self.weight1 = nn.Parameter(torch.empty(128, hidden_channels))
        self.bias1 = nn.Parameter(torch.empty(128))
        self.weight2 = nn.Parameter(torch.empty(hidden_channels * input_channels, 128))
        self.bias2 = nn.Parameter(torch.empty(hidden_channels * input_channels))
        
        self.nfe = 0
        self.last_tanh_saturation = 0.0
        self.last_tanh_mean = 0.0
        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.bias1, -bound, bound)
        nn.init.uniform_(self.bias2, -bound, bound)

    def forward(self, t, z):
        self.nfe += 1
        z1 = torch.einsum('oi,bi->bo', self.weight1, z) + self.bias1
        z1 = F.relu(z1)
        
        z2 = torch.einsum('oi,bi->bo', self.weight2, z1) + self.bias2
        z2 = torch.tanh(z2)

        with torch.no_grad():
            self.last_tanh_mean = z2.abs().mean().item()
            self.last_tanh_saturation = (z2.abs() > 0.95).float().mean().item()

        return z2.view(z.size(0), self.hidden_channels, self.input_channels)

class NeuralCDE(nn.Module):
    """A standard Neural CDE model with selectable interpolation methods."""
    def __init__(self, input_channels, hidden_channels, output_channels, seq_len, interpolation="cubic", kernel_params={}, tol=1e-4, add_time=True, t_grid=None):
        super().__init__()
        self.func = SimpleCDEFunc(input_channels, hidden_channels, seq_len)
        self.initial = nn.Linear(input_channels, hidden_channels)
        self.readout = nn.Linear(hidden_channels, output_channels)
        self.interpolation = interpolation
        self.kernel_params = kernel_params
        self.tol = tol
        self.add_time = add_time
        
        self.fit_time_accum = 0.0
        
        if t_grid is not None:
            self.register_buffer('t_grid', t_grid)
        else:
            self.t_grid = None

    def reset_fit_timer(self):
        self.fit_time_accum = 0.0

    def make_interpolation(self, coeffs):
        """Helper to create interpolation object for diagnostics."""
        
        if self.interpolation == 'cubic':
            if self.t_grid is not None:
                return torchcde.CubicSpline(coeffs, t=self.t_grid)
            else:
                t_val = coeffs[0, :, 0]
                c = coeffs if self.add_time else [] 
                return torchcde.CubicSpline(c, t=t_val)
                
        elif self.interpolation == 'linear':
            if self.t_grid is not None:
                return torchcde.LinearInterpolation(coeffs, t=self.t_grid)
            else:
                t_val = coeffs[0, :, 0]
                c = coeffs if self.add_time else []
                return torchcde.LinearInterpolation(c, t=t_val)
                
        elif self.interpolation == 'kernel':
            return KernelInterpolation(coeffs=coeffs, t=self.t_grid, kernel_params=self.kernel_params, include_time=self.add_time)
        
        elif self.interpolation == 'gp':
            return GPInterpolation(coeffs=coeffs, t=self.t_grid, gp_params=self.kernel_params, include_time=self.add_time)
            
        return None

    def forward(self, coeffs):
        self.func.nfe = 0
        
        t0 = time.perf_counter()
        X = self.make_interpolation(coeffs)
        if coeffs.is_cuda:
            torch.cuda.synchronize()
        self.fit_time_accum += (time.perf_counter() - t0)
        
        X0 = X.evaluate(X.interval[0])
        z0 = self.initial(X0)
        z_T = torchcde.cdeint(X=X, z0=z0, func=self.func, t=X.interval, atol=self.tol, rtol=self.tol)
        z_T_final = z_T[:, 1]
        pred_y = self.readout(z_T_final)
        return pred_y


class ParallelCDEFunc(nn.Module):
    """
    A CDE function that processes multiple heads in parallel.
    """
    def __init__(self, input_channels, hidden_channels, num_heads, seq_len):
        super(ParallelCDEFunc, self).__init__()
        self.seq_len = seq_len
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.num_heads = num_heads
        
        intermediate_channels = hidden_channels

        self.layer1_weights = nn.Parameter(torch.empty(num_heads, intermediate_channels, hidden_channels))
        self.layer1_bias = nn.Parameter(torch.empty(num_heads, intermediate_channels))
        self.layer2_weights = nn.Parameter(torch.empty(num_heads, hidden_channels * input_channels, intermediate_channels))
        self.layer2_bias = nn.Parameter(torch.empty(num_heads, hidden_channels * input_channels))

        self.nfe = 0
        self.last_tanh_saturation = 0.0
        self.last_tanh_mean = 0.0
        
        self._reset_parameters()


    def _reset_parameters(self):
        for weight in [self.layer1_weights, self.layer2_weights]:
            nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
        
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.layer1_weights[0])
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        for bias in [self.layer1_bias, self.layer2_bias]:
            nn.init.uniform_(bias, -bound, bound)

    def forward(self, t, z):
        self.nfe += 1 

        batch_dim = z.size(0)
        if batch_dim == 0:
            return torch.zeros(0, self.hidden_channels, self.input_channels, device=z.device)
            
        batch_size = batch_dim // self.num_heads
        
        z_bhi = z.view(batch_size, self.num_heads, self.hidden_channels)

        hidden = torch.einsum('hoi,bhi->bho', self.layer1_weights, z_bhi) + self.layer1_bias.unsqueeze(0)
        hidden = F.relu(hidden)

        output = torch.einsum('hoi,bhi->bho', self.layer2_weights, hidden) + self.layer2_bias.unsqueeze(0)
        output = torch.tanh(output) 
        
        with torch.no_grad():
            abs_out = output.abs()
            self.last_tanh_mean = abs_out.mean().item()
            self.last_tanh_saturation = (abs_out > 0.95).float().mean().item()

        return output.reshape(batch_dim, self.hidden_channels, self.input_channels)
    
class QFormerCDE(nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels, seq_len, qformer_params={}, add_time=True, t_grid=None):
        super(QFormerCDE, self).__init__()

        self.input_channels = input_channels
        
        self.kernel = qformer_params.get('kernel', 'gaussian')
        self.bandwidths = qformer_params.get('bandwidths', [1.0])
        self.noise_std = qformer_params.get('noise_std', 0.01)
        self.tol = qformer_params.get('tol', 1e-4)
        self.aggregation = qformer_params.get('aggregation', 'concat')
        self.add_time = add_time
        
        self.fit_time_accum = 0.0
        
        if t_grid is not None:
            self.register_buffer('t_grid', t_grid)
        else:
            self.t_grid = None
        
        self.num_heads = len(self.bandwidths)
        self.hidden_channels = hidden_channels 

        if self.t_grid is not None:
            raw_dim = input_channels
        else:
            raw_dim = input_channels + (1 if not add_time else 0)

        self.queries = nn.Parameter(torch.randn(self.num_heads, raw_dim))
        
        self.cde_func = ParallelCDEFunc(input_channels, self.hidden_channels, self.num_heads, seq_len)

        self.initial_layer = nn.Linear(input_channels, self.hidden_channels)
        if self.aggregation == 'concat':
            readout_input_features = self.hidden_channels * self.num_heads
        elif self.aggregation in ['mean', 'max']:
            readout_input_features = self.hidden_channels
        else:
            raise ValueError(f"value: {self.aggregation}. choose: 'concat', 'mean', 'max'.")

        self.readout = nn.Linear(readout_input_features, output_channels)

    def reset_fit_timer(self):
        self.fit_time_accum = 0.0

    def make_interpolation(self, coeffs):
        batch_size = coeffs.size(0)
        device = coeffs.device

        d_k = self.queries.size(-1)
        attention_scores = torch.einsum('md,bld->bml', self.queries, coeffs) / math.sqrt(d_k)
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        coeffs_parallel = coeffs.repeat_interleave(self.num_heads, dim=0)
        weights_parallel = attention_weights.reshape(-1, attention_weights.size(-1))
        
        if self.kernel == 'gp':
            length_scales_tensor = torch.tensor(self.bandwidths, device=device, dtype=coeffs.dtype)
            length_scales_parallel = length_scales_tensor.repeat(batch_size)
            
            return WeightedGPInterpolation(
                coeffs=coeffs_parallel,
                weights=weights_parallel,
                t=self.t_grid,
                gp_params={'length_scale': length_scales_parallel, 'noise_std': self.noise_std}, 
                include_time=self.add_time
            )

        bandwidths_tensor = torch.tensor(self.bandwidths, device=device, dtype=coeffs.dtype)
        bandwidths_parallel = bandwidths_tensor.repeat(batch_size)
        
        return WeightedKernelInterpolation(
            coeffs=coeffs_parallel, 
            weights=weights_parallel, 
            t=self.t_grid,
            kernel_params={'kernel': self.kernel, 'bandwidth': bandwidths_parallel},
            include_time=self.add_time
        )
        
    def forward(self, coeffs):
        self.cde_func.nfe = 0
        
        t0 = time.perf_counter()
        X = self.make_interpolation(coeffs)
        if coeffs.is_cuda:
            torch.cuda.synchronize()
        self.fit_time_accum += (time.perf_counter() - t0)
        
        X0 = X.evaluate(X.interval[0])
        z0 = self.initial_layer(X0)

        z_T = torchcde.cdeint(
            X=X, z0=z0, func=self.cde_func, t=X.interval,
            atol= self.tol, rtol=self.tol
        )
        
        final_states = z_T[:, 1, :]
        batch_size = coeffs.size(0)
        final_states_reshaped = final_states.view(batch_size, self.num_heads, self.hidden_channels)
        
        if self.aggregation == 'concat':
            aggregated_states = final_states_reshaped.reshape(batch_size, -1)
        elif self.aggregation == 'mean':
            aggregated_states = torch.mean(final_states_reshaped, dim=1)
        elif self.aggregation == 'max':
            aggregated_states, _ = torch.max(final_states_reshaped, dim=1)

        pred_y = self.readout(aggregated_states)
        return pred_y

class ConvCDE(nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels, seq_len, conv_params={}, add_time=True, t_grid=None):
        super(ConvCDE, self).__init__()

        self.input_channels = input_channels
        
        self.kernel = conv_params.get('kernel', 'gaussian')
        self.bandwidths = conv_params.get('bandwidths', [1.0])
        self.noise_std = conv_params.get('noise_std', 0.01)
        self.tol = conv_params.get('tol', 1e-4)
        self.aggregation = conv_params.get('aggregation', 'concat')
        self.add_time = add_time
        
        self.fit_time_accum = 0.0
        
        if t_grid is not None:
            self.register_buffer('t_grid', t_grid)
        else:
            self.t_grid = None
        
        kernel_size = conv_params.get('conv_kernel_size', 3)
        padding = kernel_size // 2 
        
        self.num_heads = len(self.bandwidths)
        self.hidden_channels = hidden_channels 

        if self.t_grid is not None:
             raw_dim = input_channels
        else:
             raw_dim = input_channels + (1 if not add_time else 0)

        self.conv1 = nn.Conv1d(
            in_channels=raw_dim, out_channels=self.hidden_channels,
            kernel_size=kernel_size, padding=padding, padding_mode='replicate'
        )
        self.conv2 = nn.Conv1d(
            in_channels=self.hidden_channels, out_channels=self.hidden_channels,
            kernel_size=kernel_size, padding=padding, padding_mode='replicate'
        )
        self.to_heads = nn.Linear(self.hidden_channels, self.num_heads)
        
        self.cde_func = ParallelCDEFunc(input_channels, self.hidden_channels, self.num_heads, seq_len)
        self.initial_layer = nn.Linear(input_channels, self.hidden_channels)
        
        if self.aggregation == 'concat':
            readout_input_features = self.hidden_channels * self.num_heads
        elif self.aggregation in ['mean', 'max']:
            readout_input_features = self.hidden_channels
        else:
            raise ValueError(f"Unknown aggregation: {self.aggregation}")

        self.readout = nn.Linear(readout_input_features, output_channels)

    def reset_fit_timer(self):
        self.fit_time_accum = 0.0

    def make_interpolation(self, coeffs):
        batch_size = coeffs.size(0)
        device = coeffs.device
        
        x = coeffs.permute(0, 2, 1) 
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.permute(0, 2, 1)
        
        raw_scores = self.to_heads(x)
        raw_scores = raw_scores.transpose(1, 2)
        attention_weights = F.softmax(raw_scores, dim=-1) 
        
        coeffs_parallel = coeffs.repeat_interleave(self.num_heads, dim=0)
        weights_parallel = attention_weights.reshape(-1, attention_weights.size(-1))

        if self.kernel == 'gp':
            length_scales_tensor = torch.tensor(self.bandwidths, device=device, dtype=coeffs.dtype)
            length_scales_parallel = length_scales_tensor.repeat(batch_size)
            
            return WeightedGPInterpolation(
                coeffs=coeffs_parallel,
                weights=weights_parallel,
                t=self.t_grid,
                # --- CHANGE: Use self.noise_std ---
                gp_params={'length_scale': length_scales_parallel, 'noise_std': self.noise_std},
                include_time=self.add_time
            )
        
        bandwidths_tensor = torch.tensor(self.bandwidths, device=device, dtype=coeffs.dtype)
        bandwidths_parallel = bandwidths_tensor.repeat(batch_size)
        
        return WeightedKernelInterpolation(
            coeffs=coeffs_parallel, 
            weights=weights_parallel, 
            t=self.t_grid,
            kernel_params={'kernel': self.kernel, 'bandwidth': bandwidths_parallel},
            include_time=self.add_time
        )
        
    def forward(self, coeffs):
        self.cde_func.nfe = 0
        
        t0 = time.perf_counter()
        X = self.make_interpolation(coeffs)
        if coeffs.is_cuda:
            torch.cuda.synchronize()
        self.fit_time_accum += (time.perf_counter() - t0)
        
        X0 = X.evaluate(X.interval[0])
        z0 = self.initial_layer(X0)

        z_T = torchcde.cdeint(
            X=X, z0=z0, func=self.cde_func, t=X.interval,
            atol=self.tol, rtol=self.tol
        )
        
        final_states = z_T[:, 1, :]
        batch_size = coeffs.size(0)
        final_states_reshaped = final_states.view(batch_size, self.num_heads, self.hidden_channels)
        
        if self.aggregation == 'concat':
            aggregated_states = final_states_reshaped.reshape(batch_size, -1)
        elif self.aggregation == 'mean':
            aggregated_states = torch.mean(final_states_reshaped, dim=1)
        elif self.aggregation == 'max':
            aggregated_states, _ = torch.max(final_states_reshaped, dim=1)

        pred_y = self.readout(aggregated_states)
        return pred_y
    

class ODERNNFunc(nn.Module):
    def __init__(self, hidden_channels):
        super(ODERNNFunc, self).__init__()
        self.hidden_channels = hidden_channels
        
        self.net = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.Tanh(),
            nn.Linear(hidden_channels, hidden_channels)
        )
        
        self.nfe = 0
        self.last_tanh_saturation = 0.0
        self.last_tanh_mean = 0.0
        
        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.kaiming_uniform_(self.net[0].weight, a=math.sqrt(5))
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.net[0].weight)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.net[0].bias, -bound, bound)

        nn.init.zeros_(self.net[2].weight)
        nn.init.zeros_(self.net[2].bias)

    def forward(self, t, h):
        self.nfe += 1
        return self.net(h)

class ODERNN(nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels, seq_len=None, tol=1e-3, add_time=True, t_grid=None):
        super(ODERNN, self).__init__()
        
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.tol = tol
        self.add_time = add_time
        
        self.fit_time_accum = 0.0
        
        if t_grid is not None:
            self.register_buffer('t_grid', t_grid)
        else:
            self.t_grid = None
            
        # GRU Cell
        self.gru_cell = nn.GRUCell(input_channels, hidden_channels)
        
        # ODE Function
        self.func = ODERNNFunc(hidden_channels)
        
        # Readout
        self.readout = nn.Linear(hidden_channels, output_channels)
        
        self._reset_parameters()

    def _reset_parameters(self):
        for name, param in self.gru_cell.named_parameters():
            if 'weight' in name:
                nn.init.kaiming_uniform_(param, a=math.sqrt(5))
            elif 'bias' in name:
                nn.init.zeros_(param)
        
        nn.init.kaiming_uniform_(self.readout.weight, a=math.sqrt(5))
        nn.init.zeros_(self.readout.bias)

    def reset_fit_timer(self):
        self.fit_time_accum = 0.0

    def forward(self, coeffs):
        # coeffs: [Batch, SeqLen, InputChannels]
        batch_size = coeffs.size(0)
        seq_len = coeffs.size(1)
        device = coeffs.device
        
        self.func.nfe = 0
        
        if self.t_grid is not None:
            time_grid = self.t_grid 
        else:
            if self.add_time: 
                time_grid = coeffs[0, :, 0]
            else:
              
                time_grid = torch.arange(seq_len, dtype=coeffs.dtype, device=device)

        h = torch.zeros(batch_size, self.hidden_channels, device=device)
        
        t_start_fit = time.perf_counter()

        h = self.gru_cell(coeffs[:, 0, :], h)
        
        for i in range(1, seq_len):
            t0, t1 = time_grid[i-1], time_grid[i]
            
            if torch.abs(t1 - t0) > 1e-5:
                h = odeint(self.func, h, torch.stack([t0, t1]), atol=self.tol, rtol=self.tol)[1]
            
            h = self.gru_cell(coeffs[:, i, :], h)
            
        if coeffs.is_cuda:
            torch.cuda.synchronize()
        self.fit_time_accum += (time.perf_counter() - t_start_fit)

        pred_y = self.readout(h)
        return pred_y


class GRUD(nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels, seq_len=None, add_time=True, t_grid=None):
        super(GRUD, self).__init__()
        
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.add_time = add_time
        
        self.fit_time_accum = 0.0
        
        if t_grid is not None:
            self.register_buffer('t_grid', t_grid)
        else:
            self.t_grid = None
            
        self.gru_cell = nn.GRUCell(input_channels, hidden_channels)
        
        self.decay_layer = nn.Linear(1, hidden_channels)
        
        self.readout = nn.Linear(hidden_channels, output_channels)
        
        self._reset_parameters()

    def _reset_parameters(self):
        for name, param in self.gru_cell.named_parameters():
            if 'weight' in name:
                nn.init.kaiming_uniform_(param, a=math.sqrt(5))
            elif 'bias' in name:
                nn.init.zeros_(param)
        
        nn.init.kaiming_uniform_(self.readout.weight, a=math.sqrt(5))
        if self.readout.bias is not None:
            nn.init.zeros_(self.readout.bias)

        nn.init.zeros_(self.decay_layer.weight)
        nn.init.zeros_(self.decay_layer.bias)

    def reset_fit_timer(self):
        self.fit_time_accum = 0.0

    def forward(self, coeffs):
        batch_size = coeffs.size(0)
        seq_len = coeffs.size(1)
        device = coeffs.device
      
        if self.t_grid is not None:
            time_grid = self.t_grid.unsqueeze(0).expand(batch_size, -1) 
        elif self.add_time:
            time_grid = coeffs[:, :, 0] 
        else:
            time_grid = torch.arange(seq_len, dtype=coeffs.dtype, device=device).unsqueeze(0).expand(batch_size, -1)

        delta_t = torch.zeros_like(time_grid)
        delta_t[:, 1:] = time_grid[:, 1:] - time_grid[:, :-1]
      
        delta_t = F.relu(delta_t)
        
        h = torch.zeros(batch_size, self.hidden_channels, device=device)
        
        for i in range(seq_len):
            if i > 0:
                dt = delta_t[:, i].unsqueeze(-1) # [Batch, 1]
                gamma = torch.exp(-F.relu(self.decay_layer(dt)))
                h = h * gamma
                
            x_t = coeffs[:, i, :]
            h = self.gru_cell(x_t, h)
            
        pred_y = self.readout(h)
        return pred_y
    


def compute_logsig_windows(x, window_size, depth):
    B, L, C = x.shape
    
    num_windows = L // window_size
    x_trimmed = x[:, :num_windows*window_size, :]
    
    x_windows = x_trimmed.view(B, num_windows, window_size, C)
    
    x_start = x_windows[:, :, 0, :]
    x_end = x_windows[:, :, -1, :]
    
    displacement = x_end - x_start 
    
    if depth == 1:
        return displacement
        
    elif depth == 2:
        tensor_prod = torch.einsum('bni,bnj->bnij', displacement, displacement)
        B_dim, N_dim, _, _ = tensor_prod.shape
        
        second_level = tensor_prod.view(B_dim, N_dim, -1)
        
        return torch.cat([displacement, second_level], dim=-1)
        
    else:
        raise ValueError("Only Depth 1 and 2 are supported for native implementation.")

class LogNeuralCDE(nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels, seq_len, 
                 log_ncde_params={}, tol=1e-4, add_time=True, t_grid=None):
        super().__init__()
        
        self.step_size = log_ncde_params.get('step_size', 5)
        self.depth = log_ncde_params.get('depth', 1)
        self.hidden_channels = hidden_channels
        self.tol = tol
        self.add_time = add_time 
        
        self.raw_input_channels = input_channels 
        
        if self.depth == 1:
            self.cde_input_dim = self.raw_input_channels
        elif self.depth == 2:
            self.cde_input_dim = self.raw_input_channels + (self.raw_input_channels ** 2)
        
        self.func = SimpleCDEFunc(self.cde_input_dim, hidden_channels, seq_len)
        self.initial = nn.Linear(self.cde_input_dim, hidden_channels)
        self.readout = nn.Linear(hidden_channels, output_channels)
        
        self.fit_time_accum = 0.0

    def reset_fit_timer(self):
        self.fit_time_accum = 0.0

    def forward(self, coeffs):
        self.func.nfe = 0
        
        t0 = time.perf_counter()
        
        logsig_path = compute_logsig_windows(coeffs, self.step_size, self.depth)
        
        X = torchcde.LinearInterpolation(logsig_path)
        
        if coeffs.is_cuda:
            torch.cuda.synchronize()
        self.fit_time_accum += (time.perf_counter() - t0)
        
        # 3. Solve CDE
        X0 = X.evaluate(X.interval[0])
        z0 = self.initial(X0)
        
        z_T = torchcde.cdeint(X=X, z0=z0, func=self.func, t=X.interval, atol=self.tol, rtol=self.tol)
        
        z_T_final = z_T[:, 1]
        pred_y = self.readout(z_T_final)
        
        return pred_y