from src.models.common import CausalLM, get_qkv

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from copy import deepcopy

@torch.no_grad()
def check_transplant(linear_model: CausalLM, origin_model: CausalLM,
                     dataset: Dataset, 
                     data_collator: nn.Module):
    batch_size_per_device = 8
    n_gpu = torch.cuda.device_count()
    num_batch_for_check = 10

    whole_batch_size = batch_size_per_device * n_gpu

    print("=== Checking whether tranplanting was done correctly ...", flush=True)

    print("Checking the attention layers ...", flush=True)
    p_linear_model = nn.DataParallel(linear_model)
    p_origin_model = nn.DataParallel(origin_model)
    p_linear_model.eval(); p_origin_model.eval()
    dataloader = DataLoader(dataset, batch_size=whole_batch_size, 
                            collate_fn=data_collator, shuffle=True)
    for batch_idx, batch in enumerate(dataloader):
        if batch_idx >= num_batch_for_check:
            break
        input_ids = batch["input_ids"].to(origin_model.device)
        attention_mask = batch["attention_mask"].to(origin_model.device)

        out_origin = p_origin_model(input_ids=input_ids, 
                                    attention_mask=attention_mask, 
                                    output_hidden_states=True)
        
        for layer_idx in range(len(origin_model.transformer.h)):
            hidden_states = out_origin.hidden_states[layer_idx]
            origin_query, origin_key, origin_value = get_qkv(origin_model, input_ids, attention_mask, hidden_states, layer_idx)
            linear_query, linear_key, linear_value = get_qkv(linear_model, input_ids, attention_mask, hidden_states, layer_idx)
            message = f"batch_idx: {batch_idx} / {num_batch_for_check}, layer_idx: {layer_idx} "
            message += f"QueryDiff = {(linear_query - origin_query).norm().item()} "
            message += f"KeyDiff = {(linear_key - origin_key).norm().item()} "
            message += f"ValueDiff = {(linear_value - origin_value).norm().item()}"
            print(message)
            
    print("Checking the layers except for the attention layers ...", flush=True)

    tmp_model: CausalLM = deepcopy(linear_model).to(origin_model.device)
    for i, block in enumerate(tmp_model.transformer.h):
        block.attn = origin_model.transformer.h[i].attn
    p_tmp_model = nn.DataParallel(tmp_model)
    p_tmp_model.eval()

    for param_name, param_origin in origin_model.named_parameters():
        param_tmp = tmp_model.state_dict()[param_name]
        if (param_origin != param_tmp).any().item():
            print(f"param_name: {param_name}, param_diff = {(param_origin - param_tmp).norm().item()}")

    for batch_idx, batch in enumerate(dataloader):
        if batch_idx >= num_batch_for_check:
            break
        input_ids = batch["input_ids"].to(origin_model.device)
        attention_mask = batch["attention_mask"].to(origin_model.device)
        with torch.no_grad():
            output_tmp = p_tmp_model(input_ids=input_ids, 
                                     attention_mask=attention_mask, 
                                     use_cache=False,
                                     output_hidden_states=True)
            output_origin = p_origin_model(input_ids=input_ids, 
                                           attention_mask=attention_mask, 
                                           use_cache=False,
                                           output_hidden_states=True)
        message = f"batch_idx: {batch_idx} / {num_batch_for_check} "
        message += f"LogitsDiff = {(output_tmp.logits - output_origin.logits).norm().item()}"
        print(message, flush=True)
        for layer_idx, (hidden_tmp, hidden_origin) in enumerate(zip(output_tmp.hidden_states,
                                                                    output_origin.hidden_states)):
            message = f"batch_idx: {batch_idx} / {num_batch_for_check} layer_idx: {layer_idx} "
            message += f"HiddenStatesDiff = {(hidden_tmp - hidden_origin).norm().item()}"
            print(message, flush=True)

    print("... Done ===", flush=True)
            