# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import numpy as np
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM


from transformers import AutoConfig, AutoModelForCausalLM
import torch
from accelerate import init_empty_weights, load_checkpoint_and_dispatch



def load_safetensors_model(model_path):
    print("Loading original model")
    model = AutoModelForCausalLM.from_pretrained(model_path)
    model.eval()
    model = model.to(torch.float16).to('cuda')
    return model


def load_safetensors_model_parallel(model_path, model_type="llama3", num_gpus=None):
    """
    General loader for safetensors weight models, supporting gemma2, llama, and qwen architectures.
    Automatically distributes model layers across multiple GPUs.

    Args:
        model_path (str): Path to the model
        model_type (str): Type of the model, "gemma2", "llama", or "qwen". Default is "gemma2"
        num_gpus (int or None): Number of GPUs to use. If None, auto-detects available GPUs.

    Returns:
        model (torch.nn.Module): Loaded and device-distributed model in eval mode
    """

    model_type = model_type.lower()

    if num_gpus is None:
        num_gpus = torch.cuda.device_count()
    if num_gpus == 0:
        raise RuntimeError("No GPU devices detected. Cannot distribute model.")

    # Dynamically import the corresponding DecoderLayer and set device_map strategy based on model type
    if model_type == "gemma2":
        from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer
        no_split_module_classes = [Gemma2DecoderLayer]

        # Gemma2 device mapping
        def get_device_map(num_layers):
            layers_per_gpu = num_layers // num_gpus
            device_map = {
                "model.embed_tokens": 0,
                "lm_head": 0,  # Avoid tie_weights error
                "model.norm": num_gpus - 1,
                "model.rotary_emb": num_gpus - 1,
            }
            for i in range(num_layers):
                device_map[f"model.layers.{i}"] = min(i // layers_per_gpu, num_gpus - 1)
            return device_map

        dtype = None  # Can be changed to torch.bfloat16 or torch.float16 as needed

    elif model_type == "llama":
        from transformers.models.llama.modeling_llama import LlamaDecoderLayer
        no_split_module_classes = [LlamaDecoderLayer]

        # LLaMA device mapping
        def get_device_map(num_layers):
            layers_per_gpu = num_layers // num_gpus
            device_map = {
                "model.embed_tokens": 0,
                "lm_head": num_gpus - 1,
                "model.norm": num_gpus - 1,
                "model.rotary_emb": num_gpus - 1,
            }
            for i in range(num_layers):
                device_map[f"model.layers.{i}"] = i // layers_per_gpu
            return device_map

        dtype = torch.float16

    elif model_type == "qwen":
        # Please adjust the import path according to actual implementation
        from transformers.models.qwen3 import QwenDecoderLayer
        no_split_module_classes = [QwenDecoderLayer]

        # Assume Qwen device_map strategy similar to Gemma2; adjust if different
        def get_device_map(num_layers):
            layers_per_gpu = num_layers // num_gpus
            device_map = {
                "model.embed_tokens": 0,
                "lm_head": 0,
                "model.norm": num_gpus - 1,
                "model.rotary_emb": num_gpus - 1,
            }
            for i in range(num_layers):
                device_map[f"model.layers.{i}"] = min(i // layers_per_gpu, num_gpus - 1)
            return device_map

        dtype = torch.float16

    else:
        raise ValueError(f"Unsupported model type: {model_type}, only support ['gemma2', 'llama', 'qwen']")

    config = AutoConfig.from_pretrained(model_path)

    with init_empty_weights():
        model = AutoModelForCausalLM.from_config(config)

    model.tie_weights()

    num_layers = len(model.model.layers)
    device_map = get_device_map(num_layers)

    print(f"Device map for {model_type}:", device_map)

    model = load_checkpoint_and_dispatch(
        model,
        model_path,
        device_map=device_map,
        no_split_module_classes=no_split_module_classes,
        dtype=dtype,
    )

    model.eval()
    return model


def print_model_summary(model: nn.Module):
    """
    Print model architecture, layer counts, total size, and parameter data types

    Args:
        model: PyTorch model to print
    """
    layer_count = {}
    total_params = 0
    total_bytes = 0

    for name, module in model.named_modules():
        layer_type = type(module).__name__
        if layer_type not in layer_count:
            layer_count[layer_type] = 1
        else:
            layer_count[layer_type] += 1

        print(f"Layer: {name} | Type: {layer_type}")

    print("\nSummary of model layer counts:")
    for layer_type, count in layer_count.items():
        print(f"{layer_type}: {count}")

    print("\nSummary of model parameters:")

    for name, param in model.named_parameters():
        param_size = param.numel() * param.element_size()  # Calculate parameter bytes
        total_bytes += param_size
        total_params += param.numel()
        print(f"Name: {name} | Shape: {param.shape} | Data type: {param.dtype} | Bytes: {param_size} bytes")

    print(f"\nTotal number of parameters: {total_params:,}")
    model_size_MB = total_bytes / (1024 ** 2)  # Convert to MB
    print(f"Estimated model size: {model_size_MB:.2f} MB")


# Save quantized model
def save_model(model, save_path):
    torch.save(model.state_dict(), save_path)
    print(f"Model saved to: {save_path}")

