import argparse
import gc
import os
from collections.abc import Iterator

import numpy as np
import pandas as pd
import torch
from torch import Tensor

from typing import cast
from optimization.optimizers import OPTIMIZERS_NAMES
from models import PC
from scripts.logger import Logger
from scripts.utils import (
    retrieve_tboard_runs,
    set_global_seed,
    setup_model,
    num_parameters,
)
from utilities import PCS_MODELS
from optimization.optimizers import OPTIMIZERS_NAMES, setup_optimizer

parser = argparse.ArgumentParser(description="Benchmarking script")

parser.add_argument("--path", type=str, help="The path to save the results")

parser.add_argument(
    "--input-shape", type=int, nargs='+', help="The shape of the input data (without batch)"
)
parser.add_argument(
    "--num-units",
    type=int,
    default="32",
    help="A numbers of units in each layer to benchmark, separated by space",
)
parser.add_argument("--model", type=str, choices=PCS_MODELS, help="The PC to benchmark")
parser.add_argument(
    "--num-iterations",
    type=int,
    default=50,
    help="The number of iterations",
)
parser.add_argument(
    "--burnin-iterations",
    type=int,
    default=5,
    help="Burnin iterations (additional to --num-iterations)",
)
parser.add_argument("--device", type=str, default="cuda", help="The device id")
parser.add_argument("--batch-size", type=int, default=512, help="The batch size")
parser.add_argument(
    "--complex",
    action="store_true",
    default=False,
    help="Whether to use complex parameters",
)
parser.add_argument(
    "--mono-num-units",
    type=int,
    default=8,
    help="The number of units in monotonic PCs, for ExpSOS models only",
)
parser.add_argument(
    "--min-bubble-radius", type=float, default=20.0, help="Bubble sizes minimum"
)
parser.add_argument(
    "--scale-bubble-radius", type=float, default=1.0, help="Bubble sizes scaler"
)
parser.add_argument(
    "--exp-bubble-radius",
    type=float,
    default=2.0,
    help="The exponent for computing the bubble sizes",
)
parser.add_argument(
    "--backprop",
    action="store_true",
    default=False,
    help="Whether to benchmark also backpropagation",
)
parser.add_argument("--metric", type=str, default="bpd", help="The test metric to log")
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
parser.add_argument(
    "--optimizer",
    choices=OPTIMIZERS_NAMES,
    default=OPTIMIZERS_NAMES[0],
    required=False,
    help="Optimiser to use",
)
parser.add_argument(
    "--use-tucker",
    action="store_true",
    default=False,
    help="Whether to use Tucker layers for OSOS",
)


def run_benchmark(
    model: PC,
    batch_shape: torch.Size,
    *,
    device: torch.device,
    num_iterations: int,
    burnin_iterations: int = 1,
    optimizer: torch.optim.Optimizer | None = None,
    partition_function_only: bool = False,
) -> tuple[list[float], list[float]]:
    def infinite_dataloader() -> Iterator[list[Tensor] | tuple[Tensor] | Tensor]:
        while True:
            yield torch.randn(batch_shape).flatten(start_dim=1)

    model = model.to(device)

    elapsed_times = list()
    gpu_memory_peaks = list()
    for i, batch in enumerate(infinite_dataloader()):
        if i == num_iterations + burnin_iterations:
            break
        if isinstance(batch, (tuple, list)):
            batch = batch[0]
        # Run GC manually and then disable it
        gc.collect()
        gc.disable()
        # Reset peak memory usage statistics
        torch.cuda.reset_peak_memory_stats(device)
        # torch.cuda.synchronize(device)  # Synchronize CUDA operations
        batch = batch.to(device)
        # torch.cuda.synchronize(device)  # Make sure the batch is already loaded (do not take into account this!)
        # start_time = time.perf_counter()
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record(torch.cuda.current_stream(device))

        if partition_function_only:
            lls = model.log_partition()
        else:
            lls = model.log_likelihood(batch)
        
        if optimizer is not None:
            loss = -lls.mean()
            loss.backward(retain_graph=False)  # Free the autodiff graph
            optimizer.step()

        end.record(torch.cuda.current_stream(device))
        torch.cuda.synchronize(device)  # Synchronize CUDA Kernels before measuring time
        # end_time = time.perf_counter()
        gpu_memory_peaks.append(torch.cuda.max_memory_allocated(device))

        if optimizer is not None:
            optimizer.zero_grad()  # Free gradients tensors
            
        gc.enable()  # Enable GC again
        gc.collect()  # Manual GC
        # elapsed_times.append(end_time - start_time)
        elapsed_times.append(start.elapsed_time(end) * 1e-3)

    # Discard burnin iterations and compute averages
    elapsed_times = elapsed_times[burnin_iterations:]
    gpu_memory_peaks = gpu_memory_peaks[burnin_iterations:]
    return elapsed_times, gpu_memory_peaks


if __name__ == "__main__":
    args = parser.parse_args()
    set_global_seed(args.seed)

    args.input_shape = tuple(args.input_shape)
    logger = Logger("benchmark", verbose=True)
    batch_shape = torch.Size([args.batch_size] + list(args.input_shape))
    device = torch.device(args.device)

    metadata = {
        'type': 'image',
        'interval': [0, args.num_units-1],
        'image_shape': args.input_shape,
        'num_variables': np.prod(args.input_shape).item()
    }

    model = None
    num_params = -1
    try:
        model = setup_model(
            args.model,
            metadata,
            logger,
            region_graph="qt",
            num_components=1,
            num_units=args.num_units,
            mono_num_units=args.mono_num_units,
            mono_clamp=True if args.model in ["MPC", "ExpSOS"] else False,
            complex=args.complex,
            seed=args.seed,
            use_tucker=args.use_tucker,
            benchmark=True # Does not instantiate the circuit for log Z
        )
        model = cast(PC, model)

        optimizer = None
        if args.optimizer is not None:
            optimizer = setup_optimizer(model.parameters(), args.optimizer, learning_rate=1e-4)
        num_params = num_parameters(model)

        elapsed_times, gpu_memory_peaks = run_benchmark(
            model,
            batch_shape,
            device=device,
            num_iterations=args.num_iterations,
            burnin_iterations=args.burnin_iterations,
            optimizer=optimizer,
            partition_function_only=False,
        )
        mu_time = np.mean(elapsed_times)
        peak_gpu_memory = np.max(gpu_memory_peaks)

        if args.optimizer is not None:
            del optimizer
    except (RuntimeError, torch.cuda.OutOfMemoryError):
        mu_time, peak_gpu_memory = np.nan, np.nan

    if model is not None:
        del model

    benchmark_results = {
        "model": args.model,
        "exp_alias": (
            ("complex" if args.complex else "real")
            if "SOS" in args.model
            else ""
        ),
        "time": mu_time,
        "gpu_memory": peak_gpu_memory,
        "num_components": 1,
        "num_units": args.num_units,
        "optimizer": args.optimizer,
        "batch_size": args.batch_size,
        "input_shape": ','.join([str(x) for x in args.input_shape]),
        "seed": args.seed,
        "num_params": num_params,
        "use_tucker": args.use_tucker,
        # metric: metric_value,
    }
        
    path = os.path.join(args.path,)
    os.makedirs(path, exist_ok=True)
    filename = "-".join([f'{k}_{v}' for k, v in benchmark_results.items() 
                         if k not in ["time", "gpu_memory", "num_params"]])
    filepath = os.path.join(path, f"{filename}.csv")
    pd.DataFrame.from_dict([benchmark_results]).to_csv(filepath)
