import torch
import numpy as np
from scipy.optimize import linprog
from seed import set_seed
def generate_data(T, dim, CP, sigma=0.1, seed=42, con_num=50):
    set_seed(seed) 
    split_points = torch.sort(torch.randperm(T - 1)[:CP] + 1).values.tolist()
    segments = torch.tensor_split(torch.arange(T), split_points)
    A = np.random.uniform(low=-2, high=-1, size=(con_num, dim))
    b = -100 * np.random.rand(con_num)  
    bounds = [(0, None)] * dim  
    c_all = np.zeros((T, dim))
    c_mean_list = []

    for segment in segments:
        c_mean = np.random.rand(dim)
        c_mean_list.append(c_mean)

        for t in segment:
            noise = np.random.uniform(low=0, high=sigma, size=dim)
            c_all[t] = c_mean + noise

    return A, b, c_all, c_mean_list, split_points

def solve_all_lp(c_all, A, b, bounds):
    T, dim = c_all.shape
    x_all = np.zeros((T, dim))
    f_all = np.zeros(T)
    lambda_all = np.zeros((T, A.shape[0]))
    mu_all = np.zeros((T, dim))

    for t in range(T):
        c = c_all[t]
        res = linprog(c, A_ub=A, b_ub=b, bounds=bounds, method='highs')

        if res.success:
            x_all[t] = res.x
            f_all[t] = res.fun
            lambda_all[t] = res.ineqlin.marginals
            mu_all[t] = res.lower.marginals
        else:
            print(f"LP failed at time {t}")

    return x_all, f_all, lambda_all, mu_all

if __name__ == "__main__":
    T = 1000  
    dim = 100  
    CP = 3    
    con_num = 50  

    A, b, c_all, c_mean_list, split_points = generate_data(T, dim, CP, con_num=con_num)
    bounds = [(0, None)] * dim  
    # print(A)
    # print(b)
    # print(c_all)
    x_all, f_all, lambda_all, mu_all = solve_all_lp(c_all, A, b, bounds)

    # print("Optimal solutions (x):", x_all)
    print("Optimal values (f):", f_all)
    # print("Dual variables for inequality constraints (lambda):", lambda_all)
    # print("Dual variables for variable lower bounds (mu):", mu_all)