import numpy as np
import os

# save function
def save_data_and_graph(X, graph, root_dir, prefix):
    root_dir = os.path.join('./data', root_dir)
    data_dir = os.path.join(root_dir, "data")
    label_dir = os.path.join(root_dir, "label")

    os.makedirs(data_dir, exist_ok=True)
    os.makedirs(label_dir, exist_ok=True)

    data_path = os.path.join(data_dir, f"{prefix}.npy")
    label_path = os.path.join(label_dir, f"{prefix}.npy")

    np.save(data_path, X)
    np.save(label_path, graph)

class Lorenz96Generator:
    def __init__(self, D, F, T, dt, burn_in, noise_scale, seed):
        self.D = D
        self.F = F
        self.T = T
        self.dt = dt
        self.burn_in = burn_in
        self.noise_scale = noise_scale
        self.seed = seed

        if self.seed is not None:
            np.random.seed(self.seed)

        # Precompute causal graph (binary) based on the equations
        self.graph = self._build_causal_graph()
        # Simple structural importance: out-degree of each node
        self.node_importance = self.graph.sum(axis=0)

    def _lorenz96_rhs(self, x):
        """Right-hand side of the Lorenz-96 ODE."""
        D = self.D
        dx = np.zeros(D)
        for i in range(D):
            xm2 = x[(i - 2) % D]
            xm1 = x[(i - 1) % D]
            xp1 = x[(i + 1) % D]
            dx[i] = (xp1 - xm2) * xm1 - x[i] + self.F
        return dx

    def _rk4_step(self, x):
        """One Runge-Kutta 4 integration step."""
        dt = self.dt
        k1 = self._lorenz96_rhs(x)
        k2 = self._lorenz96_rhs(x + 0.5 * dt * k1)
        k3 = self._lorenz96_rhs(x + 0.5 * dt * k2)
        k4 = self._lorenz96_rhs(x + dt * k3)
        return x + (dt / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4)

    def _build_causal_graph(self):
        """
        Build binary graph matrix A where A[target, source] = 1
        if source appears in the equation for target.
        """
        D = self.D
        A = np.zeros((D, D), dtype=int)
        for i in range(D):
            sources = [
                (i - 2) % D,
                (i - 1) % D,
                i,
                (i + 1) % D,
            ]
            for s in sources:
                A[i, s] = 1
        return A

    def sample(self):
        # Random initial condition (slightly perturbed around F)
        x = self.F * np.ones(self.D) + 1 * np.random.randn(self.D)

        # Burn-in
        for _ in range(self.burn_in):
            x = self._rk4_step(x)
        
        # Collect data
        X = np.zeros((self.T, self.D))
        for t in range(self.T):
            x = self._rk4_step(x)
            X[t] = x
        
        if self.noise_scale > 0.0:
            X += self.noise_scale * np.random.randn(self.T, self.D)

        print(X.mean(axis=0), X.std(axis=0))
        
        return X

    def save(self, root_dir, tag):
        prefix = f"D{self.D}_T{self.T}_F{self.F}_{tag}"
        save_data_and_graph(self.sample(), self.graph, root_dir=root_dir, prefix=prefix)

class VAR3Generator:
    def __init__(self, D, T, sparsity, max_lag, burn_in, noise_scale, seed):
        self.D = D
        self.T = T
        self.sparsity = sparsity
        self.max_lag = max_lag
        self.burn_in = burn_in
        self.noise_scale = noise_scale
        self.seed = seed

        assert self.max_lag == 3, "This generator is specifically VAR(3)."

        if self.seed is not None:
            np.random.seed(self.seed)

        # Build sparse coefficient matrices and enforce stationarity
        self.A = self._generate_sparse_stationary_matrices()
        # Build binary graph matrix and a coefficient-based importance score per variable
        self.graph = self._build_causal_graph_and_importance()

    # --------- internal helpers ---------
    def _generate_sparse_stationary_matrices(self):
        D, p = self.D, self.max_lag
        total_entries = D * D
        n_nonzero = int(self.sparsity * total_entries) - D # exclude diagonals

        # Spread non-zeros uniformly across lags
        base_nonzero_per_lag = n_nonzero // p
        remainder = n_nonzero % p
        nonzeros_per_lag = [
            base_nonzero_per_lag + (1 if i < remainder else 0) for i in range(p)
        ]
        
        used_position = np.zeros(D * D, dtype=bool)  # to track used positions across lags

        A_list = []
        for lag_idx in range(p):
            A_k = np.zeros((D, D))
            k_nonzero = nonzeros_per_lag[lag_idx]

            # Uniformly choose positions without replacement
            all_positions = np.arange(D * D)
            available_positions = all_positions[~used_position]
            available_positions = available_positions[available_positions % (D + 1) != 0]  # exclude diagonals
            
            chosen = np.random.choice(available_positions, size=k_nonzero, replace=False)
            rows = chosen // D
            cols = chosen % D
            used_position[chosen] = True

            # Sample coefficients: magnitude ~ U(0.3, 0.6), random sign
            magnitudes = np.random.uniform(0.2, 0.8, size=k_nonzero)
            signs = np.random.choice([-1.0, 1.0], size=k_nonzero)
            A_k[rows, cols] = magnitudes * signs
            
            # Add diagonal entry
            for d in range(D):
                A_k[d, d] = np.random.uniform(0.5, 1.0) * np.random.choice([-1.0, 1.0]) / (lag_idx + 1)
            
            A_list.append(A_k)

        A1, A2, A3 = A_list

        # Ensure stationarity by scaling so that spectral radius of companion matrix < 1
        A1, A2, A3 = self._enforce_stationarity(A1, A2, A3)

        return A1, A2, A3

    def _enforce_stationarity(self, A1, A2, A3):
        D = self.D

        def spectral_radius(B1, B2, B3):
            top_row = np.hstack([B1, B2, B3])
            I = np.eye(D)
            zeros = np.zeros((D, D))
            middle_row = np.hstack([I, zeros, zeros])
            bottom_row = np.hstack([zeros, I, zeros])
            companion = np.vstack([top_row, middle_row, bottom_row])
            eigvals = np.linalg.eigvals(companion)
            return np.max(np.abs(eigvals))

        rho = spectral_radius(A1, A2, A3)
        while(rho > 0.95):
            # shrink a bit relative to current rho
            scale = 0.95 / rho
            A1 *= scale
            A2 *= scale
            A3 *= scale
            rho = spectral_radius(A1, A2, A3)

        # If already stable, we keep as is
        return A1, A2, A3

    def _build_causal_graph_and_importance(self):
        A1, A2, A3 = self.A

        nonzero_mask = (A1 != 0.0) | (A2 != 0.0) | (A3 != 0.0)
        graph = nonzero_mask.astype(int)

        return graph

    def sample(self):
        D, T, p = self.D, self.T, self.max_lag
        A1, A2, A3 = self.A

        total_steps = T + self.burn_in + p

        # Initialize with Gaussian noise
        X = np.zeros((total_steps, D))
        X[:p] = np.random.randn(p, D)

        for t in range(p, total_steps):
            x_tm1 = X[t - 1]
            x_tm2 = X[t - 2]
            x_tm3 = X[t - 3]
            noise = self.noise_scale * np.random.randn(D)
            X[t] = A1 @ x_tm1 + A2 @ x_tm2 + A3 @ x_tm3 + noise

        # Drop burn-in and initial lags
        X_out = X[self.burn_in + p : self.burn_in + p + T]

        print(X_out.mean(axis=0), X_out.std(axis=0))
        
        return X_out
    
    def save(self, root_dir, tag):
        prefix = f"D{self.D}_T{self.T}_{tag}"
        save_data_and_graph(self.sample(), self.graph, root_dir=root_dir, prefix=prefix)

def save_data():
    D = [10, 50]
    F = [10, 50]
    T = [500, 1000, 2000]
    trials = 5
    seed = 2025
    # generate lorenz96 data
    # for d in D:
    #     for f in F:
    #         for t in T:
    #             print(f"Generating Lorenz96 data: D={d}, F={f}, T={t}")
    #             l96 = Lorenz96Generator(D=d, F=f, T=t, dt=0.01, burn_in=1000, noise_scale=1, seed=seed+d+f+t)
    #             for trial in range(trials):
    #                 l96.save(root_dir="Lorenz96", tag=f"trials_{trial + 1}")
    # generate VAR(3) data
    sparsity = 0.3
    for d in D:
        for t in T:
            print(f"Generating VAR(3) data: D={d}, T={t}, sparsity={sparsity}")
            var3 = VAR3Generator(D=d, T=t, sparsity=sparsity, max_lag=3, burn_in=1000, noise_scale=1, seed=seed+d+t)
            for trial in range(trials):
                var3.save(root_dir="VAR3", tag=f"trials_{trial + 1}")

if __name__ == "__main__":
    # save data
    save_data()
    # ----- Lorenz-96 examples -----
    l96_small = Lorenz96Generator(D=10, F=10.0, T=500, dt=0.01, burn_in=1000, noise_scale=1, seed=42)
    X_l96_small = l96_small.sample()
    print("Lorenz-96 (D=10, F=10, T=500) shape:", X_l96_small.shape)
    print("Lorenz-96 graph shape:", l96_small.graph.shape)
    print(X_l96_small.mean(axis=0), X_l96_small.std(axis=0))
    
    l96_large = Lorenz96Generator(D=50, F=50.0, T=1000, dt=0.01, burn_in=2000, noise_scale=1, seed=123)
    X_l96_large = l96_large.sample()
    print("Lorenz-96 (D=50, F=50, T=1000) shape:", X_l96_large.shape)
    print("Lorenz-96 graph shape:", l96_large.graph.shape)
    print(X_l96_large.mean(axis=0), X_l96_large.std(axis=0))

    # ----- VAR(3) examples -----
    var_small = VAR3Generator(D=10, T=500, sparsity=0.3, max_lag=3, burn_in=200, noise_scale=1, seed=7)
    X_var_small = var_small.sample()
    print("VAR(3) (D=10, T=500) shape:", X_var_small.shape)
    print("VAR graph shape:", var_small.graph.shape)
    print("VAR sparsity (graph):", var_small.graph.mean())
    print(X_var_small.mean(axis=0), X_var_small.std(axis=0))

    var_large = VAR3Generator(D=50, T=1000, sparsity=0.3, max_lag=3, burn_in=500, noise_scale=1, seed=99)
    X_var_large = var_large.sample()
    print("VAR(3) (D=50, T=1000) shape:", X_var_large.shape)
    print("VAR graph shape:", var_large.graph.shape)
    print("VAR sparsity (graph):", var_large.graph.mean())
    print(X_var_large.mean(axis=0), X_var_large.std(axis=0))
    
    # ----- GLV examples -----
    glv_small = GLVGenerator(D=50, T=500, sparsity=0.3, dt=0.1, burn_in=1000, noise_scale=0.01, seed=21)
    X_glv_small = glv_small.sample()
    print("GLV (D=10, T=500) shape:", X_glv_small.shape)
    print("GLV graph shape:", glv_small.graph.shape)
    print("GLV sparsity (graph):", glv_small.graph.mean())
    print(X_glv_small.mean(axis=0), X_glv_small.std(axis=0))
    
    import matplotlib.pyplot as plt

    plt.figure(figsize=(6, 5))
    plt.imshow(glv_small.graph, cmap="Greys", interpolation="nearest")
    plt.colorbar(label="Edge present")
    plt.xlabel("Source node")
    plt.ylabel("Target node")
    plt.title("GLV Causal Graph (Adjacency Matrix)")
    plt.tight_layout()
    plt.show()