import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ConditionalLinear(nn.Module):
    def __init__(self, num_in, num_out, n_steps):
        super(ConditionalLinear, self).__init__()
        self.num_out = num_out
        self.lin = nn.Linear(num_in, num_out)
        self.embed = nn.Embedding(n_steps, num_out)
        self.embed.weight.data.uniform_()

    def forward(self, x, t):
        out = self.lin(x)
        gamma = self.embed(t)
        # out = gamma.view(-1, self.num_out) * out

        out = gamma.view(t.size()[0], -1, self.num_out) * out
        return out


class ConditionalGuidedModel(nn.Module):
    def __init__(self, config, MTS_args):
        super(ConditionalGuidedModel, self).__init__()
        n_steps = config.diffusion.timesteps + 1
        self.cat_x = config.model.cat_x
        self.cat_y_pred = config.model.cat_y_pred
        data_dim = MTS_args.diff_in*2
        dim_denoisy = MTS_args.diff_dim

        self.lin1 = ConditionalLinear(data_dim, dim_denoisy, n_steps)
        self.lin2 = ConditionalLinear(dim_denoisy, dim_denoisy, n_steps)
        self.lin3 = ConditionalLinear(dim_denoisy, dim_denoisy, n_steps)
        self.lin4 = nn.Linear(dim_denoisy, MTS_args.diff_in)

    def forward(self, x, y_t, y_0_hat, t):
        if self.cat_x:
            if self.cat_y_pred:
                eps_pred = torch.cat((y_t, y_0_hat), dim=-1)
            else:
                eps_pred = torch.cat((y_t, x), dim=2)
        else:
            if self.cat_y_pred:
                eps_pred = torch.cat((y_t, y_0_hat), dim=2)
            else:
                eps_pred = y_t
        

        eps_pred = F.softplus(self.lin1(eps_pred, t))
        eps_pred = F.softplus(self.lin2(eps_pred, t))
        eps_pred = F.softplus(self.lin3(eps_pred, t))
        eps_pred = self.lin4(eps_pred)

        return eps_pred

class ConditionalGuidedGCNModel(nn.Module):
    def __init__(self, config, MTS_args):
        super(ConditionalGuidedGCNModel, self).__init__()
        
        self.n_steps = config.diffusion.timesteps
 
        self.cat_y_pred = config.model.cat_y_pred
        self.n_dlayer = 2
        
        # GCN parameters
        self.hidden_dim = MTS_args.hidden_dim if hasattr(MTS_args, 'hidden_dim') else 64
        self.max_kernel_size = MTS_args.max_kernel_size if hasattr(MTS_args, 'max_kernel_size') else 5
        self.temporal_decay = MTS_args.temporal_decay if hasattr(MTS_args, 'temporal_decay') else MTS_args.pred_len
        
        # Embedding dimensions
        self.d_model = MTS_args.d_model if hasattr(MTS_args, 'd_model') else 64
        
        self.temp_embed_dim = self.d_model-2 if self.cat_y_pred else self.d_model
        
        # GCN layers
    
        self.W1 = nn.Linear(2 + self.temp_embed_dim, self.hidden_dim)
        
        
        self.W2 = nn.Linear(2 * self.hidden_dim, self.hidden_dim)
      
        
        # Conditional projection
        self.cond_proj = nn.Linear(1, self.hidden_dim)
        
        # Attention mechanism
        self.attention_weight = nn.Linear(self.hidden_dim, 1)
        
        # Final FFN layers for mean estimation
        self.mid_dim = self.hidden_dim
        self.W3 = nn.Linear(self.hidden_dim, self.mid_dim)
        
        self.W4 = nn.Linear(self.mid_dim, 1)
      
        
        # Learnable beta schedule parameters
        self.register_buffer('beta', torch.linspace(config.diffusion.beta_start, config.diffusion.beta_end, self.n_steps))
        self.register_buffer('alpha', 1. - self.beta)
        self.register_buffer('alpha_bar', torch.cumprod(self.alpha, dim=0))
        
    def rotary_position_embedding(self, timestamps):
        """Apply Rotary Position Embedding to timestamps"""
        batch_size, seq_len = timestamps.shape
        positions = timestamps.unsqueeze(-1)
        
        # Simplified RoPE implementation
        freqs = torch.arange(0, self.temp_embed_dim, 2, device=timestamps.device).float()
        inv_freq = 1.0 / (10000 ** (freqs / self.temp_embed_dim))
        
 
        
        sinusoid = positions * inv_freq.unsqueeze(0).unsqueeze(0)
        sin_emb = torch.sin(sinusoid)
        cos_emb = torch.cos(sinusoid)
        
        # Interleave sin and cos
        emb = torch.stack([sin_emb, cos_emb], dim=-1)
        emb = emb.view(batch_size, seq_len, self.temp_embed_dim)
        
        return emb
    
    def build_observation_graph(self, y_t, mask, timestamps):
        """Build observation graph G_O based on valid observations"""
        batch_size, L, d = y_t.shape

        emb = self.rotary_position_embedding(timestamps)
        F_init = torch.concat([y_t, emb.to(y_t.device)], dim=-1)
    
        timestamps_ = timestamps.unsqueeze(1).repeat(1, L, 1)
        timestamps_Transpose = timestamps.unsqueeze(2).repeat(1, 1, L)
     

        time_dealt = torch.exp( -torch.square((timestamps_ - timestamps_Transpose)/self.temporal_decay))
        
        # A_0 = torch.zeros(size=(L, L), device=y_t.device)
        A_O = time_dealt
    
        D_O = torch.diag_embed(torch.sum(A_O, dim=2) + 1e-8)
     
        D_O_inv_sqrt = torch.inverse(torch.sqrt(D_O))
       
        A_hat = D_O_inv_sqrt @ A_O @ D_O_inv_sqrt
        return A_hat, F_init
    

    def noise_aware_gcn(self, F_init, A_hat, condition, t):
        """Apply noise-aware GCN with adaptive kernel size"""
        
        
       
        
        # First GCN layer
     
        L = A_hat.shape[-1]
        F1 = F.relu(self.W1( A_hat @ F_init))
        
        # Project condition
        
        E_c = self.cond_proj(condition)
        
        # Second GCN layer with condition
        F1_concat = torch.cat([F1, E_c], dim=-1)
        F2 = F.relu(self.W2( A_hat @ F1_concat) )
        
        return F2
    
    def compute_target_mean(self, y_t, y_0_pred, t):
        """Compute target mean for denoising step (Eq. 10)"""
        alpha_bar_t = self.alpha_bar[t]
        alpha_bar_t_prev = self.alpha_bar[t-1] if t > 0 else torch.tensor(1.0)
        beta_t = self.beta[t]
        
        # Coefficients for mean computation
        coeff1 = (torch.sqrt(alpha_bar_t_prev) * beta_t) / (1 - alpha_bar_t)
        coeff2 = (torch.sqrt(self.alpha[t]) * (1 - alpha_bar_t_prev)) / (1 - alpha_bar_t)
        
        target_mean = coeff1 * y_0_pred + coeff2 * y_t
        return target_mean
    
    def forward(self, x, y, y_t, y_0_hat, t, timestamps):
        """
        x: historical observations [batch_size, lookback, n_vars]
        y: ground truth future values [batch_size, horizon, 1] (training only)
        y_t: noisy future values at diffusion step t [batch_size, horizon, 1]
        y_0_hat: predicted future values from transformer [batch_size, horizon, 1]
        t: diffusion step [batch_size] (values from 0 to n_steps-1)
        timestamps: timestamps for each position [batch_size, horizon]
        
        Returns: estimated mean μ_θ for p(y^{t-1} | y^t, c)
        """
        batch_size, horizon, _ = y_t.shape
        if timestamps == None:
            timestamps = torch.arange(0, horizon, 1, device=y_0_hat.device).unsqueeze(0).repeat(batch_size, 1)
        # During training: future mixup to create condition c (Eq. 6)
        if self.training:
            mix_matrix = torch.rand(size=(batch_size, horizon, 1), device=y_t.device)
            condition_c = y_0_hat * mix_matrix + (1 - mix_matrix) * y
        else:
            condition_c = y_0_hat
        
        if self.cat_y_pred:
            y_t = torch.concat([condition_c, y_t], dim=-1)

        # if self.training:
        # Build observation graphs
        A_hat, F_init = self.build_observation_graph(y_t, None, timestamps)

        F2 = self.noise_aware_gcn(F_init, A_hat, condition_c, t)
        attention_scores = torch.softmax(self.attention_weight(F2) / math.sqrt(self.hidden_dim), dim=0)
       
        F_refine = attention_scores * F2
        hidden = F.relu(self.W3(F_refine) )
        mu_theta = self.W4(hidden) 
        
        return mu_theta

# deterministic feed forward neural network
class DeterministicFeedForwardNeuralNetwork(nn.Module):

    def __init__(self, dim_in, dim_out, hid_layers,
                 use_batchnorm=False, negative_slope=0.01, dropout_rate=0):
        super(DeterministicFeedForwardNeuralNetwork, self).__init__()
        self.dim_in = dim_in  # dimension of nn input
        self.dim_out = dim_out  # dimension of nn output
        self.hid_layers = hid_layers  # nn hidden layer architecture
        self.nn_layers = [self.dim_in] + self.hid_layers  # nn hidden layer architecture, except output layer
        self.use_batchnorm = use_batchnorm  # whether apply batch norm
        self.negative_slope = negative_slope  # negative slope for LeakyReLU
        self.dropout_rate = dropout_rate
        layers = self.create_nn_layers()
        self.network = nn.Sequential(*layers)

    def create_nn_layers(self):
        layers = []
        for idx in range(len(self.nn_layers) - 1):
            layers.append(nn.Linear(self.nn_layers[idx], self.nn_layers[idx + 1]))
            if self.use_batchnorm:
                layers.append(nn.BatchNorm1d(self.nn_layers[idx + 1]))
            layers.append(nn.LeakyReLU(negative_slope=self.negative_slope))
            layers.append(nn.Dropout(p=self.dropout_rate))
        layers.append(nn.Linear(self.nn_layers[-1], self.dim_out))
        return layers

    def forward(self, x):
        return self.network(x)


# early stopping scheme for hyperparameter tuning
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""

    def __init__(self, patience=10, delta=0):
        """
        Args:
            patience (int): Number of steps to wait after average improvement is below certain threshold.
                            Default: 10
            delta (float): Minimum change in the monitored quantity to qualify as an improvement;
                           shall be a small positive value.
                           Default: 0
            best_score: value of the best metric on the validation set.
            best_epoch: epoch with the best metric on the validation set.
        """
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_score = None
        self.best_epoch = None
        self.early_stop = False

    def __call__(self, val_cost, epoch, verbose=False):

        score = val_cost

        if self.best_score is None:
            self.best_score = score
            self.best_epoch = epoch + 1
        elif score > self.best_score - self.delta:
            self.counter += 1
            if verbose:
                print("EarlyStopping counter: {} out of {}...".format(
                    self.counter, self.patience))
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_epoch = epoch + 1
            self.counter = 0
