import argparse
import os
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import scienceplots  # noqa
import torch
import torch.nn as nn
import torchvision
from litgpt.config import Config
from matplotlib.cm import viridis

from layer_freeze.model_agnostic_freezing import FrozenModel
from saws.model import GPT_Scales, GPT_Scales_Detached

# Set style for publication-quality plots
plt.style.use("science")
plt.rcParams["font.family"] = "serif"
plt.rcParams["axes.grid"] = True
plt.rcParams["grid.alpha"] = 0.3
plt.rcParams["figure.figsize"] = (5, 4)
# Increase font sizes
plt.rcParams["axes.labelsize"] = 20
plt.rcParams["axes.titlesize"] = 20
plt.rcParams["xtick.labelsize"] = 20
plt.rcParams["ytick.labelsize"] = 20
plt.rcParams["legend.fontsize"] = 20

# Set marker size for all lines
plt.rcParams["lines.markersize"] = 5
# Make plot lines thicker
plt.rcParams["lines.linewidth"] = 1.5

# Color maps for multiple models
COLOR_MAPS = [viridis]


MODEL_NAME_MAPS = {
    "resnet18": "ResNet-18",
    "resnet34": "ResNet-34",
    "resnet50": "ResNet-50",
    "resnet101": "ResNet-101",
    "resnet152": "ResNet-152",
    "open_llama_3b": "Open-LLaMA-3B",
    "pythia_1.4b": "Pythia-1.4B",
    "l24": "GPT-2 (112M)",
    "14m": "GPT-2 (14M)",
}

configs = {
    "l24": {
        "config": Config(
            n_embd=512,
            n_layer=24,
            n_head=16,
            block_size=1024,
            vocab_size=50257,
            norm_class_name="LayerNorm",
            bias=True,
        ),
        "share_embeddings": True,
        "mup_init": False,
    },
    "14m": {
        "config": Config(
            n_embd=128,
            n_layer=8,
            n_head=2,
            block_size=1024,
            vocab_size=50257,
            norm_class_name="LayerNorm",
            bias=True,
        ),
        "share_embeddings": True,
        "mup_init": False,
    },
    "open_llama_3b": {
        "config": Config(
            block_size=2048,
            vocab_size=32000,
            padding_multiple=64,
            n_layer=26,
            n_embd=3200,
            rotary_percentage=1.0,
            parallel_residual=False,
            bias=False,
            norm_class_name="RMSNorm",
            norm_eps=1e-6,
            mlp_class_name="LLaMAMLP",
            intermediate_size=8640 // 2,
        ),
        "share_embeddings": True,
        "mup_init": False,
    },
    "pythia_1.4b": {
        "config": Config(
            block_size=2048,
            n_layer=24,
            n_embd=2048,
            n_head=16,
            padding_multiple=128,
        ),
        "share_embeddings": True,
        "mup_init": False,
    },
}


def get_model(model_name):
    """Get model by name."""
    match model_name:
        case "resnet18":
            return torchvision.models.resnet18(weights=None)
        case "resnet34":
            return torchvision.models.resnet34(weights=None)
        case "resnet50":
            return torchvision.models.resnet50(weights=None)
        case "resnet101":
            return torchvision.models.resnet101(weights=None)
        case "resnet152":
            return torchvision.models.resnet152(weights=None)
        case "open_llama_3b":
            return GPT_Scales_Detached(**configs["open_llama_3b"])
        case "pythia_1.4b":
            return GPT_Scales_Detached(**configs["pythia_1.4b"])
        case "l24":
            return GPT_Scales_Detached(**configs["l24"])
        case "14m":
            return GPT_Scales_Detached(**configs["14m"])
        case _:
            raise ValueError(f"Model {model_name} not supported")


def get_trainable_params(model):
    """Count trainable parameters in model."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_total_params(model):
    """Count total parameters in model."""
    return sum(p.numel() for p in model.parameters())


def get_max_fidelity(model):
    """
    Get the maximum fidelity (number of trainable layers) for a model.
    """
    # Initialize FrozenModel to get max_fidelity
    frozen_model = FrozenModel(
        n_trainable=1,
        base_model=model,
        print_summary=False,
        unwrap=(model.__class__, nn.ModuleDict, nn.ModuleList, GPT_Scales),
    )

    return frozen_model.max_fidelity


def get_csv_path(model_name, data_dir="plots/params_layers"):
    """Get path for saving/loading CSV data."""
    csv_dir = Path(data_dir)
    csv_dir.mkdir(parents=True, exist_ok=True)
    return csv_dir / f"{model_name}_params_vs_layers.csv"


def save_data_to_csv(df, model_name, data_dir="plots/params_layers"):
    """Save data to CSV file."""
    csv_path = get_csv_path(model_name, data_dir)
    df.to_csv(csv_path, index=False)
    print(f"Data saved to {csv_path}")
    return csv_path


def load_data_from_csv(model_name, data_dir="plots/params_layers"):
    """Load data from CSV file if it exists."""
    csv_path = get_csv_path(model_name, data_dir)
    if csv_path.exists():
        print(f"Loading existing data from {csv_path}")
        return pd.read_csv(csv_path)
    return None


def generate_data_with_frozen_model(
    model_name, force_regenerate=False, data_dir="plots/params_layers"
):
    """
    Generate parameter data using FrozenModel for freezing.
    If data already exists and force_regenerate is False, load from CSV instead.
    """
    # Check if data already exists
    if not force_regenerate:
        df = load_data_from_csv(model_name, data_dir)
        if df is not None:
            return df

    print(f"Generating data for {model_name} (this may take a while)...")

    # Get a fresh model
    base_model = get_model(model_name)

    # Get total parameter count
    total_params = get_total_params(base_model)

    # Get max fidelity
    max_fidelity = get_max_fidelity(base_model)

    data = []

    # For each number of trainable layers
    for n_trainable in range(1, max_fidelity + 1):
        print(f"Measuring with n_trainable={n_trainable}/{max_fidelity}")

        # Create a fresh model
        model = get_model(model_name)

        # Use FrozenModel to freeze layers
        frozen_model = FrozenModel(
            n_trainable=n_trainable,
            base_model=model,
            print_summary=False,
            unwrap=(model.__class__, nn.ModuleDict, nn.ModuleList, GPT_Scales),
        )

        # Count trainable parameters
        trainable_params = get_trainable_params(frozen_model)
        perc_trainable = (trainable_params / total_params) * 100

        data.append(
            {
                "n_trainable_layers": n_trainable,
                "n_trainable_params": trainable_params,
                "n_total_params": total_params,
                "perc_trainable_params": perc_trainable,
                "max_layers": max_fidelity,
            }
        )

        # Clean up
        del frozen_model
        del model
        torch.cuda.empty_cache() if torch.cuda.is_available() else None

    # Create and save DataFrame
    df = pd.DataFrame(data)
    save_data_to_csv(df, model_name, data_dir)

    return df


def plot_models_in_single_plot(model_dfs, output_dir):
    """
    Create a single plot with all models, normalizing x-axis as percentage of trainable layers.

    Args:
        model_dfs: Dictionary mapping model names to DataFrames with parameter data
        output_dir: Directory to save the plot
    """

    # Plot each model
    for i, (model_name, df) in enumerate(model_dfs.items()):
        # Normalize x-axis to percentage of total layers
        x_data = (df["n_trainable_layers"] / df["max_layers"].iloc[0]) * 100

        # Select color and marker for this model
        color = viridis(i / (len(model_dfs) - 1 if len(model_dfs) > 1 else 1))
        marker = ["o", "s", "^", "D", "v"][i % 5]  # Different markers for each model

        # Plot the data
        plt.plot(
            x_data,
            df["perc_trainable_params"],
            marker=marker,
            linestyle="-",
            color=color,
            label=MODEL_NAME_MAPS[model_name],
        )

    # Set axis labels and title
    plt.xlabel(r"Trainable Layers (\%)")
    plt.ylabel(r"Trainable Param. (\%)")
    # plt.title("Parameter Distribution Across Layers")

    # Add legend
    # Create a legend below the plot with as many columns as there are models
    plt.legend(
        loc="best",
        # bbox_to_anchor=(0.5, -0.4),
        # ncol=len(model_dfs),
        # handlelength=0.01,
        # markerscale=1.3,
    )

    # Add grid
    plt.grid(True, linestyle="--")

    # Set x-axis to 0-100%
    plt.xlim(0, 100)

    # Ensure directory exists
    output_path = Path(output_dir) / "params_vs_layers"
    os.makedirs(output_path, exist_ok=True)

    # Create a filename that includes the model names
    model_names_str = "_".join(model_dfs.keys())
    filepath = output_path / f"params_vs_layers_{model_names_str}.pdf"
    plt.savefig(filepath, bbox_inches="tight", dpi=500)


def parse_args():
    parser = argparse.ArgumentParser(
        description="Plot percentage of trainable parameters vs number of trainable layers"
    )
    parser.add_argument(
        "--models", nargs="+", default=["resnet18"], help="Model names (e.g., resnet18 resnet50)"
    )
    parser.add_argument("--output_dir", default="plots", help="Directory to save the plot")
    parser.add_argument(
        "--data_dir", default="plots/params_layers", help="Directory to save/load generated data"
    )
    parser.add_argument(
        "--force_regenerate",
        action="store_true",
        help="Force regeneration of data even if CSV exists",
    )
    return parser.parse_args()


def main():
    # Parse command-line arguments
    args = parse_args()

    # Generate data for each model
    model_dfs = {}

    for model_name in args.models:
        # Generate data using FrozenModel (or load from CSV)
        df = generate_data_with_frozen_model(
            model_name, force_regenerate=args.force_regenerate, data_dir=args.data_dir
        )

        if not df.empty:
            model_dfs[model_name] = df
            print(f"Data available for {model_name}")
        else:
            print(f"Error: No valid data for {model_name}")

    if not model_dfs:
        print("Error: No valid data generated for any model")
        return

    # Plot all models in a single subplot
    plot_models_in_single_plot(model_dfs, args.output_dir)


if __name__ == "__main__":
    main()
