import torch 
torch.set_default_device("cuda")
import os 
import argparse
from ming.model.builder import load_molora_pretrained_model
from ming.utils import client, get_model_name_from_path
from ming.model.modeling_molora_qwen import MoLoRAQwenForCausalLM
from ming.conversations import conv_templates
from tqdm import tqdm 
from transformers import AutoTokenizer

def get_prompt(query: str, response: str):
    conv = conv_templates['qwen'].copy()
    conv.append_message(conv.roles[0], query)
    conv.append_message(conv.roles[1], message=response)
    prompt = conv.get_prompt()
    return prompt 

def calc_distance(hidden_states: torch.Tensor, opt_hidden_states: torch.Tensor, padding_mask: torch.Tensor):
    # Inverting the padding_mask (1 for real tokens, 0 for padding)
    mask = 1 - padding_mask  # [B, N]

    # Expanding mask to match dimensions of hidden_states [L, B, N, D]
    expanded_mask = mask.unsqueeze(0).unsqueeze(-1)  # [B, N] -> [1, B, N, 1]

    # Apply the mask to hidden_states and opt_hidden_states
    masked_hidden_states = hidden_states * expanded_mask  # [L, B, N, D]
    masked_opt_hidden_states = opt_hidden_states * expanded_mask  # [L, B, N, D]

    # Compute the squared difference
    squared_diff = (masked_hidden_states - masked_opt_hidden_states) ** 2  # [L, B, N, D]

    # Sum over the feature dimension (D), then mask invalid positions and compute the sum over sequence length (N)
    sum_squared_diff = torch.sum(squared_diff, dim=-1)  # [L, B, N]
    valid_sum_squared_diff = sum_squared_diff * mask.unsqueeze(0)  # Apply mask [L, B, N]
    sum_diff = torch.sum(valid_sum_squared_diff, dim=-1)  # Sum over N, [L, B]

    # Calculate the real length for each sequence in each layer
    real_length = mask.sum(dim=-1).unsqueeze(0)  # [1, B]

    # Avoid division by zero by setting zero lengths to one (affecting zero results)
    real_length = real_length.masked_fill(real_length == 0, 1)

    # Compute the mean of the squared differences
    mean_squared_diff = sum_diff / real_length  # [L, B]

    # Finally, take square root to get the Euclidean distance and average over all layers
    distance = torch.sqrt(mean_squared_diff).mean(dim=0)  # [B]

    return distance

def calc_dist_hidden_states_with_opt(model: MoLoRAQwenForCausalLM, opt_model: MoLoRAQwenForCausalLM, args: argparse.Namespace, tokenizer: AutoTokenizer):
    # 读取args.output_jsonl_path，按照args.batch_size划分
    # 对于每个batch，计算model和opt_model的hidden_states的距离
    output_data = client.read(args.output_jsonl_path)
    data_chunks = [output_data[i:i+args.batch_size] for i in range(0, len(output_data), args.batch_size)]
    prompt_chunks = [[get_prompt(i['prompt'], i['text']) for i in data] for data in data_chunks]
    
    hidden_states_dist_list = []
    attn_hidden_states_dist_list = []
    mlp_hidden_states_dist_list = []
    
    # iterate over prompt_chunks
    for each_prompt in tqdm(prompt_chunks):
        input_ids = tokenizer(each_prompt, return_tensors="pt", truncation=True,
                              padding=True).input_ids
        padding_mask = (input_ids == tokenizer.pad_token_id).float()
        with torch.inference_mode():
            outputs = model.model(input_ids)
            opt_outputs = opt_model.model(input_ids)
        # obtain hidden_states, attn_hidden_states, mlp_hidden_states
        hidden_states = torch.stack(outputs.hidden_states, dim=0)  # [L, N, D]
        attn_hidden_states = torch.stack(outputs.attn_hidden_states, dim=0)
        mlp_hidden_states = torch.stack(outputs.mlp_hidden_states, dim=0)
        opt_hidden_states = torch.stack(opt_outputs.hidden_states, dim=0)
        opt_attn_hidden_states = torch.stack(opt_outputs.attn_hidden_states, dim=0)
        opt_mlp_hidden_states = torch.stack(opt_outputs.mlp_hidden_states, dim=0)
        # calculate distance
        hidden_states_dist = calc_distance(hidden_states, opt_hidden_states, padding_mask)
        attn_hidden_states_dist = calc_distance(attn_hidden_states, opt_attn_hidden_states, padding_mask)
        mlp_hidden_states_dist = calc_distance(mlp_hidden_states, opt_mlp_hidden_states, padding_mask)
        hidden_states_dist_list.append(hidden_states_dist)
        attn_hidden_states_dist_list.append(attn_hidden_states_dist)
        mlp_hidden_states_dist_list.append(mlp_hidden_states_dist)
    
    # stack the results (K (B,) tensors) and obtain (K, B ) tensor and flatten it to a (N, ) tensor and average
    hidden_state_dist = torch.stack(hidden_states_dist_list, dim=0).flatten().mean()
    attn_hidden_state_dist = torch.stack(attn_hidden_states_dist_list, dim=0).flatten().mean()
    mlp_hidden_state_dist = torch.stack(mlp_hidden_states_dist_list, dim=0).flatten().mean()
    return hidden_state_dist, attn_hidden_state_dist, mlp_hidden_state_dist
        

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_base", type=str, default="/mnt/petrelfs/usr/models/models--Qwen--Qwen1.5-1.8B-Chat")
    parser.add_argument("--opt_model_path", type=str, default="/mnt/petrelfs/usr/checkpoints/ming1.8b-4x1-topk-mathtest-r16")
    parser.add_argument("--model_path", type=str)
    parser.add_argument("--output_jsonl_path", type=str)
    parser.add_argument("--load_molora", action="store_true")
    parser.add_argument("--unload_lora", action="store_true")
    
    parser.add_argument("--batch_size", type=int, default=128)
    
    args = parser.parse_args()
    args.model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, _, _ = load_molora_pretrained_model(args.model_path, args.model_base, 
                                         model_name, args.load_molora,
                                         unload_lora=args.unload_lora)
    opt_model_name = get_model_name_from_path(args.opt_model_path)
    tokenizer, opt_model, _, _ = load_molora_pretrained_model(args.opt_model_path, args.model_base,
                                             opt_model_name, True,
                                             unload_lora=False)
    model.config.output_hidden_states = True 
    opt_model.config.output_hidden_states = True
    hidden_state_dist, attn_hidden_state_dist, mlp_hidden_state_dist = calc_dist_hidden_states_with_opt(model, opt_model, args, tokenizer=tokenizer)
    print(f"Hidden states distance: {hidden_state_dist}")
    print(f"Attention hidden states distance: {attn_hidden_state_dist}")
    print(f"MLP hidden states distance: {mlp_hidden_state_dist}")
    