import numpy as np
import pandas as pd

from collections.abc import Generator
from pathlib import Path
from typing import Any

import datasets
import pandas as pd
from datasets import Features, Sequence, Value

import random

import numpy as np

def generate_multivariate_ar_process(p, n_vars, n_samples, burn_in=50):
    """
    Generate a synthetic dataset from a multivariate AR process with i.i.d. noise.

    Parameters:
        p (int): AR order.
        n_vars (int): Number of variables in the multivariate process.
        n_samples (int): Number of samples to generate (after burn-in).
        burn_in (int): Number of initial samples to discard to remove transient effects.

    Returns:
        np.ndarray: Synthetic multivariate AR process data of shape (n_samples, n_vars) in float16.
    """
    # Initialize AR coefficients with normalized values to avoid instability
    phi = np.random.normal(0, 1, (n_vars, n_vars * p)).astype(np.float16)
    phi /= np.max(np.abs(phi)) + np.float16(1e-3)

    total_samples = n_samples + burn_in
    
    # Initialize the time series with zeros in float16
    x = np.zeros((total_samples, n_vars), dtype=np.float16)
    
    # Generate i.i.d. standard normal noise in float16
    v = random.uniform(0.1, 1)
    noise = np.random.normal(0, v, (total_samples, n_vars)).astype(np.float16)
    
    # Generate the AR process
    for t in range(p, total_samples):
        for j in range(n_vars):
            # Compute the j-th variable as a linear combination of past values
            contribution = np.mean(
                [
                    phi[j, k] * x[t - (k // n_vars) - 1, k % n_vars]
                    for k in range(n_vars * p)
                ]
            )
            # Clip to avoid overflow
            contribution = np.clip(
                contribution, np.finfo(np.float16).min, np.finfo(np.float16).max
            )
            x[t, j] = contribution + noise[t, j]

        # Replace NaN or Inf values with zeros
        if np.any(np.isnan(x[t]) | np.isinf(x[t])):
            x[t] = np.zeros(n_vars, dtype=np.float16)

    # Discard burn-in samples
    return x[burn_in:]


def generate_multivariate_ar_with_seasonality(p, n_vars, n_samples, burn_in=50):
    """
    Generate a multivariate AR dataset with added seasonality.

    Parameters:
        phi (np.ndarray): AR coefficient matrix of shape (n_vars, n_vars * p),
                          where p is the order of the AR process.
        n_vars (int): Number of variables in the multivariate process.
        n_samples (int): Number of samples to generate (after burn-in).
        seasonality_params (list): A list of dictionaries with keys 'amplitude', 'frequency', and 'phase'
                                   for each variable's seasonal component.
        burn_in (int): Number of initial samples to discard to remove transient effects.

    Returns:
        pd.DataFrame: Multivariate AR process data with seasonality added, indexed by time.
    """

    seasonality_params = [
        {"amplitude": random.random()*1.5, "frequency": 30, "phase": np.pi/4}, # Var2: monthly seasonality
    ]

    phi = np.random.normal(0, 1, (n_vars, n_vars * p)).astype(np.float16)
    phi /= np.max(np.abs(phi)) + np.float16(1e-3)

    total_samples = n_samples + burn_in
    
    # Initialize the time series with zeros
    x = np.zeros((total_samples, n_vars))
    
    # Generate i.i.d. standard normal noise
    noise = np.random.normal(0, 1, (total_samples, n_vars))
    
    # Generate the AR process
    for t in range(p, total_samples):
        for j in range(n_vars):
            x[t, j] = sum(
                phi[j, k] * x[t - (k // n_vars) - 1, k % n_vars] 
                for k in range(n_vars * p)
            )/(n_vars * p) + noise[t, j]
    
    # Discard burn-in samples
    x = x[burn_in:]
    
    # Add seasonality
    seasonal_component = np.zeros_like(x)
    for j, params in enumerate(seasonality_params):
        amplitude = params["amplitude"]
        frequency = params["frequency"]  # Default to yearly frequency
        phase = params["phase"]
        time_steps = np.arange(n_samples)
        seasonal_component[:, j] = amplitude * np.sin(2 * np.pi * time_steps / frequency + phase)
    
    # Add seasonal component to the AR data
    x += seasonal_component
    
    return x




def save_to_csv(data, filename="multivariate_ar_process.csv"):
    """
    Save multivariate AR process data to a CSV file.

    Parameters:
        data (np.ndarray): Multivariate time series data of shape (n_samples, n_vars).
        filename (str): Name of the CSV file to save.
    """
    start_date = pd.Timestamp("2010-01-01")

    # Create a daily range of dates
    time_index = pd.date_range(start=start_date, periods=len(data), freq="D")
    df = pd.DataFrame(data, columns=[f"V{i+1}" for i in range(data.shape[1])], index=time_index).fillna(0)
    df.to_csv(filename)
    return df



def generate_synthetic_data(order=1, n_series=100, d=5, length=10000, folder_name="csv_datasets/test", seed=1111):
    
    for i in range(n_series):
        random.seed(seed + i)

        data = generate_multivariate_ar_process(order, d, length, burn_in=50)
        df = save_to_csv(data, f"{folder_name}/data_{i}_o{order}_{d}.csv")

        def multivar_example_gen_func() -> Generator[dict[str, Any], None, None]:
            yield {
                "target": df.to_numpy().T,  # array of shape (var, time)
                "start": df.index[0],
                "freq": pd.infer_freq(df.index),
                "item_id": "item_0",
            }

        features = Features(
            dict(
                target=Sequence(
                    Sequence(Value("float32")), length=len(df.columns)
                ),  # multivariate time series are saved as (var, time)
                start=Value("timestamp[s]"),
                freq=Value("string"),
                item_id=Value("string"),
            )
        )

        hf_dataset = datasets.Dataset.from_generator(
            multivar_example_gen_func, features=features
        )
        hf_dataset.save_to_disk(f"datasets/testing/data{i}_o{order}_{d}")



def generate_synthetic_data_with_seasonality(order=1, n_series=100, d=5, length=10000, folder_name="csv_datasets/test", seed=1111):
    

    for i in range(n_series):
        random.seed(seed + i)

        data = generate_multivariate_ar_with_seasonality(order, d, length, burn_in=50)
        df = save_to_csv(data, f"{folder_name}/data_{i}_o{order}_season_{d}.csv")

        def multivar_example_gen_func() -> Generator[dict[str, Any], None, None]:
            yield {
                "target": df.to_numpy().T,  # array of shape (var, time)
                "start": df.index[0],
                "freq": pd.infer_freq(df.index),
                "item_id": "item_0",
            }

        features = Features(
            dict(
                target=Sequence(
                    Sequence(Value("float32")), length=len(df.columns)
                ),  # multivariate time series are saved as (var, time)
                start=Value("timestamp[s]"),
                freq=Value("string"),
                item_id=Value("string"),
            )
        )

        hf_dataset = datasets.Dataset.from_generator(
            multivar_example_gen_func, features=features
        )
        hf_dataset.save_to_disk(f"datasets/testing/data{i}_o{order}_season_{d}")



def strongly_dependent_data(T, d):

    # T = 500  # Time steps
    # d = 3    # Number of variables (dimensions)
    sigma_noise = 0.1  # Noise standard deviation
    sigma_mu = 0.05    # Drift fluctuation standard deviation

    # Initialize time series and drift terms
    X = np.zeros((T+50, d))  # Main time series
    mu = np.zeros((T+50, d))  # Drift term

    # Generate non-ergodic drift time series
    # np.random.seed(42)  # Reproducibility
    for t in range(1, T+50):
        mu[t] = mu[t-1] + np.random.normal(0, sigma_mu, d)  # Drift random walk
        X[t] = X[t-1] + mu[t] + np.random.normal(0, sigma_noise, d)  # Main process

    # Convert to DataFrame for visualization
    # df = pd.DataFrame(X[50:], columns=[f"V{i+1}" for i in range(d)])
    return X[50:]



def generate_synthetic_data_strongly_dependent(n_series=100, d=5, length=10000, folder_name="csv_datasets/test", seed=1111):
    

    for i in range(n_series):
        random.seed(seed + i)

        data = strongly_dependent_data(length, d)
        df = save_to_csv(data, f"{folder_name}/strong_data_{i}.csv")

        def multivar_example_gen_func() -> Generator[dict[str, Any], None, None]:
            yield {
                "target": df.to_numpy().T,  # array of shape (var, time)
                "start": df.index[0],
                "freq": pd.infer_freq(df.index),
                "item_id": "item_0",
            }

        features = Features(
            dict(
                target=Sequence(
                    Sequence(Value("float32")), length=len(df.columns)
                ),  # multivariate time series are saved as (var, time)
                start=Value("timestamp[s]"),
                freq=Value("string"),
                item_id=Value("string"),
            )
        )

        hf_dataset = datasets.Dataset.from_generator(
            multivar_example_gen_func, features=features
        )
        hf_dataset.save_to_disk(f"datasets/testing/strong_data{i}")

