# preprocess_ili_data.py
import numpy as np
import os
import pickle

def preprocess_ili_data(input_path='data/state360.txt', output_dir='data/processed'):
    """
    Preprocess the ILI data for federated learning
    
    Structure:
    - 7 clients, each with 7 states
    - 300 rows for training (5 tasks x 60 rows each)
    - 60 rows for testing
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Read the data
    print(f"Loading data from {input_path}")
    data = []
    with open(input_path, 'r') as f:
        for line in f:
            values = [float(x) for x in line.strip().split(',')]
            data.append(values)
    
    data = np.array(data, dtype=np.float32)
    print(f"Loaded data shape: {data.shape}")  # Should be (360, 49)
    
    # Normalize data
    data_min = data.min()
    data_max = data.max()
    data_normalized = (data - data_min) / (data_max - data_min + 1e-8)
    
    # Split into train and test
    train_data = data_normalized[:300]  # First 300 rows
    test_data = data_normalized[300:]   # Last 60 rows
    
    print(f"Train data shape: {train_data.shape}")  # (300, 49)
    print(f"Test data shape: {test_data.shape}")    # (60, 49)
    
    # Assign states to clients (7 states per client)
    client_assignments = {}
    for client_id in range(49):
        start_state = client_id * 1
        end_state = start_state + 1
        client_assignments[client_id] = list(range(start_state, end_state))
        print(f"Client {client_id}: states {client_assignments[client_id]}")
    
    # Create task splits for each client
    processed_data = {
        'train_data': train_data,
        'test_data': test_data,
        'client_assignments': client_assignments,
        'normalization': {'min': data_min, 'max': data_max},
        'task_splits': {}
    }
    
    # Define task boundaries
    for task_id in range(6):
        start_row = task_id * 50
        end_row = start_row + 50
        processed_data['task_splits'][task_id] = {
            'start': start_row,
            'end': end_row,
            'rows': list(range(start_row, end_row))
        }
        print(f"Task {task_id}: rows {start_row}-{end_row-1}")
    
    # Save processed data
    output_path = os.path.join(output_dir, 'ili_processed.pkl')
    with open(output_path, 'wb') as f:
        pickle.dump(processed_data, f)
    
    print(f"\nProcessed data saved to {output_path}")
    
    # Also save client-specific data files for easier access
    for client_id in range(7):
        client_states = client_assignments[client_id]
        client_data = {
            'client_id': client_id,
            'states': client_states,
            'train_data': train_data[:, client_states],  # Only this client's states
            'test_data': test_data[:, client_states],
            'task_data': {}
        }
        
        # Split training data by tasks
        for task_id in range(6):
            start_row = task_id * 50
            end_row = start_row + 50
            client_data['task_data'][task_id] = train_data[start_row:end_row, client_states]
        
        client_path = os.path.join(output_dir, f'client_{client_id}_data.pkl')
        with open(client_path, 'wb') as f:
            pickle.dump(client_data, f)
        
        print(f"Client {client_id} data saved to {client_path}")
    
    return processed_data

if __name__ == "__main__":
    preprocess_ili_data()