import importlib
import os
import random
import numpy as np
import json
import torch
from typing import Dict
from numpy.random import default_rng

DEFAULT_PRECISION = 2


def generate_transformer_input(model_tokenizer, input_shape, device):

    if input_shape is None:
        input_shape = [1, 128]  

    max_length = input_shape[1]
    model_input_ids = []
    model_attention_mask = []
    model_token_type_ids = []
    model_position_ids = []

    inp_seq = ""
    for _ in range(input_shape[0]):
        inputs = model_tokenizer.encode_plus(
            inp_seq,
            add_special_tokens=True,
            truncation_strategy='longest_first',
        )
        origin_length = len(inputs["input_ids"])
        padding_length = max_length - origin_length

        for key in inputs.keys():
            if key == "input_ids":
                input_ids = inputs["input_ids"]
                pad_token = model_tokenizer.pad_token_id if model_tokenizer.pad_token_id else 0
                input_ids = input_ids + ([pad_token] * padding_length)
                assert len(input_ids) == max_length, "len(input_ids) must equal max_length"
                model_input_ids.append(input_ids)
            elif key == "attention_mask":
                attention_mask = [1] * origin_length
                attention_mask = attention_mask + ([0] * padding_length)
                assert len(attention_mask) == max_length, "len(attention_mask) must equal max_length"
                model_attention_mask.append(attention_mask)
            elif key == "token_type_ids":
                token_type_ids = inputs['token_type_ids']
                pad_token_segment_id = 0
                token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
                assert len(token_type_ids) == max_length, "len(token_type_ids) must equal max_length"
                model_token_type_ids.append(token_type_ids)
            elif key == "position_ids":  
                position_ids = inputs['position_ids']
                for i in range(origin_length, max_length):
                    position_ids.append(i)
                assert len(position_ids) == max_length, "len(position_ids) must equal max_length"
                model_position_ids.append(position_ids)

    
    inputs = {}
    if len(model_input_ids) > 0:
        inputs.update({"input_ids": torch.tensor(model_input_ids).to(device)})
    if len(model_attention_mask) > 0:
        inputs.update({"attention_mask": torch.tensor(model_attention_mask).to(device)})
    if len(model_token_type_ids) > 0:
        inputs.update({'token_type_ids': torch.tensor(model_token_type_ids).to(device)})
    if len(model_position_ids) > 0:
        inputs.update({'position_ids': torch.tensor(model_position_ids).to(device)})

    return inputs


def number_to_string(num, units=None, precision=DEFAULT_PRECISION):
    if units is None:
        if num >= 1e12:
            magnitude, units = 1e12, "T"
        elif num >= 1e9:
            magnitude, units = 1e9, "G"
        elif num >= 1e6:
            magnitude, units = 1e6, "M"
        elif num >= 1e3:
            magnitude, units = 1e3, "K"
        elif num >= 1 or num == 0:
            magnitude, units = 1, ""
        elif num >= 1e-3:
            magnitude, units = 1e-3, "m"
        else:
            magnitude, units = 1e-6, "u"
    else:
        if units == "T":
            magnitude = 1e12
        elif units == "G":
            magnitude = 1e9
        elif units == "M":
            magnitude = 1e6
        elif units == "K":
            magnitude = 1e3
        elif units == "m":
            magnitude = 1e-3
        elif units == "u":
            magnitude = 1e-6
        else:
            magnitude = 1
    return f"{round(num / magnitude, precision):g} {units}"


def macs_to_string(macs, units=None, precision=DEFAULT_PRECISION):
    return f"{number_to_string(macs, units=units, precision=precision)}MACs"


def flops_to_string(flops, units=None, precision=DEFAULT_PRECISION):
    return f"{number_to_string(flops, units=units, precision=precision)}FLOPS"


def bytes_to_string(b, units=None, precision=DEFAULT_PRECISION):
    return f"{number_to_string(b, units=units, precision=precision)}B"


def params_to_string(params_num, units=None, precision=DEFAULT_PRECISION):
    units = units.replace("B", "G") if units else units
    return number_to_string(params_num, units=units, precision=precision).replace("G", "B").strip()


def get_module_flops(module, is_sparse=False):
    sum_flops = module.__flops__ * sum(
        p.count_nonzero().item() for p in module.parameters() if p.requires_grad
    ) / (1e-8 + sum(p.numel() for p in module.parameters() if p.requires_grad)) if is_sparse else module.__flops__
    
    for child in module.children():
        sum_flops += get_module_flops(child, is_sparse=is_sparse)
    return sum_flops


def get_module_macs(module, is_sparse=False):
    sum_macs = module.__macs__ * sum(
        p.count_nonzero().item() for p in module.parameters() if p.requires_grad
    ) / (1e-8 + sum(p.numel() for p in module.parameters() if p.requires_grad)) if is_sparse else module.__macs__
    
    for child in module.children():
        sum_macs += get_module_macs(child, is_sparse=is_sparse)
    return sum_macs


def convert_bytes(size):
    "Converts `size` from bytes to the largest possible unit"
    for x in ["bytes", "KB", "MB", "GB", "TB"]:
        if size < 1024.0:
            return f"{round(size, 2)} {x}"
        size /= 1024.0

    return f"{round(size, 2)} PB"


def _is_package_available(pkg_name):
    
    package_exists = importlib.util.find_spec(pkg_name) is not None
    if package_exists:
        try:
            _ = importlib.metadata.metadata(pkg_name)
            return True
        except importlib.metadata.PackageNotFoundError:
            return False

def eval_task_to_list(task):
    if isinstance(task, str):
        task = task.split(',')
    elif isinstance(task, tuple):
        task = list(task)
    elif isinstance(task, list):
        print("Do not need to motify the type of task")
    print("eval_task:", task)
    return task


def load_opt_config(config_path: str) -> Dict:
    with open(config_path, 'r') as f:
        return json.load(f)

def save_opt_config(config: Dict, config_path: str) -> None:
    os.makedirs(os.path.dirname(config_path), exist_ok=True)
    with open(config_path, 'w') as f:
        json.dump(config, f, indent=2)
        
def set_project_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def _stochastic_round(arr: np.ndarray, rng):
    frac = arr - np.floor(arr)
    return np.floor(arr + (rng.random(len(arr)) < frac)).astype(int)

def gen_smooth_staircase(
    num_layers:int         = 32,
    min_experts:int        = 2,
    max_experts:int        = 8,
    target_total:int       = None,
    adapt_mode:str         = "up",
    base_noise:float       = 0.3,        
    max_step:int           = 2,          
    rng_seed:int | None    = None,
):
    assert adapt_mode in ("up", "down")
    rng = default_rng(rng_seed)

    
    idx = np.arange(num_layers)
    base = min_experts + (max_experts - min_experts) * idx / (num_layers - 1)

    
    
    alpha = base_noise * (max_experts - min_experts) * idx / (num_layers - 1)
    noise = rng.uniform(-alpha, alpha)
    noisy = np.clip(base + noise, min_experts, max_experts)

    
    step_experts = _stochastic_round(noisy, rng)

    
    step_experts.sort()
    for i in range(1, num_layers):
        if step_experts[i] - step_experts[i-1] > max_step:
            step_experts[i] = step_experts[i-1] + max_step

    
    diff = target_total - step_experts.sum()
    inc_order = np.argsort(np.abs(idx - num_layers/2))   
    while diff != 0:
        for j in inc_order:
            if diff == 0:
                break
            if diff > 0 and step_experts[j] < max_experts:
                step_experts[j] += 1
                diff -= 1
            elif diff < 0 and step_experts[j] > min_experts:
                step_experts[j] -= 1
                diff += 1

    
    if adapt_mode == "down":
        step_experts = step_experts[::-1]

    
    assert step_experts.sum() == target_total
    assert step_experts.min() >= min_experts and step_experts.max() <= max_experts
    assert np.all(np.diff(step_experts) <= max_step)     
    return step_experts
