# simulated data
import numpy as np
from scipy.stats import truncnorm

class DataGenerator_linear:
    def __init__(self, mu=0.5, n_features=8, random_state=42):
        self.n_features = n_features
        self.random_state = random_state
        np.random.seed(random_state)

        # 设置真实参数
        self.beta = np.random.uniform(-1, 1, self.n_features)
        self.mu = mu

    def generate_rct_data(self, n_samples=400, noise_std=1.0):

        """生成理想的实验数据"""
        X = np.random.randn(n_samples, self.n_features)
        A = np.random.choice([0, 1], size=X.shape[0], p=[1.005 / 4, 2.995 / 4])
        alpha = np.array([0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3])
        y = np.dot(X, self.beta) + A * (alpha[0] + X @ alpha[1:]) + noise_std * np.random.randn(n_samples)

        rct_data = pd.DataFrame({
            'X1': X[:, 0],
            'X2': X[:, 1],
            'X3': X[:, 2],
            'X4': X[:, 3],
            'X5': X[:, 4],
            'X6': X[:, 5],
            'X7': X[:, 6],
            'X8': X[:, 7],
            'Y': y,
            'A': A
        })
        return rct_data

    def generate_external_data(self, n_samples=800, noise_std=1.5, delta_t = 0.1):
        """
        生成存在差异的外部数据
        """
        T = np.random.choice([0, 1, 2], size=n_samples, p=[1 / 3, 1 / 3, 1 / 3])
        X = self.mu + 2 * np.random.randn(n_samples, self.n_features)
        changed_delta = self.beta * np.random.uniform(0.8, 1.2, size=self.n_features)
        y = np.dot(X, changed_delta) + noise_std * np.random.randn(n_samples) + delta_t * T
        ext_data = pd.DataFrame({
            'X1': X[:, 0],
            'X2': X[:, 1],
            'X3': X[:, 2],
            'X4': X[:, 3],
            'X5': X[:, 4],
            'X6': X[:, 5],
            'X7': X[:, 6],
            'X8': X[:, 7],
            'Y': y,
            'A': 0
        })
        return ext_data


def truncated_normal_scipy(size, mean=0, std=0.5, low=-1, high=1):
    """使用SciPy的truncnorm生成截断正态分布"""
    a = (low - mean) / std
    b = (high - mean) / std
    return truncnorm.rvs(a, b, loc=mean, scale=std, size=size)


def generate_truncated_normal(size=1000, dim = 8, mean=0, std=0.5, low=-1, high=1):
    """
    生成8维截断正态分布，每维独立N(0,1)截断到[-1,1]
    参数:
        size: 样本数量
    返回:
        (size, 8)的numpy数组
    """
    a, b = low, high
    loc, scale = mean, std
    # 为每个维度生成截断正态样本
    samples = truncnorm.rvs(a, b, loc=loc, scale=scale, size=(size, dim))
    return samples

class DataGenerator_exp:
    def __init__(self, mu=0.5, n_features=8, random_state=42):
        self.n_features = n_features
        self.random_state = random_state
        np.random.seed(random_state)

        # 设置真实参数
        self.true_a = np.round(1 + 2 * np.random.rand(), 2)
        # self.true_b = np.round(np.random.randn(n_features), 2)
        self.true_b = np.random.uniform(-1, 1, self.n_features)
        self.mu = mu

    def generate_rct_data(self, n_samples=400, noise_std=1.0):
        """生成理想的实验数据"""
        X = generate_truncated_normal(n_samples, self.n_features, mean=0, std=1.0, low=-2, high=2)
        A = np.random.choice([0, 1], size=X.shape[0], p=[1.005 / 4, 2.995 / 4])
        alpha = np.array([0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3])
        y = self.true_a * np.exp(np.dot(X, self.true_b) + A * (alpha[0] + np.dot(X, alpha[1:]))) + noise_std * np.random.randn(n_samples)

        rct_data = pd.DataFrame({
            'X1': X[:, 0],
            'X2': X[:, 1],
            'X3': X[:, 2],
            'X4': X[:, 3],
            'X5': X[:, 4],
            'X6': X[:, 5],
            'X7': X[:, 6],
            'X8': X[:, 7],
            'Y': y,
            'A': A
        })
        return rct_data

    def generate_external_data(self, n_samples=800, noise_std=2.0, delta_t = 1.0): # 0.3
        """
        生成存在差异的外部数据
        """

        T = np.random.choice([0, 1, 2], size=n_samples, p=[1/3, 1/3, 1/3])
        X = generate_truncated_normal(n_samples, self.n_features, mean=self.mu, std=2.0, low=-4, high=4)
        changed_b = self.true_b * np.random.uniform(0.8, 1.2, size=self.n_features)  # 1.1-1.3
        y = self.true_a * np.exp(np.dot(X, changed_b)) + noise_std * np.random.randn(n_samples) + delta_t * T
        ext_data = pd.DataFrame({
            'X1': X[:, 0],
            'X2': X[:, 1],
            'X3': X[:, 2],
            'X4': X[:, 3],
            'X5': X[:, 4],
            'X6': X[:, 5],
            'X7': X[:, 6],
            'X8': X[:, 7],
            'Y': y,
            'A': 0
        })
        return ext_data

# 使用示例
if __name__ == "__main__":
    import pandas as pd

    mu = [0.1,0.2,0.3,0.4,0.5]
    for m in mu:
        data_linear = DataGenerator_linear(mu=m, n_features=8, random_state=42)
        data_rct_linear = data_linear.generate_rct_data()
        data_rwe_linear = data_linear.generate_external_data()
        data_rct_linear.to_csv('dataset/simulated_data/rct_data_linear_' + str(m) + '.csv', index=False)
        data_rwe_linear.to_csv('dataset/simulated_data/rwe_data_linear_' + str(m) + '.csv', index=False)

    mu = [0.1,0.2,0.3,0.4,0.5]
    for m in mu:
        data_exp = DataGenerator_exp(mu = m, n_features=8, random_state=42)
        data_rct = data_exp.generate_rct_data()
        data_rwe = data_exp.generate_external_data()
        data_rct.to_csv('dataset/simulated_data/rct_data_exp_' + str(m) + '.csv', index=False)
        data_rwe.to_csv('dataset/simulated_data/rwe_data_exp_' + str(m) + '.csv', index=False)

