import torch
import torch.utils
import vllm
from typing import Dict


def load_hf_params_to_vllm(param: Dict, llm: vllm.LLM) -> None:
    """Load weights from HF transformer model to vLLM model."""

    model = llm.llm_engine.model_executor.driver_worker.model_runner.model
    num_layers = model.config.num_hidden_layers

    # Load embeddings layer weights.
    model_param = model.get_parameter('model.embed_tokens.weight')
    model_param.copy_(
        param['model.embed_tokens.weight'][:model_param.shape[0]].to(
            model_param.dtype).to(model_param.device)
    )
    model_param = model.get_parameter('lm_head.weight')
    model_param.copy_(
        param['lm_head.weight'][:model_param.shape[0]].to(
            model_param.dtype).to(model_param.device)
    )

    # Load the final layernorm weights.
    model_param = model.get_parameter('model.norm.weight')
    model_param.copy_(
        param['model.norm.weight'].to(model_param.dtype).to(model_param.device)
    )

    for i in range(num_layers):
        # Load qkv_proj weights.
        model_param = model.get_parameter(
            f'model.layers.{i}.self_attn.qkv_proj.weight')
        model_param.copy_(
            torch.cat([
                param[f'model.layers.{i}.self_attn.q_proj.weight'],
                param[f'model.layers.{i}.self_attn.k_proj.weight'],
                param[f'model.layers.{i}.self_attn.v_proj.weight'],
            ], dim=0).to(model_param.dtype).to(model_param.device)
        )
        # Load gate_up_proj weights.
        model_param = model.get_parameter(
            f'model.layers.{i}.mlp.gate_up_proj.weight')
        model_param.copy_(
            torch.cat([
                param[f'model.layers.{i}.mlp.gate_proj.weight'],
                param[f'model.layers.{i}.mlp.up_proj.weight'],
            ], dim=0).to(model_param.dtype).to(model_param.device)
        )
        # Load o_proj and down_proj weights.
        model_param = model.get_parameter(
            f'model.layers.{i}.self_attn.o_proj.weight')
        model_param.copy_(
            param[f'model.layers.{i}.self_attn.o_proj.weight'].to(
                model_param.dtype).to(model_param.device)
        )
        model_param = model.get_parameter(
            f'model.layers.{i}.mlp.down_proj.weight')
        model_param.copy_(
            param[f'model.layers.{i}.mlp.down_proj.weight'].to(
                model_param.dtype).to(model_param.device)
        )
        # Load layer_norm weights.
        model_param = model.get_parameter(
            f'model.layers.{i}.input_layernorm.weight')
        model_param.copy_(
            param[f'model.layers.{i}.input_layernorm.weight'].to(
                model_param.dtype).to(model_param.device)
        )
        model_param = model.get_parameter(
            f'model.layers.{i}.post_attention_layernorm.weight')
        model_param.copy_(
            param[f'model.layers.{i}.post_attention_layernorm.weight'].to(
                model_param.dtype).to(model_param.device)
        )


def eval_model(vllm_model, evaluator, ix=None):
    result = evaluator.evaluate(vllm_model, sample_ids=ix)
    print(result.aggregate_metrics)
    return result



def compose_new_params(
        policy,
        param_name,
        decomposed_params,
        learnable_params,
        convert_dtype=True,
        ):
    if not convert_dtype:
        raise NotImplementedError
    mm = policy.get_mask(
        learnable_params[param_name], convert_dtype=convert_dtype)
    return (
        decomposed_params[f'{param_name}.U'] @
        torch.diag_embed(decomposed_params[f'{param_name}.S'] * mm) @
        decomposed_params[f'{param_name}.V'].T
    ) * (
        decomposed_params[f'{param_name}.S'].sum() / 
        (decomposed_params[f'{param_name}.S'] * mm).sum()
    )


@torch.no_grad()
def forward(policy, model, base_params, decomposed_params, learnable_params):
    """Forward pass."""
    new_params = {}
    for k in base_params:
        if 'mlp' in k:
            new_params[k] = compose_new_params(
                policy, k, decomposed_params, learnable_params
            )
            model.get_parameter(k).copy_(new_params[k])
        else:
            new_params[k] = base_params[k]
    return new_params

@torch.no_grad()
def load_base_params(model, base_params,):
    for k in base_params:
        if 'mlp' in k:
            model.get_parameter(k).copy_(base_params[k].cuda())


def backward(policy, model, base_params, decomposed_params, learnable_params,):
    """Backward pass."""
    keys_to_backprop = [k for k in base_params if 'mlp' in k]
    last_key = keys_to_backprop[-1]
    for k in keys_to_backprop[:-1]:
        compose_new_params(
            policy, k, decomposed_params, learnable_params
        ).backward(model.get_parameter(k).grad, retain_graph=True)
    # release graph
    compose_new_params(
            policy, last_key, decomposed_params, learnable_params
        ).backward(model.get_parameter(last_key).grad, retain_graph=False)