import torch
import torch.nn as nn
import numpy as np


class RecurrentCycle(torch.nn.Module):
    def __init__(self, cycle_len, channel_size):
        super(RecurrentCycle, self).__init__()
        self.cycle_len = cycle_len
        self.channel_size = channel_size
        self.data = torch.nn.Parameter(torch.zeros(cycle_len, channel_size), requires_grad=True)

    def init_data(self, data):
        """Initialize the cycle data with provided data."""
        if data.shape[0] != self.cycle_len or data.shape[1] != self.channel_size:
            raise ValueError(f"Data shape must be ({self.cycle_len}, {self.channel_size})")
        self.data.data.copy_(data)

    def forward(self, index, num_cycles):
        indices = (index.view(-1,1) + torch.arange(num_cycles*self.cycle_len, device=index.device)) % self.cycle_len
        return self.data[indices].reshape(-1, num_cycles, self.cycle_len, self.data.shape[-1])
    

class RESAM(nn.Module):
    def __init__(self, configs):
        super(RESAM, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.enc_in = configs.enc_in
        self.model_type = configs.model_type
        self.d_model = configs.d_model
        self.use_revin = configs.use_revin
        self.lambda_reg = 0.1
        self.cycle_len = configs.cycle
        self.sample_ratio = configs.sample_ratio

        # Generate frequency basis
        self.freqs = self._generate_frequencies(self.cycle_len)
        self.num_coeffs = len(self.freqs) * 2  # sin + cos for each frequency
        self.register_buffer('angle_scale', 2 * torch.pi * self.freqs)

        t_dec = torch.arange(self.pred_len) + self.seq_len 
        X_dec = self._generate_basis_matrix(t_dec)
        self.register_buffer('X_dec', X_dec)

        # Model architecture adaptation
        assert self.model_type in ['linear', 'mlp']
        if self.model_type == 'linear':
            self.model = nn.Linear(self.num_coeffs, self.num_coeffs)
        elif self.model_type == 'mlp':
            self.model = nn.Sequential(
                nn.Linear(self.num_coeffs, self.d_model),
                nn.ReLU(),
                nn.Linear(self.d_model, self.num_coeffs)
            )
        
        self.n_cycle_pred = self.pred_len // self.cycle_len + 1
        self.n_cycle = self.seq_len // self.cycle_len + 1 if self.seq_len % self.cycle_len else self.seq_len // self.cycle_len
        self.cycleQueue = RecurrentCycle(cycle_len=self.cycle_len, channel_size=self.enc_in)

    

    def _process_sequence(self, x):
        B, L, C = x.shape
        valid_length = self.n_cycle * self.cycle_len
        if L < valid_length:
            x = nn.functional.pad(x, (0, 0, valid_length-L, 0))
        else:
            x = x[:, -valid_length:, :]
        x = x.reshape(B, self.n_cycle, self.cycle_len, C)
        return x.reshape(B, self.n_cycle, self.cycle_len, C)
    
    def _generate_frequencies(self, cycle):
        ranges = [
            (1, 60, 5),                     # high frequency：1 min - 1 hour
            (60, 1440, 15),                 # high frequency：1 - 24 hour
            (1440, 1440*7, 360),            # low frequency：1 - 7 day
            (1440*7, 1440*7*52, 1440*7),    # low frequency：1 - 52 week
        ]
        freqs = []
        for (start, end, step) in ranges:
            periods = np.arange(start, end, step)
            freqs.extend(1440 / cycle / periods)
        return torch.tensor(freqs, dtype=torch.float32)

    def _generate_basis_matrix(self, t):
        """Create basis matrix for given timestamps"""
        t = t.unsqueeze(-1).to(torch.float32)
        angle = self.angle_scale * t
        return torch.cat([torch.sin(angle), torch.cos(angle)], dim=-1)
    
    def generate_coeffs(self, x, x_t, x_t_plus_reg):
        B, _, C = x.shape
        x_reshaped = x.permute(0, 2, 1)  # [B, C, L]
        X_enc_T_v = torch.matmul(x_t.unsqueeze(0), x_reshaped.unsqueeze(-1))  # [B, C, D, 1]
        batch_size = B * C
        X_reg = x_t_plus_reg.unsqueeze(0).expand(batch_size, -1, -1)    # [B*C, D, D]
        X_enc_T_v_flat = X_enc_T_v.view(batch_size, -1, 1)              # [B*C, D, 1]
        try:
            L = torch.linalg.cholesky(X_reg)
            y = torch.linalg.solve_triangular(L, X_enc_T_v_flat, upper=False)
            coeffs_flat = torch.linalg.solve_triangular(L.transpose(-2, -1), y, upper=True)
        except RuntimeError:
            coeffs_flat = torch.linalg.solve(X_reg, X_enc_T_v_flat)
        # coeffs_flat = torch.linalg.solve(X_reg, X_enc_T_v_flat)         # [B*C, D, 1]
        coeffs = coeffs_flat.view(B, C, -1)  
        return coeffs
    
    def sample_uniform(self, num_steps=96, num_samples=1, device='cuda'):
        sampled_indices = torch.randperm(num_steps, device=device)[:num_samples]
        return sampled_indices
    
    def sample_long_tail(self, num_steps=96, num_samples=1, temperature=0.9, device='cuda'):
        linear_weights = torch.arange(0, num_steps, step=1, device=device).flip(0)
        powered_weights = linear_weights ** (1 / temperature)
        probs = powered_weights / powered_weights.sum()
        sampled_indices = torch.multinomial(probs, num_samples, replacement=False)
        return sampled_indices
    
    def sample_and_generate(self, x, sample=True, sample_method='uniform'):
        if sample:
            if sample_method == 'uniform':
                t_enc = self.sample_uniform(num_steps=self.seq_len, num_samples=int(self.seq_len*self.sample_ratio), device=x.device)
            elif sample_method == 'long_tail':
                t_enc = self.sample_long_tail(num_steps=self.seq_len, num_samples=int(self.seq_len*self.sample_ratio), device=x.device)
            else:
                raise ValueError("Unknown sampling method")
        else:
            t_enc = torch.arange(self.seq_len, device=x.device)
        X_enc = self._generate_basis_matrix(t_enc)
        X_enc_T = X_enc.transpose(0, 1)  # [num_coeffs, seq_len]
        X_enc_T_X = torch.matmul(X_enc_T, X_enc)
        I = torch.eye(self.num_coeffs, device=x.device)
        I[0, 0] = 0
        X_enc_T_X_plus_reg = X_enc_T_X + self.lambda_reg * I

        coeffs = self.generate_coeffs(x[:, t_enc, :], X_enc_T, X_enc_T_X_plus_reg)
        coeffs_pred = self.model(coeffs)  # [B, C, D]
        y = torch.matmul(coeffs_pred, self.X_dec.t())  # [B, C, pred_len]
        y = y.permute(0, 2, 1)
        return y


    def forward(self, x, cycle_index, x_mark=None, dec_inp=None, y_mark=None, sample=True, sample_method='uniform'):
        B, _, C = x.shape
        if self.use_revin:
            seq_mean = x.mean(dim=1, keepdim=True)
            seq_var = x.var(dim=1, keepdim=True) + 1e-5
            x = (x - seq_mean) / seq_var.sqrt()
        
        pred_cycles = self.cycleQueue((cycle_index + self.seq_len) % self.cycle_len, self.n_cycle_pred)
        seq_cycles = self.cycleQueue(
            cycle_index - (self.n_cycle * self.cycle_len - self.seq_len) % self.cycle_len,
            self.n_cycle
        ).reshape(B, -1, self.enc_in)[:, -self.seq_len:, :]
        x = x - seq_cycles
        
        y = self.sample_and_generate(x, sample=sample, sample_method=sample_method)

        y = y + pred_cycles.reshape(B, -1, C)[:, :self.pred_len, :]

        if self.use_revin:
            y = y * seq_var.sqrt() + seq_mean
        
        return y
    
    def get_trend(self, x, cycle_index, x_mark=None, dec_inp=None, y_mark=None):
        seq_cycles = self.cycleQueue(
            cycle_index - (self.n_cycle * self.cycle_len - self.seq_len) % self.cycle_len,
            self.n_cycle
        ).reshape(1, -1, self.enc_in)[:, -self.seq_len:, :]
        if self.use_revin:
            seq_mean = torch.mean(x, dim=1, keepdim=True)
            seq_var = torch.var(x, dim=1, keepdim=True) + 1e-5
            norm_input_tensor = (x - seq_mean) / torch.sqrt(seq_var)
            norm_trend_tensor = norm_input_tensor - seq_cycles
            trend_tensor = norm_trend_tensor
        else:
            trend_tensor = x - seq_cycles

        trend_predict = self.sample_and_generate(trend_tensor, sample=False)
        if self.use_revin:
            trend_tensor = trend_tensor * torch.sqrt(seq_var)
            trend_predict = trend_predict * torch.sqrt(seq_var)
        return trend_tensor, trend_predict
