import os
import pandas as pd
import numpy as np

# --- Parameters ---
base_dir = "/Users/home/Documents/naz/research_codes/uncert_prop/realworld_exp/hai_down1"
central_lti_dir = os.path.join(base_dir, "centralized_lti_system")
component_dir = os.path.join(base_dir, "component_split")
os.makedirs(component_dir, exist_ok=True)

# --- Load system matrices and states ---
A = pd.read_csv(os.path.join(central_lti_dir, "A_complete.csv"), index_col=0)
B = pd.read_csv(os.path.join(central_lti_dir, "B_complete.csv"), index_col=0)
C = pd.read_csv(os.path.join(central_lti_dir, "C_complete.csv"), index_col=0)
Q = pd.read_csv(os.path.join(central_lti_dir, "Q_complete.csv"), index_col=0)
R = pd.read_csv(os.path.join(central_lti_dir, "R_complete.csv"), index_col=0)
states = pd.read_csv(os.path.join(central_lti_dir, "states_complete.csv"))
U = pd.read_csv(os.path.join(central_lti_dir, "U_complete.csv"))
Y = pd.read_csv(os.path.join(central_lti_dir, "Y_complete.csv"))

# Remove identifier row and column (first row and column) for A, B, C, Q, R
A_n = A.iloc[1:, 1:].reset_index(drop=True)
B_n = B.iloc[1:, 1:].reset_index(drop=True)
C_n = C.iloc[1:, 1:].reset_index(drop=True)
Q_n = Q.iloc[1:, 1:].reset_index(drop=True)
R_n = R.iloc[1:, 1:].reset_index(drop=True)

A_n.to_csv(os.path.join(component_dir, "A_complete.csv"), index=False, header=False)
B_n.to_csv(os.path.join(component_dir, "B_complete.csv"), index=False, header=False)
C_n.to_csv(os.path.join(component_dir, "C_complete.csv"), index=False, header=False)
Q_n.to_csv(os.path.join(component_dir, "Q_complete.csv"), index=False, header=False)
R_n.to_csv(os.path.join(component_dir, "R_complete.csv"), index=False, header=False)

# Save initial state (first row of states, after removing identifier column)
x0_complete = states.iloc[0, 1:].values
pd.DataFrame(x0_complete.reshape(1, -1)).to_csv(os.path.join(component_dir, "x0_complete.csv"), index=False, header=False)

# --- Split into components ---
# Infer process names from C's column names (e.g., P1_1, P2_2, ...)
state_cols = list(C.columns[1:])
output_rows = list(C.index[1:])
input_cols = list(B.columns[1:])

# Get unique process names in order of appearance
processes = []
for col in state_cols:
    pname = col.split('_')[0]
    if pname not in processes:
        processes.append(pname)
M = len(processes)

# Map process to state indices, output indices, input indices
process_state_idx = {}
process_output_idx = {}
process_input_idx = {}

for p in processes:
    process_state_idx[p] = [i for i, name in enumerate(state_cols) if name.startswith(p + "_")]
    process_output_idx[p] = [i for i, name in enumerate(output_rows) if name.startswith(p + "_")]
    process_input_idx[p] = [i for i, name in enumerate(input_cols) if name.startswith(p + "_")]

for pi, p in enumerate(processes, 1):
    subdir = os.path.join(component_dir, f"C{pi}")
    os.makedirs(subdir, exist_ok=True)
    # Save Y.csv (outputs for this process) -- remove first row (identifier)
    if len(process_output_idx[p]) > 0:
        Y_block = Y.iloc[1:, process_output_idx[p]]
        Y_block.to_csv(os.path.join(subdir, "Y.csv"), index=False, header=False)
    else:
        pd.DataFrame().to_csv(os.path.join(subdir, "Y.csv"), index=False, header=False)
    # Save U.csv (inputs for this process) -- remove first row (identifier)
    if len(process_input_idx[p]) > 0:
        U_block = U.iloc[1:, process_input_idx[p]]
        U_block.to_csv(os.path.join(subdir, "U.csv"), index=False, header=False)
    else:
        pd.DataFrame().to_csv(os.path.join(subdir, "U.csv"), index=False, header=False)
    # Save A.csv (diagonal block)
    idx = process_state_idx[p]
    if len(idx) > 0:
        A_block = A_n.iloc[idx, idx]
        A_block.to_csv(os.path.join(subdir, "A.csv"), index=False, header=False)
    else:
        pd.DataFrame().to_csv(os.path.join(subdir, "A.csv"), index=False, header=False)
    # Save B.csv (block for this process)
    if len(idx) > 0 and len(process_input_idx[p]) > 0:
        B_block = B_n.iloc[idx, process_input_idx[p]]
        B_block.to_csv(os.path.join(subdir, "B.csv"), index=False, header=False)
    else:
        pd.DataFrame().to_csv(os.path.join(subdir, "B.csv"), index=False, header=False)
    # Save C.csv (block for this process)
    if len(process_output_idx[p]) > 0 and len(idx) > 0:
        C_block = C_n.iloc[process_output_idx[p], idx]
        C_block.to_csv(os.path.join(subdir, "C.csv"), index=False, header=False)
    else:
        pd.DataFrame().to_csv(os.path.join(subdir, "C.csv"), index=False, header=False)
    # Save Q.csv (block for this process)
    if len(idx) > 0:
        Q_block = Q_n.iloc[idx, idx]
        Q_block.to_csv(os.path.join(subdir, "Q.csv"), index=False, header=False)
    else:
        pd.DataFrame().to_csv(os.path.join(subdir, "Q.csv"), index=False, header=False)
    # Save R.csv (block for this process)
    if len(process_output_idx[p]) > 0:
        R_block = R_n.iloc[process_output_idx[p], process_output_idx[p]]
        R_block.to_csv(os.path.join(subdir, "R.csv"), index=False, header=False)
    else:
        pd.DataFrame().to_csv(os.path.join(subdir, "R.csv"), index=False, header=False)
    # Save x0.csv (initial state for this process)
    if len(idx) > 0:
        x0_block = x0_complete[idx]
        pd.DataFrame(x0_block.reshape(1, -1)).to_csv(os.path.join(subdir, "x0.csv"), index=False, header=False)
    else:
        pd.DataFrame().to_csv(os.path.join(subdir, "x0.csv"), index=False, header=False)

print(f"Component split complete. Files saved in {component_dir}")