import numpy as np
from transformers import AutoModelForCausalLM, AutoConfig
from accelerate import init_empty_weights
import numpy as np
import os
import json

# ============ Taraget / Draft model config ============ #

def get_target_model_info(target_model_path):
    '''
    Get the target model parameter volume.
    It is used to compute the initial guess of the model parameters.

    Arguments:
        target_model_path: the model path of the target model

    Returns:
        expert_params_count: the number of parameters in the experts
        other_params_count: the number of parameters in the other layers
    '''
    config = AutoConfig.from_pretrained(target_model_path)
    with init_empty_weights():
        target_model = AutoModelForCausalLM.from_config(config)

    # compute parameter volume
    expert_params = {}
    other_params = {}
    for name, param in target_model.named_parameters():
        if 'experts' in name:
            expert_params[name] = param.numel()
        else:
            other_params[name] = param.numel()

    other_params_count = sum(other_params.values())
    expert_params_count = sum(expert_params.values())

    return expert_params_count, other_params_count


def get_draft_model_volume(draft_model_path):
    '''
    Get the draft model parameter volume.
    It is used to compute the initial guess of the model parameters.

    Arguments:
        draft_model_path: the model path of the draft model

    Returns:
        param_count: the number of parameters in the draft model
        other_params_count: the number of parameters in the other layers
    '''
    config = AutoConfig.from_pretrained(draft_model_path)
    with init_empty_weights():
        draft_model = AutoModelForCausalLM.from_config(config)

    param_count = sum(p.numel() for p in draft_model.parameters())

    return param_count



# ============ Modeling Framework ============ #

def g(t, rp, s):
    '''
    This function describes the execution time growth caused by roofline model.
    It is mentioned by Equation 11 in Section 3.3 of the paper.

    Arguments:
        t: input token counts
        rp: roofline model ridge point
        s: slope of the roofline model  
    
    Returns:
        model_g: modeled execution time growth
    '''
    model_g = 0
    if t < rp:
        model_g = (s ** t)
    else:
        model_g = (s ** rp) * (1 + np.log(s) * (t - rp))

    return model_g


def N(t, K, E):
    '''
    This function computes the number of activated experts.
    It is mentioned by Equation 8 in Section 3.2 of the paper.

    Arguments:
        t: input token counts
        K: number of activated experts per token
        E: total number of experts

    Returns:
        model_N: the expected number of activated experts
    '''
    model_N = E * (1 - ((E - K) / E) ** t)
    return model_N


def target_forward(params, x, K, E):
    '''
    This function computes the target model execution time.
    It is mentioned by Line 6 and 8 in Algorithm 1 of the paper.

    Arguments:
        params: model parameters
        x: input token counts
        K: number of activated experts per token
        E: total number of experts
    
    Returns:
        model_target: modeled execution time of the target model
    '''
    bias, k1, k2, k3, rp, s = params

    tokens = x
    exp_per_token = K

    num_activated_experts = N(tokens, exp_per_token, E)
    token_per_expert = tokens * exp_per_token / num_activated_experts
    model_target = bias + k1 * g(tokens, rp, s) + k2 * num_activated_experts + k3 * g(token_per_expert, rp, s)
    
    return model_target


def draft_forward(params, x):
    '''
    This function computes the draft model execution time.
    It is mentioned by Line 9 in Algorithm 1 of the paper.

    Arguments:
        params: model parameters
        x: input token counts

    Returns:
        model_draft: modeled execution time of the draft model
    '''
    draft_bias, draft_k, rp, s  = params
    tokens = x
    model_draft = draft_bias + draft_k * g(tokens, rp, s)

    return model_draft

def reject_forward(params, x):
    '''
    This function computes the reject model execution time.
    It is mentioned by Line 10 in Algorithm 1 of the paper.

    Arguments:
        params: model parameters
        x: input token counts

    Returns:
        model_reject: modeled execution time of rejection sampling
    '''
    reject_bias, reject_k  = params
    tokens = x
    model_reject = reject_bias + reject_k * tokens

    return model_reject


def ComputeSpeedup(params, B, gamma, K, E, sigma):
    '''
    Compute SD speedup given the model parameters and input data.
    It is mentioned by line 3 in Algorithm 1 of the paper.

    Arguments:
        params: model parameters
        B: batch size
        gamma: draft length
        K: number of activated experts per token
        E: total number of experts
        sigma: the ratio of accepted tokens count to the maximum token count if all draft tokens are accepted

    Returns:
        model_speedup: modeled speedup
    '''
    data_len = len(B)
    model_speedup = np.zeros(data_len)
    bias, k1, k2, k3, draft_bias, draft_k, reject_bias, reject_k, rp, s = params
    target_params = [bias, k1, k2, k3, rp, s]
    draft_params = [draft_bias, draft_k, rp, s]
    reject_params = [reject_bias, reject_k]
    for i in range(data_len):
        ar_time = target_forward(target_params, B[i], K[i], E[i])
        verfication_time = target_forward(target_params, B[i] * gamma[i], K[i], E[i])
        draft_time = gamma[i] * draft_forward(draft_params, B[i])
        reject_time = reject_forward(reject_params, B[i])
        model_speedup[i] = sigma[i] * (gamma[i]+1) * (ar_time / (verfication_time + draft_time + reject_time))

    return model_speedup


def residual(params, B, gamma, K, E, sigma, y):
    '''
    Compute the residual between the modeled speedup and the real speedup.
    It is mentioned by line 13 in Algorithm 1 of the paper.

    Arguments:
        params: model parameters
        B: batch size
        gamma: draft length
        K: number of activated experts per token
        E: total number of experts
        sigma: the ratio of accepted tokens count to the maximum token count if all draft tokens are accepted
        y: real speedup

    Returns:
        residual: the difference between the modeled speedup and the real speedup
    '''
    mdoel_y = ComputeSpeedup(params, B, gamma, K, E, sigma)
    return mdoel_y - y
