import argparse
import csv
import gc
import os
from dataclasses import dataclass

import git
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

from layer_freeze.model_agnostic_freezing import FrozenModel


@dataclass
class ModelPerfStats:
    median_fwd_time: float
    std_fwd_time: float
    median_bwd_time: float
    std_bwd_time: float
    median_loop_time: float
    memory: float
    frac_trainable_params: float
    optimizer_time: float


def frac_trainable_params(model: nn.Module) -> float:
    return sum(p.numel() for p in model.parameters() if p.requires_grad) / sum(
        p.numel() for p in model.parameters()
    )


def measure_model_training_hw_metrics(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    warmup_passes: int = 100,
    device: torch.device | None = None,
) -> ModelPerfStats:
    """Measure the model's forward, backward, and optimizer times."""
    if device is None:
        device = torch.device("cpu")

    model = model.to(device)

    # perform warmup passes
    for i, batch in enumerate(tqdm(dataloader, desc="Warming up model", total=warmup_passes)):
        x: torch.Tensor
        y: torch.Tensor
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        pred = model(x)
        loss = torch.nn.functional.cross_entropy(pred, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if i > warmup_passes:
            break

    fwd_times = []
    bwd_times = []
    loop_times = []
    optimizer_times = []
    max_memory_usage_recorded = 0
    for batch in tqdm(dataloader, desc="Measuring forward times"):
        loop_start = torch.cuda.Event(enable_timing=True)
        loop_end = torch.cuda.Event(enable_timing=True)

        fwd_start = torch.cuda.Event(enable_timing=True)
        fwd_end = torch.cuda.Event(enable_timing=True)

        bwd_start = torch.cuda.Event(enable_timing=True)
        bwd_end = torch.cuda.Event(enable_timing=True)

        opt_start = torch.cuda.Event(enable_timing=True)
        opt_end = torch.cuda.Event(enable_timing=True)

        loop_start.record()
        x, y = batch
        x = x.to(device)
        y = y.to(device)

        fwd_start.record()
        pred = model(x)
        loss = torch.nn.functional.cross_entropy(pred, y)
        loss = loss.to(device)
        fwd_end.record()
        torch.cuda.synchronize()

        fwd_times.append(fwd_start.elapsed_time(fwd_end))

        bwd_start.record()
        loss.backward()
        bwd_end.record()
        torch.cuda.synchronize()
        bwd_times.append(bwd_start.elapsed_time(bwd_end))

        opt_start.record()
        optimizer.step()
        optimizer.zero_grad()
        opt_end.record()
        torch.cuda.synchronize()
        optimizer_times.append(opt_start.elapsed_time(opt_end))

        loop_end.record()
        torch.cuda.synchronize()
        loop_times.append(loop_start.elapsed_time(loop_end))
        max_memory_usage_recorded = max(
            max_memory_usage_recorded, torch.cuda.memory_allocated(device) / (1024**2)
        )

    # assert torch.cuda.memory_allocated(device) / (1024**2) == max_memory_usage_recorded
    return ModelPerfStats(
        median_fwd_time=np.median(fwd_times),
        std_fwd_time=np.std(fwd_times),
        median_bwd_time=np.median(bwd_times),
        std_bwd_time=np.std(bwd_times),
        median_loop_time=np.median(loop_times),
        memory=max_memory_usage_recorded,
        frac_trainable_params=frac_trainable_params(model),
        optimizer_time=np.median(optimizer_times),
    )


def get_max_fidelity(model: nn.Module, extra_unwrap=None) -> float:
    return FrozenModel(
        n_trainable=1,
        base_model=model,
        print_summary=False,
        unwrap=model.__class__ if extra_unwrap is None else (model.__class__, extra_unwrap),
    ).max_fidelity


def get_model(model: str) -> nn.Module:
    match model:
        case "resnet18":
            return torchvision.models.resnet18(weights=None)
        case "resnet34":
            return torchvision.models.resnet34(weights=None)
        case "resnet50":
            return torchvision.models.resnet50(weights=None)
        case "resnet101":
            return torchvision.models.resnet101(weights=None)
        case "resnet152":
            return torchvision.models.resnet152(weights=None)
        case "vit_b_16":
            return torchvision.models.vit_b_16(weights=None)
        case "vit_l_16":
            return torchvision.models.vit_l_16(weights=None)
        case _:
            raise ValueError(f"Model {model} not supported")


def save_stats_to_csv(stats: list[ModelPerfStats], model_name: str, batch_size: int) -> str:
    """Save the collected statistics to a CSV file."""
    repo = git.Repo(".", search_parent_directories=True)
    base_path = os.path.join(repo.working_tree_dir, "output", "model_perf", "csv_data")
    os.makedirs(base_path, exist_ok=True)

    csv_path = os.path.join(base_path, f"{model_name}_bs{batch_size}.csv")

    with open(csv_path, "w", newline="") as csvfile:
        fieldnames = [
            "n_trainable",
            "frac_trainable_params",
            "median_fwd_time",
            "std_fwd_time",
            "median_bwd_time",
            "std_bwd_time",
            "median_loop_time",
            "memory",
            "optimizer_time",
        ]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        writer.writeheader()
        for i, stat in enumerate(stats):
            writer.writerow(
                {
                    "n_trainable": i + 1,  # n_trainable starts from 1
                    "frac_trainable_params": stat.frac_trainable_params,
                    "median_fwd_time": stat.median_fwd_time,
                    "std_fwd_time": stat.std_fwd_time,
                    "median_bwd_time": stat.median_bwd_time,
                    "std_bwd_time": stat.std_bwd_time,
                    "median_loop_time": stat.median_loop_time,
                    "memory": stat.memory,
                    "optimizer_time": stat.optimizer_time,
                }
            )

    print(f"Data saved to {csv_path}")
    return csv_path


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        type=str,
        required=True,
        help="Model to collect metrics for (e.g., resnet18, vit_b_16)",
    )
    parser.add_argument(
        "--batch_size", type=int, required=True, help="Batch size to use for data collection"
    )
    parser.add_argument(
        "--num_dataloader_workers", type=int, default=8, help="Number of workers for data loading"
    )
    parser.add_argument("--device", type=str, default="cuda", help="Device to run on (cuda or cpu)")
    parser.add_argument(
        "--warmup_passes", type=int, default=100, help="Number of warmup passes before measurement"
    )
    args = parser.parse_args()

    print(f"Collecting metrics for {args.model} with batch size {args.batch_size}")

    # Create dataset with appropriate size for this batch size
    transform = transforms.Compose(
        [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()]
    )
    dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
    dataset = torch.utils.data.Subset(
        dataset, range(200 * args.batch_size) if 200 * args.batch_size < 50_000 else range(50_000)
    )
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_dataloader_workers,
        persistent_workers=True,
    )

    from torchvision.models.vision_transformer import Encoder

    if args.model == "vit_b_16":
        max_fidelity = 1
    # Get the model and measure its performance
    base_model = get_model(args.model)
    max_fidelity = get_max_fidelity(base_model, Encoder if args.model == "vit_b_16" else None)

    device = torch.device(args.device)
    stats = []
    for n_trainable in range(1, max_fidelity):
        print(f"Measuring with n_trainable={n_trainable}/{max_fidelity - 1}")
        frozen_model = FrozenModel(
            n_trainable=n_trainable,
            base_model=base_model,
            print_summary=True,
            unwrap=base_model.__class__
            if args.model != "vit_b_16"
            else (Encoder, base_model.__class__),
        )
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, frozen_model.parameters()), lr=0.001
        )
        stats.append(
            measure_model_training_hw_metrics(
                frozen_model, dataloader, optimizer, device=device, warmup_passes=args.warmup_passes
            )
        )
        del frozen_model
        del optimizer
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        gc.collect()

    # Save the collected stats to CSV
    csv_path = save_stats_to_csv(stats, args.model, args.batch_size)

    print(f"Data collection complete. Results saved to: {csv_path}")
