import argparse
import os
import random
import time

import pandas as pd
import numpy as np
import torch

from run_benchmark import MODEL_CLASSES, COMPLEX_MODELS, ANALYTICAL_MODELS
from utils.data_utils import get_dataloaders
from utils.training_utils import (
    create_log_dir,
    get_device,
    get_dist_grid_codes,
    setup_pytorch,
)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_dir",
        required=True,
    )
    parser.add_argument(
        "--model",
        default="n-gnn",
        choices=[
            "ALL",
            "dist-flow",
            "lin-dist-flow",
            "n-gnn",
            "n-gnn-residuals",
            "n-gnn-loss-supervised",
            "n-gnn-complex",
            "n-gnn-complex-loss",
            "n-gnn-complex-residuals",
            "n-gnn-residuals-loss",
            "n-gat",
            "n-gat-residuals",
            "n-gat-loss-supervised",
            "n-gat-complex",
            "n-gat-wide",
            "n-gat-wide-residuals",
            "n-gat-wide-loss-supervised",
            "n-gat-wide-complex",
            "dc-pf",
            "dc-pf-slack",
        ],
    )
    parser.add_argument(
        "--save_results",
        action="store_true"
    )
    parser.add_argument(
        "--repeat",
        type=int,
        default=10,
    )
    parser.add_argument(
        "--cpu",
        action="store_true"
    )
    args = parser.parse_args()
    return args

def format_num(num):
    """
    Formats a number into a human-readable string (e.g., 1.2M, 4.5B).
    """
    if num < 1000:
        return str(num)
    
    for unit in ['K', 'M', 'B', 'T']:
        num /= 1000.0
        if abs(num) < 1000:
            return f"{num:.1f}{unit}"
            
    return f"{num:.1f}P" # Handles Quadrillions (Peta) just in case

def get_model_stats(args):
    # Argument parsing and validation
    data_dir = args.data_dir
    save_results = args.save_results
    batch_size = 1 # Testing inference for single grid
    repeat = args.repeat # How often to repeat the test (for inference speed)
    
    # Set up training, logging, and experiment cases
    setup_pytorch()
    
    log_dir = None
    if save_results:
        # Create a new log directory for each model
        log_dir = create_log_dir()
    
    grids_to_compare = get_dist_grid_codes(scenario=1)
    test_cases = [(grids_to_compare, None)]  # All grids scenario
    for grid in grids_to_compare:
        test_cases.append(([g for g in grids_to_compare if g != grid], grid))  # Leave-one-out scenarios

    # Set up results tracking
    if save_results and log_dir:
        results_file = os.path.join(log_dir, 'model_stats.csv')
        column_names = [
            'model',
            'testing_grid',
            'num_params',
            'inference_time_ms',
            'num_runs',
            'device'
        ]
        # Create a DataFrame for the results
        pd.DataFrame(columns=column_names).to_csv(results_file)
        print(f'\nStats will be saved to: {results_file}', flush=True)

    models_to_evaluate = [MODEL_CLASSES[args.model]] if args.model.upper() != 'ALL' else list(MODEL_CLASSES.values())

    # Run evaluations
    for training_grids, testing_grid in test_cases:
        # Get data loaders

        need_real_valued_data = len(set(models_to_evaluate) - set(COMPLEX_MODELS)) > 0
        need_complex_valued_data = len(set(models_to_evaluate) & set(COMPLEX_MODELS)) > 0

        # If no real models are being evaluated, skip loading real data
        if need_real_valued_data:
            _, _, loader_test_real = get_dataloaders(
                data_dir, training_grids, testing_grid, batch_size=batch_size
            )

        # If no complex models are being evaluated, skip loading complex data
        if need_complex_valued_data:
            _, _, loader_test_complex = get_dataloaders(
                data_dir, training_grids, testing_grid, batch_size=batch_size, complex=True
            )
        # Keep track of results
        results = []

        for model_class in models_to_evaluate:
            # Use complex data loaders for complex models
            if model_class in COMPLEX_MODELS:
                loader_test = loader_test_complex
            else:
                loader_test = loader_test_real
            print('\n--------------------------------------------------', flush=True)
            print(f'\nEvaluating model: {model_class.__name__} | Testing grid: {testing_grid if testing_grid else "All Grids"}', flush=True)

            # NOTE: WE DO NOT NEED A TRAINED MODEL TO GATHER STATS.
            #       INFERENCE TIME AND MODEL CAPACITY THE SAME FOR UNTRAINED MODEL.

            # 1. Calculating model capacity

            model = model_class()
            num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
            print(f"Num params: {format_num(num_params)}")

            # 2. Calculating inference speed (latency)
            device = 'cpu' if args.cpu or model_class in ANALYTICAL_MODELS else get_device()
            model = model.to(device)
            model.eval()

            # Warm up (crucial!) - Run 10 dummy passes to wake up the GPU/CPU caches
            random.seed(123)
            steps = 10
            warmup = random.choices(loader_test.dataset, k=steps)
            for data in warmup:
                data = data.to(device)
                _ = model(data)

            # Measure
            gpu_times = []
            with torch.no_grad():
                for _ in range(repeat):
                    for data in loader_test.dataset:
                        data = data.to(device)
                        # Sync before starting timer (only if using GPU)
                        if torch.cuda.is_available():
                            torch.cuda.synchronize()
                        start = time.time()
                        
                        _ = model(data)
                        
                        # Sync after ending timer
                        if torch.cuda.is_available():
                            torch.cuda.synchronize()
                        end = time.time()
                        
                        gpu_times.append(end - start)

            inference_time_ms = np.mean(gpu_times)*1000
            print(f"Inference time: {inference_time_ms:.4f} ms (averaged over {repeat} runs of {len(loader_test.dataset)} graphs)")
            
            # 3. Record results

            results.append(
                (
                    model_class.__name__,
                    testing_grid if testing_grid else 'all',
                    num_params,
                    inference_time_ms,
                    repeat,
                    device
                )
            )

        if save_results and log_dir:
            # Create a DataFrame for the results
            results_df = pd.DataFrame(results, columns=column_names)
            
            # Append to existing results file after each test case
            assert(results_file is not None), "results_file should not be None if save_results is True"
            results_df.to_csv(results_file, mode='a', index=True, header=False)
            print(f'\nAppended results to: {results_file}', flush=True)

        print('\n==================================================', flush=True)

    print('\nAll evaluations completed.', flush=True)

if __name__ == '__main__':
    args = parse_args()
    get_model_stats(args)
