import argparse
import json
import os
import shutil

from vlmq.quantization.utils.utils import cleanup_memory

from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils import get_balanced_memory


def format_time(seconds):
    hours = int(seconds // 3600)
    remaining_seconds = seconds % 3600

    minutes = int(remaining_seconds // 60)

    seconds = remaining_seconds % 60

    parts = []
    if hours > 0:
        parts.append(f"{hours}h")
    if minutes > 0:
        parts.append(f"{minutes}min")

    if seconds > 0 or (hours == 0 and minutes == 0):
        parts.append(f"{seconds:.2f}s")

    if not parts:
        return "0s"

    return "-".join(parts)


def distribute_model(model) -> None:
    no_split_module_classes = ['Qwen2VLDecoderLayer', 'Qwen2VisionTransformerPretrainedModel']
    max_memory = get_balanced_memory(
        model,
        no_split_module_classes=no_split_module_classes,
    )

    device_map = infer_auto_device_map(
        model, max_memory=max_memory, no_split_module_classes=no_split_module_classes
    )

    print(device_map)
    dispatch_model(
        model,
        device_map=device_map,
        offload_buffers=True,
        offload_dir="offload",
        state_dict=model.state_dict(),
    )

    cleanup_memory()
        
        
def save_args(args, tgt_path, filename="quant_cfg.json"):
    args_dict = vars(args)
    os.makedirs(tgt_path, exist_ok=True)
    json_path = os.path.join(tgt_path, filename)

    try:
        with open(json_path, 'w', encoding='utf-8') as f:
            json.dump(args_dict, f, ensure_ascii=False, indent=4)
        print(f"Save quant_cfg.json to {json_path}")
    except IOError as e:
        print(f"Failed to save quant_cfg to {json_path}: {e}")
        
        
def load_args(json_file_path):
    args = argparse.Namespace()

    if not os.path.exists(json_file_path):
        print(f"Error: Configuration file '{json_file_path}' does not exist. Returning empty Namespace.")
        return args

    try:
        with open(json_file_path, 'r', encoding='utf-8') as f:
            config_dict = json.load(f)
        
        for key, value in config_dict.items():
            setattr(args, key, value)
        
        print(f"Successfully read configuration from '{json_file_path}'.")
        return args
    except json.JSONDecodeError as e:
        print(f"Error: Failed to parse JSON file '{json_file_path}': {e}")
        return argparse.Namespace()
    except Exception as e:
        print(f"Unknown error reading file '{json_file_path}': {e}")
        return argparse.Namespace()
    
    
def contiguous_params(model):
    for name, param in model.model.named_parameters():
        if not param.is_contiguous():
            param.data = param.data.contiguous()

    for name, param in model.model.named_buffers():
        if not param.is_contiguous():
            param.data = param.data.contiguous()
            
            
def save_model(model, tgt_path):
    contiguous_params(model)
    os.makedirs(tgt_path, exist_ok=True)
    model.save_pretrained(tgt_path)
    print(f'Save quantized model to {tgt_path}')
    
    
def copy_auxiliary_file(org_path, tgt_path):
    print(f'Copy auxiliary files FROM {org_path} TO {tgt_path}')
    if not os.path.exists(org_path):
        print(f"Error: Original path '{org_path}' does not exist.")
        return

    os.makedirs(tgt_path, exist_ok=True)

    for item_name in os.listdir(org_path):
        org_item_path = os.path.join(org_path, item_name)
        tgt_item_path = os.path.join(tgt_path, item_name)

        if item_name == ".cache":
            print(f"    Skipping .cache folder: {org_item_path}")
            continue
        
        if item_name.endswith(".safetensors"):
            print(f"    Skipping .safetensors file: {org_item_path}")
            continue

        if os.path.isfile(org_item_path):
            if os.path.exists(tgt_item_path):
                print(f"    Skipping existing file: {tgt_item_path}")
                continue
            try:
                shutil.copy2(org_item_path, tgt_item_path)
                print(f"    Copied file: {org_item_path} -> {tgt_item_path}")
            except Exception as e:
                print(f"    Error copying file {org_item_path}: {e}")
        elif os.path.isdir(org_item_path):
            # Recursively call the function for subdirectories
            print(f"    Entering directory: {org_item_path}")
            copy_auxiliary_file(org_item_path, tgt_item_path)