from .SpatioTemporalGraph import SpatioTemporalGraph
import torch
import networkx as nx
from torch_geometric.utils import from_networkx
from utils.utils import integrate


def compute_noise_snr(raw_data, snr_db = 20):
        signal = raw_data
        
        signal_power = signal.pow(2).mean()

        snr_linear = 10 ** (snr_db / 10)
        noise_power = signal_power / snr_linear
        noise = torch.randn_like(signal) * noise_power.sqrt()
        
        noisy_signal = signal + noise

        return noisy_signal
    

class SyntheticData(SpatioTemporalGraph):
    def __init__(
        self, 
        root,
        dynamics,
        t_span = [0, 1],
        t_max = 300,
        num_samples = 30,
        seed=42, 
        n_ics = 3,
        input_range = [0,1],
        device='cpu',
        horizon = 15,
        history= 1,
        stride=24,
        predict_deriv=False,
        snr_db = -1,
        denoise=False,
        deriv_method = "five_point",
        **integration_kwargs
    ):  
        
        self.t_span = t_span
        self.t_max = t_max
        self.input_range = input_range
        self.int_kwargs = integration_kwargs
        self.dynamics = dynamics
        name = dynamics if snr_db < 0 else f"{dynamics}_{snr_db}_db"
        if denoise:
            name += "_denoise"
        self.snr_db = snr_db
        self.n_ics = n_ics
        
        super().__init__(
            root=root,
            name=name,
            n_samples=num_samples,
            seed=seed,
            device=device,
            horizon=horizon,
            history=history,
            stride=stride,
            predict_deriv=predict_deriv,
            denoise=denoise,
            deriv_method=deriv_method
        )
    
    
    def get_raw_data(self):
        
        graph = nx.barabasi_albert_graph(70, 3, seed=self.seed)
        edge_index = from_networkx(graph).edge_index
        edge_index = edge_index.to(torch.device(self.device))
        
        raw_data, t = [], []
        for _ in range(self.n_ics):
            data_k, t_k = integrate(
                input_range=self.input_range,
                t_span=self.t_span,
                t_eval_steps=self.t_max,
                dynamics=self.dynamics,
                device=torch.device(self.device),
                graph=graph,
                rng=self.rng,
                **self.int_kwargs
            )
            raw_data.append(data_k)
            t.append(t_k)
        
        raw_data = torch.stack(raw_data, dim=0) #(n_ics, t_max, n_nodes, n_features)
        t = torch.stack(t, dim=0)               #(n_ics, t_max)
        
        if self.snr_db > 0:
            raw_data = compute_noise_snr(raw_data=raw_data, snr_db=self.snr_db)
        
        return edge_index, None, raw_data, t
        