import random

import numpy as np
import torch
from scipy.integrate import odeint

# 重新缩放VAR模型的系数以使其稳定，beta内除0以外的值就变了
def make_var_stationary(beta, radius=0.97):
    '''Rescale coefficients of VAR model to make stable.'''
    p = beta.shape[0]
    lag = beta.shape[1] // p
    bottom = np.hstack((np.eye(p * (lag - 1)), np.zeros((p * (lag - 1), p))))
    beta_tilde = np.vstack((beta, bottom))
    eigvals = np.linalg.eigvals(beta_tilde)
    max_eig = max(np.abs(eigvals))
    nonstationary = max_eig > radius
    if nonstationary:
        return make_var_stationary(0.95 * beta, radius)
    else:
        return beta

# 模拟一个有因果关系的时序数据
def simulate_var(p, T, lag, sparsity=0.2, beta_value=1.0, sd=0.1, seed=0):
    if seed is not None:
        np.random.seed(seed)

    # Set up coefficients and Granger causality ground truth.建立系数和格兰杰因果关系的基本事实
    GC = np.eye(p, dtype=int)# 生成一个P*P的对角线为1，其余都是0的矩阵
    # beta是一个与GC一样的矩阵但值是beta_value
    beta = np.eye(p) * beta_value # 生成一个P*P的对角线为1beta_value
    # sparsity稀疏性
    num_nonzero = int(p * sparsity) - 1
    # 决定因果关系
    for i in range(p):
        choice = np.random.choice(p - 1, size=num_nonzero, replace=False)
        choice[choice >= i] += 1
        beta[i, choice] = beta_value
        GC[i, choice] = 1
    # beta在水平方向上叠加了lag个 [10,10]叠加3次=>[10,30]
    beta = np.hstack([beta for _ in range(lag)])
    # 重新缩放VAR模型的系数以使其稳定，beta内除0以外的值就变了
    beta = make_var_stationary(beta)

    # Generate data.开始生成数据
    burn_in = 100 # 开始的100个生成点数据不要了
    errors = np.random.normal(scale=sd, size=(p, T + burn_in))  # [10,1000]
    X = np.zeros((p, T + burn_in))
    X[:, :lag] = errors[:, :lag]
    for t in range(lag, T + burn_in):
        X[:, t] = np.dot(beta, X[:, (t-lag):t].flatten(order='F'))  # 计算前lag与beta的点积
        X[:, t] += + errors[:, t-1]

    return X.T[burn_in:], beta, GC

# 模拟一个有因果关系的时序数据
def simulate_var_1000(p, T, lag, sparsity=0.2, beta_value=1.0, sd=0.1, seed=0):
    if seed is not None:
        np.random.seed(seed)
    # Set up coefficients and Granger causality ground truth.建立系数和格兰杰因果关系的基本事实
    GC = np.eye(p, dtype=int)# 生成一个P*P的对角线为1，其余都是0的矩阵
    # beta是一个与GC一样的矩阵但值是beta_value
    beta = np.eye(p) * beta_value # 生成一个P*P的对角线为1beta_value
    # sparsity稀疏性
    num_nonzero = int(p * sparsity) - 1
    # 决定因果关系
    for i in range(p):
        choice = np.random.choice(p - 1, size=num_nonzero, replace=False)
        choice[choice >= i] += 1
        beta[i, choice] = beta_value
        GC[i, choice] = 1
    # beta在水平方向上叠加了lag个 [10,10]叠加3次=>[10,30]
    beta = np.hstack([beta for _ in range(lag)])
    # 重新缩放VAR模型的系数以使其稳定，beta内除0以外的值就变了
    beta = make_var_stationary(beta)

    # Generate data.开始生成数据
    burn_in = 100 # 开始的100个生成点数据不要了
    errors = np.random.normal(scale=sd, size=(p, T + burn_in))  # [10,1000]
    X = np.zeros((p, T + burn_in))
    X[:, :lag] = errors[:, :lag]
    for t in range(lag, T + burn_in):
        X[:, t] = np.dot(beta, X[:, (t-lag):t].flatten(order='F'))  # 计算前lag与beta的点积
        X[:, t] += + errors[:, t-1]

    return X.T[burn_in:], beta, GC


# 模拟一个有因果关系的时序数据,n为有关系的数量
def simulate_var2(p, T, lag, sparsity=0.2, beta_value=1.0, sd=0.1, seed=0,n=2):
    if seed is not None:
        np.random.seed(seed)
        random.seed(seed)
    if p <= n:
        return None
    # Set up coefficients and Granger causality ground truth.建立系数和格兰杰因果关系的基本事实
    GC = np.eye(p, dtype=int)# 生成一个P*P的对角线为1，其余都是0的矩阵
    # beta是一个与GC一样的矩阵但值是beta_value
    beta = np.eye(p) * beta_value # 生成一个P*P的对角线为1beta_value
    # sparsity稀疏性
    num_nonzero = int(p * sparsity) - 1
    # 决定因果关系
    # for i in range(p):
    #     choice = np.random.choice(p - 1, size=num_nonzero, replace=False)
    #     choice[choice >= i] += 1
    #     beta[i, choice] = beta_value
    #     GC[i, choice] = 1

    for i in range(p):
        choice = generate_list(n, p, i)
        for c in choice:
            beta[i, c] = beta_value
            GC[i, c] = 1
    # beta在水平方向上叠加了lag个 [10,10]叠加3次=>[10,30]
    beta = np.hstack([beta for _ in range(lag)])
    # 重新缩放VAR模型的系数以使其稳定，beta内除0以外的值就变了
    beta = make_var_stationary(beta)

    # Generate data.开始生成数据
    burn_in = 100 # 开始的100个生成点数据不要了
    errors = np.random.normal(scale=sd, size=(p, T + burn_in))  # [10,1000]
    X = np.zeros((p, T + burn_in))
    X[:, :lag] = errors[:, :lag]
    for t in range(lag, T + burn_in):
        X[:, t] = np.dot(beta, X[:, (t-lag):t].flatten(order='F'))
        X[:, t] += + errors[:, t-1]

    return X.T[burn_in:], beta, GC

def generate_list(n, p, i):
    # 生成包含n个p以内不重复数字的列表
    num_list = random.sample(range(0, p), n)
    # 如果列表中存在i，则将其替换为列表中的最后一个元素
    if i in num_list:
        idx = num_list.index(i)
        new = i
        while new in num_list:
            new = random.randint(0, p-1)
        num_list[idx] = new
    return num_list

def lorenz(x, t, F):
    '''Partial derivatives for Lorenz-96 ODE.'''
    p = len(x)
    dxdt = np.zeros(p)
    for i in range(p):
        dxdt[i] = (x[(i+1) % p] - x[(i-2) % p]) * x[(i-1) % p] - x[i] + F

    return dxdt

# 模拟lorenz_96的数据
def simulate_lorenz_96(p, T, F=10.0, delta_t=0.1, sd=0.1, burn_in=1000,
                       seed=0):
    if seed is not None:
        np.random.seed(seed)

    # Use scipy to solve ODE.
    x0 = np.random.normal(scale=0.01, size=p)
    t = np.linspace(0, (T + burn_in) * delta_t, T + burn_in)
    # 使用SciPy库的odeint函数求解了洛伦兹系统
    X = odeint(lorenz, x0, t, args=(F,))
    X += np.random.normal(scale=sd, size=(T + burn_in, p))

    # Set up Granger causality ground truth.
    GC = np.zeros((p, p), dtype=int)
    for i in range(p):
        GC[i, i] = 1
        GC[i, (i + 1) % p] = 1
        GC[i, (i - 1) % p] = 1
        GC[i, (i - 2) % p] = 1

    return X[burn_in:], GC

# seg：将数据集划分的序列数，对应训练多少个模型，val_rate：取百分之多少为验证集
def data_segmentation(data,lag=5, seg=1, val_rate = 0.2):
    T = len(data)

    # 构造输入的batch和标签
    data_x = []
    train_input = data[:T - 1, :]  # 取前train_data_len个数预测,少一个是因为那个数是最后一个要预测的值
    # 一共要预测input.shape[0] - lag + 1个数,构造这么多个输入
    for i in range(0, train_input.shape[0] - lag + 1):
        x = train_input[i:i + lag, :]
        x = x.transpose(1, 0)
        data_x.append(x)
    data_x = torch.stack(data_x)  # [995,10,5]  [T-5,10,lag]
    data_y = data[lag:T, :]  # [995,10]  [T-5,10,lag]

    train_x = []
    train_y = []
    val_x = []
    val_y = []
    # ！！！！！！！！！！！！！！
    for n in range(seg):
        a = n * int(T / seg)  # 起始位置
        b = (n + 1) * int(T / seg)  # 结束位置
        # 训练集
        x = data_x[a:b - int(val_rate * T / seg)]
        train_x.append(x)
        y = data_y[a:b - int(val_rate * T / seg)]
        train_y.append(y)
        # 验证集
        x = data_x[b - int(val_rate * T / seg):b]
        val_x.append(x)
        y = data_y[b - int(val_rate * T / seg):b]
        val_y.append(y)
    return train_x, train_y, val_x, val_y
