# utils/dataset_utils.py
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Subset
from collections import defaultdict
import pickle
import os
import logging

logger = logging.getLogger('GFedCL')

class ILIDataset(Dataset):
    """Custom Dataset for ILI time series data"""
    
    def __init__(self, data, sequence_length=10, predict_steps=1):
        """
        Args:
            data: numpy array of shape (num_weeks, num_states)
            sequence_length: number of weeks to use as input
            predict_steps: number of steps ahead to predict
        """
        self.data = torch.FloatTensor(data)
        self.sequence_length = sequence_length
        self.predict_steps = predict_steps
        
        # Calculate number of valid sequences
        self.num_sequences = len(data) - sequence_length - predict_steps + 1
        
    def __len__(self):
        return self.num_sequences
    
    def __getitem__(self, idx):
        # Get sequence of weeks as input
        x = self.data[idx:idx + self.sequence_length]  # Shape: (sequence_length, num_states)
        # Get target (next week's values)
        y = self.data[idx + self.sequence_length]  # Shape: (num_states,)
        
        # Flatten the input sequence
        x_flat = x.reshape(-1)  # Shape: (sequence_length * num_states,)
        
        # For regression, keep y as continuous values (already normalized to [0,1])
        return x_flat, y

def setup_ili_loaders(opt):
    """Create ILI data loaders for federated learning with regression"""
    
    # Load preprocessed data
    processed_path = os.path.join(opt.data_dir, 'processed', 'ili_processed.pkl')
    with open(processed_path, 'rb') as f:
        processed_data = pickle.load(f)
    
    train_data = processed_data['train_data']
    test_data = processed_data['test_data']
    client_assignments = processed_data['client_assignments']
    task_splits = processed_data['task_splits']
    
    client_loaders = defaultdict(dict)
    
    # Create data loaders for each client and task
    for client_id in range(opt.num_clients):
        client_states = client_assignments[client_id]
        
        # Get test data for this client (same for all tasks)
        client_test_data = test_data[:, client_states]
        test_dataset = ILIDataset(client_test_data, opt.sequence_length)
        
        for task_id in range(opt.num_task):
            # Get task-specific training data
            start_row = task_splits[task_id]['start']
            end_row = task_splits[task_id]['end']
            
            # Get this client's states for this task's time period
            task_train_data = train_data[start_row:end_row, client_states]
            
            # Create datasets
            train_dataset = ILIDataset(task_train_data, opt.sequence_length)
            
            # Create data loaders
            client_loaders[client_id][task_id] = {
                'train': DataLoader(
                    train_dataset,
                    batch_size=opt.batch_size,
                    shuffle=True,
                    num_workers=opt.num_workers,
                    pin_memory=opt.pin_memory,
                    drop_last=True  # Drop last batch if incomplete
                ),
                'test': DataLoader(
                    test_dataset,
                    batch_size=opt.batch_size,
                    shuffle=False,
                    num_workers=opt.num_workers,
                    pin_memory=opt.pin_memory,
                    drop_last=False
                )
            }
            
            logger.info(f"Client {client_id}, Task {task_id}: "
                       f"Train samples: {len(train_dataset)}, "
                       f"Test samples: {len(test_dataset)}")
    
    # Log summary
    logger.info("\n=== ILI Regression Dataset Summary ===")
    logger.info(f"Number of clients: {opt.num_clients}")
    logger.info(f"States per client: {opt.states_per_client}")
    logger.info(f"Number of tasks: {opt.num_task}")
    logger.info(f"Weeks per task: {opt.weeks_per_task}")
    logger.info(f"Sequence length: {opt.sequence_length}")
    logger.info(f"Total training weeks: {opt.train_rows}")
    logger.info(f"Total testing weeks: {opt.test_rows}")
    logger.info(f"Target: Continuous ILI values (normalized)")
    logger.info("=====================================\n")
    
    return client_loaders