import argparse
import csv
import gc
import os
from dataclasses import dataclass

import git
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data
from litgpt.config import Config
from torch.utils.data import DataLoader
from tqdm import tqdm

from layer_freeze.model_agnostic_freezing import FrozenModel
from saws.model import GPT_Scales, GPT_Scales_Detached


@dataclass
class ModelPerfStats:
    median_fwd_time: float
    std_fwd_time: float
    median_bwd_time: float
    std_bwd_time: float
    median_loop_time: float
    memory: float
    frac_trainable_params: float
    optimizer_time: float


def frac_trainable_params(model: nn.Module) -> float:
    return sum(p.numel() for p in model.parameters() if p.requires_grad) / sum(
        p.numel() for p in model.parameters()
    )


def measure_model_training_hw_metrics(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    warmup_passes: int = 100,
    device: torch.device | None = None,
) -> ModelPerfStats:
    """Measure the model's forward, backward, and optimizer times."""
    if device is None:
        device = torch.device("cpu")

    model = model.to(device)

    # perform warmup passes
    for i, batch in enumerate(
        tqdm(dataloader, desc="Warming up model", total=warmup_passes, disable=True)
    ):
        x: torch.Tensor
        y: torch.Tensor
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        pred = model(x)
        # Fix: reshape y to match the expected target size for cross_entropy
        y_reshaped = y.reshape(-1)
        loss = torch.nn.functional.cross_entropy(pred.view(-1, pred.size(-1)), y_reshaped)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if i > warmup_passes:
            break

    fwd_times = []
    bwd_times = []
    loop_times = []
    optimizer_times = []
    max_memory_usage_recorded = 0
    for batch in tqdm(dataloader, desc="Measuring forward times", disable=True):
        loop_start = torch.cuda.Event(enable_timing=True)
        loop_end = torch.cuda.Event(enable_timing=True)

        fwd_start = torch.cuda.Event(enable_timing=True)
        fwd_end = torch.cuda.Event(enable_timing=True)

        bwd_start = torch.cuda.Event(enable_timing=True)
        bwd_end = torch.cuda.Event(enable_timing=True)

        opt_start = torch.cuda.Event(enable_timing=True)
        opt_end = torch.cuda.Event(enable_timing=True)

        loop_start.record()
        x, y = batch
        x = x.to(device)
        y = y.to(device)

        fwd_start.record()
        pred = model(x)
        # Fix: reshape y to match the expected target size for cross_entropy
        y_reshaped = y.reshape(-1)
        loss = torch.nn.functional.cross_entropy(pred.view(-1, pred.size(-1)), y_reshaped)
        loss = loss.to(device)
        fwd_end.record()
        torch.cuda.synchronize()

        fwd_times.append(fwd_start.elapsed_time(fwd_end))

        bwd_start.record()
        loss.backward()
        bwd_end.record()
        torch.cuda.synchronize()
        bwd_times.append(bwd_start.elapsed_time(bwd_end))

        opt_start.record()
        optimizer.step()
        optimizer.zero_grad()
        opt_end.record()
        torch.cuda.synchronize()
        optimizer_times.append(opt_start.elapsed_time(opt_end))

        loop_end.record()
        torch.cuda.synchronize()
        loop_times.append(loop_start.elapsed_time(loop_end))
        max_memory_usage_recorded = max(
            max_memory_usage_recorded, torch.cuda.memory_allocated(device) / (1024**2)
        )

    # assert torch.cuda.memory_allocated(device) / (1024**2) == max_memory_usage_recorded
    return ModelPerfStats(
        median_fwd_time=np.median(fwd_times),
        std_fwd_time=np.std(fwd_times),
        median_bwd_time=np.median(bwd_times),
        std_bwd_time=np.std(bwd_times),
        median_loop_time=np.median(loop_times),
        memory=max_memory_usage_recorded,
        frac_trainable_params=frac_trainable_params(model),
        optimizer_time=np.median(optimizer_times),
    )


def get_max_fidelity(model: nn.Module, extra_unwrap=None) -> float:
    return FrozenModel(
        n_trainable=1,
        base_model=model,
        print_summary=False,
        unwrap=model.__class__ if extra_unwrap is None else (model.__class__, extra_unwrap),
    ).max_fidelity


def save_stats_to_csv(
    stats: list[ModelPerfStats],
    model_name: str,
    batch_size: int,
) -> str:
    """Save the collected statistics to a CSV file."""
    repo = git.Repo(".", search_parent_directories=True)
    base_path = os.path.join(repo.working_tree_dir, "output", "model_perf", "csv_data")
    os.makedirs(base_path, exist_ok=True)

    csv_path = os.path.join(base_path, f"{model_name}_bs{batch_size}.csv")

    with open(csv_path, "w", newline="") as csvfile:
        fieldnames = [
            "n_trainable",
            "frac_trainable_params",
            "median_fwd_time",
            "std_fwd_time",
            "median_bwd_time",
            "std_bwd_time",
            "median_loop_time",
            "memory",
            "optimizer_time",
        ]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        writer.writeheader()
        for i, stat in enumerate(stats):
            writer.writerow(
                {
                    "n_trainable": i + 1,  # n_trainable starts from 1
                    "frac_trainable_params": stat.frac_trainable_params,
                    "median_fwd_time": stat.median_fwd_time,
                    "std_fwd_time": stat.std_fwd_time,
                    "median_bwd_time": stat.median_bwd_time,
                    "std_bwd_time": stat.std_bwd_time,
                    "median_loop_time": stat.median_loop_time,
                    "memory": stat.memory,
                    "optimizer_time": stat.optimizer_time,
                }
            )

    print(f"Data saved to {csv_path}")
    return csv_path


class RandomDataset(data.Dataset):
    def __init__(self, size=1000, seq_length=256, vocab_size=50257):
        self.size = size
        self.seq_length = seq_length
        self.vocab_size = vocab_size

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        # Random input ids and random target
        input_ids = torch.randint(0, self.vocab_size, (self.seq_length,))
        # Target should be a single class index for each token position
        target = torch.randint(0, self.vocab_size, (self.seq_length,))
        return input_ids, target


configs = {
    "l24_tied": {
        "config": Config(
            n_embd=512,
            n_layer=24,
            n_head=16,
            block_size=1024,
            vocab_size=50257,
            norm_class_name="LayerNorm",
            bias=True,
        ),
        "share_embeddings": True,
        "mup_init": False,
        "batch_size": 13,
    },
    "l24_untied": {
        "config": Config(
            n_embd=512,
            n_layer=24,
            n_head=16,
            block_size=1024,
            vocab_size=50257,
            norm_class_name="LayerNorm",
            bias=True,
        ),
        "share_embeddings": False,
        "mup_init": False,
        "batch_size": 13,
    },
    "14m_tied": {
        "config": Config(
            n_embd=128,
            n_layer=8,
            n_head=2,
            block_size=1024,
            vocab_size=50257,
            norm_class_name="LayerNorm",
            bias=True,
        ),
        "share_embeddings": True,
        "mup_init": False,
        "batch_size": 9,
    },
    "14m_untied": {
        "config": Config(
            n_embd=128,
            n_layer=8,
            n_head=2,
            block_size=1024,
            vocab_size=50257,
            norm_class_name="LayerNorm",
            bias=True,
        ),
        "share_embeddings": False,
        "mup_init": False,
        "batch_size": 9,
    },
    "open_llama_3b_untied": {
        "config": Config(
            block_size=2048,
            vocab_size=32000,
            padding_multiple=64,
            n_layer=26,
            n_embd=3200,
            rotary_percentage=1.0,
            parallel_residual=False,
            bias=False,
            norm_class_name="RMSNorm",
            norm_eps=1e-6,
            mlp_class_name="LLaMAMLP",
            intermediate_size=8640 // 2,
        ),
        "share_embeddings": False,
        "mup_init": False,
        "batch_size": 1,
    },
    "open_llama_3b_tied": {
        "config": Config(
            block_size=2048,
            vocab_size=32000,
            padding_multiple=64,
            n_layer=26,
            n_embd=3200,
            rotary_percentage=1.0,
            parallel_residual=False,
            bias=False,
            norm_class_name="RMSNorm",
            norm_eps=1e-6,
            mlp_class_name="LLaMAMLP",
            intermediate_size=8640 // 2,
        ),
        "share_embeddings": True,
        "mup_init": False,
        "batch_size": 1,
    },
    "pythia_1.4b_untied": {
        "config": Config(
            block_size=2048,
            n_layer=24,
            n_embd=2048,
            n_head=16,
            padding_multiple=128,
        ),
        "share_embeddings": False,
        "mup_init": False,
        "batch_size": 1,
    },
    "pythia_1.4b_tied": {
        "config": Config(
            block_size=2048,
            n_layer=24,
            n_embd=2048,
            n_head=16,
            padding_multiple=128,
        ),
        "share_embeddings": True,
        "mup_init": False,
        "batch_size": 1,
    },
}


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--num_dataloader_workers", type=int, default=8, help="Number of workers for data loading"
    )
    parser.add_argument("--device", type=str, default="cuda", help="Device to run on (cuda or cpu)")
    parser.add_argument(
        "--warmup_passes", type=int, default=100, help="Number of warmup passes before measurement"
    )
    parser.add_argument("--config_name", type=str, default="l24_tied", help="Config to use")
    args = parser.parse_args()

    # Create dataset with appropriate size for this batch size

    config = configs[args.config_name]["config"]
    share_embeddings = configs[args.config_name]["share_embeddings"]
    mup_init = configs[args.config_name]["mup_init"]
    # batch_size = configs[args.config_name]["batch_size"]
    batch_size = 1
    model = GPT_Scales(config=config, share_embeddings=share_embeddings, mup_init=mup_init)
    model.max_seq_length = 1024
    model.vocab_size = config.vocab_size

    # Create dataloader with random data
    dataset = RandomDataset(size=batch_size * 1, seq_length=1024, vocab_size=config.vocab_size)
    dataloader = data.DataLoader(
        dataset, batch_size=batch_size, num_workers=args.num_dataloader_workers, shuffle=True
    )

    max_fidelity = FrozenModel(
        base_model=model,
        n_trainable=1,
        unwrap=(GPT_Scales, nn.ModuleList, nn.ModuleDict),
        print_summary=False,
    ).max_fidelity

    device = torch.device(args.device)
    stats = []
    for n_trainable in range(1, max_fidelity + 1):
        model_class = GPT_Scales_Detached if n_trainable != max_fidelity else GPT_Scales
        model = model_class(
            config=config,
            share_embeddings=share_embeddings,
            mup_init=mup_init,
        )
        model.max_seq_length = 1024
        model.vocab_size = config.vocab_size
        print(f"Measuring with n_trainable={n_trainable}/{max_fidelity}")
        frozen_model = FrozenModel(
            n_trainable=n_trainable,
            base_model=model,
            print_summary=False,
            unwrap=(GPT_Scales, nn.ModuleList, nn.ModuleDict),
        )
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, frozen_model.parameters()), lr=0.001
        )
        stats.append(
            measure_model_training_hw_metrics(
                frozen_model, dataloader, optimizer, device=device, warmup_passes=args.warmup_passes
            )
        )
        del frozen_model
        del optimizer
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        gc.collect()

    # Save the collected stats to CSV
    csv_path = save_stats_to_csv(stats, f"gpt_scales_{args.config_name}", batch_size)

    print(f"Data collection complete. Results saved to: {csv_path}")
