# main_ili.py
import torch
import numpy as np
import random
import os
import logging
import matplotlib
matplotlib.use('Agg')

# Import the ILI configuration
from configs.ILI import parse_args
from gfedcl import ParallelServerGFedCL
from utils.plot_utils import plot_results, plot_all_tasks_accuracy
from process_ili_data import preprocess_ili_data

# Set random seeds for reproducibility
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def setup_logging(opt):
    """Set up logging configuration"""
    os.makedirs(opt.output_dir, exist_ok=True)
    
    log_file_path = opt.log_path or os.path.join(opt.output_dir, 'run.log')
    
    # Remove existing handlers
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    
    # Set up logging
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler(log_file_path, mode='w')
        ]
    )
    
    logger = logging.getLogger('GFedCL-ILI')
    logger.info(f"Logging initialized. Log file: {log_file_path}")
    
    return logger

def apply_ablation(opt):
    if opt.ablation == "no_dp":
        opt.dp = False
    if getattr(opt, "output_dir_is_default", False) and opt.ablation != "none":
        opt.output_dir = f"{opt.output_dir}_{opt.ablation}"
        opt.log_path = os.path.join(opt.output_dir, "run.log")
    return opt

def main():
    opt = parse_args()
    # Set random seed
    set_seed(opt.seed)

    opt = apply_ablation(opt)
    
    # Set up logging
    logger = setup_logging(opt)
    
    # Preprocess data if needed
    if not os.path.exists(os.path.join(opt.data_dir, 'processed', 'ili_processed.pkl')):
        logger.info("Preprocessing ILI data...")
        preprocess_ili_data(
            input_path=os.path.join(opt.data_dir, 'state360.txt'),
            output_dir=os.path.join(opt.data_dir, 'processed')
        )
    
    logger.info('Initializing GFedCL for ILI time series data...')
    logger.info(f'Dataset: {opt.dataset}')
    logger.info(f'Number of clients: {opt.num_clients}')
    logger.info(f'States per client: {opt.states_per_client}')
    logger.info(f'Number of tasks: {opt.num_task}')
    logger.info(f'Weeks per task: {opt.weeks_per_task}')
    logger.info(f'Sequence length: {opt.sequence_length}')
    logger.info(f'Output directory: {opt.output_dir}')
    logger.info(f'Ablation: {opt.ablation}')
    
    # Initialize GFedCL
    gfedcl = ParallelServerGFedCL(opt)
    
    logger.info('Starting training...')
    accuracy_results, all_tasks_accuracy = gfedcl.train_GFedCL()
    
    logger.info('Training completed.')
    
    # Plot and save results
    plots_dir = plot_results(accuracy_results, opt.output_dir)
    
    # Plot all tasks accuracy over time
    if all_tasks_accuracy:
        plot_all_tasks_accuracy(opt, all_tasks_accuracy, plots_dir)
    
    # Print final summary
    logger.info('===== TRAINING SUMMARY =====')
    logger.info(f'Dataset: {opt.dataset}')
    logger.info(f'Number of clients: {opt.num_clients}')
    logger.info(f'Number of tasks: {opt.num_task}')
    logger.info(f'Overall average accuracy: {accuracy_results["overall_avg_acc"]:.2f}%')
    logger.info(f'Results saved to {opt.output_dir}')
    logger.info('===========================')

if __name__ == "__main__":
    main()
