# main_experiment.py
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import time
import statistics
from pathlib import Path

# Custom modules
import config_utils
import data_utils
from nc_models import get_model  # Models are defined in nc_models.py
import metrics_utils
import trainer_module


def main():
    args = config_utils.parse_arguments()
    config = config_utils.create_run_config(args)

    logger = config_utils.setup_logging(config['log_file'])
    config_utils.set_random_seeds(args.seed)  # Initial seed for the whole experiment setup if needed outside runs

    logger.info(f"Starting Experiment Run: {config['run_name']}")
    logger.info(f"Full Configuration: {config}")
    logger.info(f"Parsed Arguments: {args}")
    logger.info(f"PyTorch version: {torch.__version__}")
    logger.info(
        f"Torchvision version: {getattr(torchvision, '__version__', 'N/A')}")  # torchvision might not have __version__

    device = config_utils.get_device()

    # --- Data Preparation ---
    transform, input_ch, dataset_cls = data_utils.get_dataset_transforms(config['dataset'])

    train_dataset, test_dataset = data_utils.load_datasets(dataset_cls, transform)
    num_classes = data_utils.get_num_classes(train_dataset)
    logger.info(
        f"Dataset: {config['dataset']}, Detected Classes: {num_classes}, Input Channels for Model: {input_ch}")
    logger.info(f"Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")

    train_loader, test_loader = data_utils.create_data_loaders(
        train_dataset, test_dataset, config['batch_size'], args.workers, device.type
    )

    # --- Experiment Repetitions ---
    all_runs_metrics_files = []
    all_runs_best_accuracies = []
    all_runs_successful_seeds = []
    experiment_start_time = time.time()

    for i_repeat in range(args.repeat):
        current_run_seed = args.seed + i_repeat  # Increment seed for each repetition
        run_name_info = f"Run {i_repeat + 1}/{args.repeat} (Seed {current_run_seed})"
        logger.info(f"--- Starting {run_name_info} ---")
        config_utils.set_random_seeds(current_run_seed)

        model = get_model(
            config['model'],
            input_channels=input_ch,  # Use the potentially adjusted input_ch
            num_classes=num_classes,
            device=device,
            layer_norm=args.layer_norm
        )
        logger.info(
            f"Model '{config['model']}' initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters.")

        criterion = nn.CrossEntropyLoss()
        if args.optimizer.lower() == 'adam':
            optimizer = optim.Adam(model.parameters(), lr=args.lr)
        elif args.optimizer.lower() == 'sgd':
            optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.95)
        else:
            logger.error(f"Unsupported optimizer: {args.optimizer}")
            raise ValueError(f"Unsupported optimizer: {args.optimizer}")

        scheduler = optim.lr_scheduler.StepLR(
            optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma
        )
        logger.info(f"Using StepLR scheduler: step_size={args.lr_step_size}, gamma={args.lr_gamma}")

        # Pass relevant parts of config to trainer_module.train_model
        trainer_config = {
            'num_epochs': config['num_epochs'],
            'log_interval': config['log_interval'],
            'repeat': args.repeat  # To show Run x/N in trainer logs
        }

        # --- Train one run ---
        # The retry logic from original script is removed for simplicity here. Can be added back if needed.
        success, best_acc_this_run, metrics_file_path_this_run = trainer_module.train_model(
            model, train_loader, test_loader, criterion, optimizer, scheduler,
            config=trainer_config, device=device, num_classes=num_classes,
            output_dir=config['output_dir'],
            eval_fn=lambda m, dl, desc: metrics_utils.evaluate_accuracy(m, dl, device, desc),
            nc_fn=lambda m, dl: metrics_utils.measure_neural_collapse(m, dl, num_classes, device),
            save_metrics_fn=metrics_utils.save_metrics_to_csv,
            repeat_idx=i_repeat, run_seed=current_run_seed
        )

        if success:
            all_runs_metrics_files.append(str(metrics_file_path_this_run))  # Ensure it's a string for pd.read_csv
            all_runs_best_accuracies.append(best_acc_this_run)
            all_runs_successful_seeds.append(current_run_seed)

            # --- Final Evaluation for this run (after all epochs) ---
            logger.info(f"{run_name_info} - Final Evaluation (model state after last epoch):")
            final_run_test_acc = metrics_utils.evaluate_accuracy(model, test_loader, device,
                                                                 f"{run_name_info} - Final Test")
            final_run_nc1, final_run_nc2, final_run_nc3 = metrics_utils.measure_neural_collapse(model, train_loader,
                                                                                                num_classes,
                                                                                                device)  # NC on train_loader
            logger.info(f"  Final Test Accuracy for run: {final_run_test_acc:.2f}%")
            if final_run_nc1 is not None:
                logger.info(
                    f"  Final NC1: {final_run_nc1.item():.4f}, NC2: {final_run_nc2.item() if final_run_nc2 is not None else 'NaN':.4f}, NC3: {final_run_nc3.item() if final_run_nc3 is not None else 'NaN':.4f}")
            logger.info(f"  Best Pre-Training Test Accuracy during this run: {best_acc_this_run:.2f}%")

            if args.save_model:
                model_save_path = Path(config['output_dir']) / f"model_run_{i_repeat}_seed_{current_run_seed}.pth"
                torch.save(model.state_dict(), model_save_path)
                logger.info(f"Model for {run_name_info} saved to {model_save_path}")
        else:
            logger.error(f"{run_name_info} failed.")
            # If a run fails, we might not want to include it in aggregation or stop the experiment.
            # For now, it just means its metrics_file might not be complete or exist.

    # --- Aggregation and Summary (after all repeats) ---
    if args.repeat > 1 and all_runs_metrics_files:
        logger.info("Aggregating results from all successful/attempted runs...")
        aggregated_metrics_df = metrics_utils.aggregate_run_metrics(all_runs_metrics_files, config['num_epochs'])
        if not aggregated_metrics_df.empty:
            agg_metrics_file_path = Path(config['output_dir']) / 'metrics_aggregated.csv'
            aggregated_metrics_df.to_csv(agg_metrics_file_path, index=False)
            logger.info(f"Aggregated metrics saved to {agg_metrics_file_path}")
        else:
            logger.warning("Aggregation resulted in an empty DataFrame. No aggregated metrics saved.")

    # Log overall best accuracies
    if all_runs_best_accuracies:
        mean_best_acc = statistics.mean(all_runs_best_accuracies)
        std_best_acc = statistics.stdev(all_runs_best_accuracies) if len(all_runs_best_accuracies) > 1 else 0.0
        logger.info(
            f"Average Best Pre-Training Test Accuracy across {len(all_runs_best_accuracies)} run(s): "
            f"{mean_best_acc:.2f}% (±{std_best_acc:.2f}%)"
        )

    total_experiment_time = time.time() - experiment_start_time
    logger.info(
        f"Total experiment runtime: {total_experiment_time:.2f} seconds ({total_experiment_time / 3600:.2f} hours)")

    # --- Generate Summary File ---
    summary_file_path = Path(config['output_dir']) / 'summary_experiment.txt'
    with open(summary_file_path, 'w') as f:
        f.write(f"Experiment Run Name: {config['run_name']}\n")
        f.write(f"Model: {config['model']}\n")
        f.write(f"Dataset: {config['dataset']}\n")
        f.write(f"Input Channels (for model): {input_ch}\n")
        f.write(f"Number of Classes: {num_classes}\n")
        f.write(f"Layer Normalization: {args.layer_norm}\n")
        f.write(f"Epochs per run: {config['num_epochs']}\n")
        f.write(f"Number of runs (repeats): {args.repeat}\n")
        f.write(f"Initial Seed: {args.seed}\n")
        f.write(f"Successful Seeds: {all_runs_successful_seeds}\n")
        f.write(f"Optimizer: {args.optimizer}, LR: {args.lr}\n")
        f.write(f"LR Scheduler: StepLR (step_size={args.lr_step_size}, gamma={args.lr_gamma})\n")
        if all_runs_best_accuracies:
            f.write(f"Avg Best Pre-Train Test Acc: {mean_best_acc:.2f}% (±{std_best_acc:.2f}%)\n")
        else:
            f.write("No successful runs to report average accuracy.\n")
        f.write(f"Total Runtime: {total_experiment_time / 60:.2f} minutes\n")
        f.write(f"Output Directory: {config['output_dir']}\n")

    logger.info(f"Experiment summary saved to {summary_file_path}")
    logger.info(f"All experiment outputs are in: {config['output_dir']}")
    logger.info("--- Experiment Finished ---")


if __name__ == "__main__":
    main()