import utils
import model_utils
import quant_utils
import torch
import os
import logging
from tqdm import tqdm


@torch.no_grad()
def evaluator(model, testenc, dev, args):

    model.eval()

    if 'opt' in args.model:
        opt_type = True
        llama_type = False
    elif 'meta' in args.model:
        llama_type = True
        opt_type = False
    else:
        raise ValueError(f'Unknown model {args.model}')


    use_cache = model.config.use_cache
    model.config.use_cache = False

    if opt_type:
        layers = model.model.decoder.layers
        model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
        model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
        if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
            model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
        if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
            model.model.decoder.project_in = model.model.decoder.project_in.to(dev)

    elif llama_type:
        layers = model.model.layers
        model.model.embed_tokens = model.model.embed_tokens.to(dev)

    layers[0] = layers[0].to(dev)

    # Convert the whole text of evaluation dataset into batches of sequences.
    input_ids = testenc.input_ids  # (1, text_len)
    nsamples = input_ids.numel() // model.seqlen  # The tail is truncated.
    input_ids = input_ids[:, :nsamples * model.seqlen].view(nsamples, model.seqlen).to(dev)  # (nsamples, seqlen)

    batch_size = args.bsz
    input_ids = [input_ids[i:i + batch_size] for i in range(0, nsamples, batch_size)]
    nbatches = len(input_ids)

    dtype = next(iter(model.parameters())).dtype
    # The input of the first decoder layer.
    inps = torch.zeros(
        (nbatches, batch_size, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )
    inps = [0] * nbatches
    cache = {'i': 0, 'attention_mask': None}
    class Catcher(torch.nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inps[cache['i']] = inp
            cache['i'] += 1
            cache['attention_mask'] = kwargs['attention_mask']
            if llama_type:
                cache['position_ids'] = kwargs['position_ids']
            raise ValueError
    layers[0] = Catcher(layers[0])
   
    for i in range(nbatches):
        batch = input_ids[i]
        try:
            model(batch)
        except ValueError:
            pass
    layers[0] = layers[0].module
    layers[0] = layers[0].cpu()

    if opt_type:
        model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
        model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
        if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
            model.model.decoder.project_out = model.model.decoder.project_out.cpu()
        if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
            model.model.decoder.project_in = model.model.decoder.project_in.cpu()
    elif llama_type:
        model.model.embed_tokens = model.model.embed_tokens.cpu()
        position_ids = cache['position_ids']

    torch.cuda.empty_cache()
    outs = [0] * nbatches
    attention_mask = cache['attention_mask']

    for i in tqdm(range(len(layers)), desc="(Eval) Layers"):
        layer = layers[i].to(dev)

        # Dump the layer input and output
        if args.capture_layer_io and args.layer_idx == i:
            captured_io = model_utils.capture_layer_io(model_utils.get_model_type(model), layer, inps)
            save_path = model_utils.get_layer_io_save_path(args)
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            torch.save(captured_io, save_path)
            logging.info(f'Dumped layer input and output to: {save_path}')

        for j in range(nbatches):
            # 修复 attention mask 形状问题
            batch_size = inps[j].shape[0]
            seq_len = inps[j].shape[1]
            
            # 创建正确形状的 attention mask
            if attention_mask is not None:
                if attention_mask.dim() == 4:
                    # 如果是 4D mask，调整批次大小和序列长度
                    if attention_mask.shape[0] != batch_size:
                        # 如果批次大小不匹配，扩展或截取
                        if attention_mask.shape[0] < batch_size:
                            # 重复 attention mask 以匹配批次大小
                            repeat_times = (batch_size + attention_mask.shape[0] - 1) // attention_mask.shape[0]
                            attn_mask = attention_mask.repeat(repeat_times, 1, 1, 1)[:batch_size]
                        else:
                            attn_mask = attention_mask[:batch_size]
                    else:
                        attn_mask = attention_mask
                    # 调整序列长度
                    attn_mask = attn_mask[:, :, :seq_len, :seq_len]
                elif attention_mask.dim() == 2:
                    # 如果是 2D mask，调整批次大小和序列长度
                    if attention_mask.shape[0] != batch_size:
                        if attention_mask.shape[0] < batch_size:
                            repeat_times = (batch_size + attention_mask.shape[0] - 1) // attention_mask.shape[0]
                            attn_mask = attention_mask.repeat(repeat_times, 1)[:batch_size]
                        else:
                            attn_mask = attention_mask[:batch_size]
                    else:
                        attn_mask = attention_mask
                    attn_mask = attn_mask[:, :seq_len]
                else:
                    attn_mask = attention_mask
            else:
                attn_mask = None
            
            # 调整 position_ids
            if position_ids is not None:
                if position_ids.shape[0] != batch_size:
                    if position_ids.shape[0] < batch_size:
                        repeat_times = (batch_size + position_ids.shape[0] - 1) // position_ids.shape[0]
                        pos_ids = position_ids.repeat(repeat_times, 1)[:batch_size]
                    else:
                        pos_ids = position_ids[:batch_size]
                else:
                    pos_ids = position_ids
                    
                if pos_ids.shape[1] > seq_len:
                    pos_ids = pos_ids[:, :seq_len]
            else:
                pos_ids = position_ids
                
            if opt_type:
                outs[j] = layer(inps[j], attention_mask=attn_mask)[0]
            elif llama_type:
                outs[j] = layer(inps[j], attention_mask=attn_mask, position_ids=pos_ids)[0]
        layers[i] = layer.cpu()
        del layer
        torch.cuda.empty_cache()
        inps, outs = outs, inps

    if opt_type:
        if model.model.decoder.final_layer_norm is not None:
            model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev)
        if model.model.decoder.project_out is not None:
            model.model.decoder.project_out = model.model.decoder.project_out.to(dev)

    elif llama_type:
        if model.model.norm is not None:
            model.model.norm = model.model.norm.to(dev)

    model.lm_head = model.lm_head.to(dev)
    nlls = []
    loss_fct = torch.nn.CrossEntropyLoss(reduction = "none")
    for i in range(nbatches):
        hidden_states = inps[i]
        if opt_type:
            if model.model.decoder.final_layer_norm is not None:
                hidden_states = model.model.decoder.final_layer_norm(hidden_states)
            if model.model.decoder.project_out is not None:
                hidden_states = model.model.decoder.project_out(hidden_states)
        elif llama_type:
            if model.model.norm is not None:
                hidden_states = model.model.norm(hidden_states)
        lm_logits = model.lm_head(hidden_states)
        shift_logits = lm_logits[:, :-1, :]
        shift_labels = input_ids[i][:, 1:]
        loss = loss_fct(shift_logits.permute(0, 2, 1), shift_labels)
        neg_log_likelihood = loss.float().mean(dim=1)
        nlls.append(neg_log_likelihood)
    nlls_tensor = torch.cat(nlls)
    ppl = torch.exp(nlls_tensor.mean())
    model.config.use_cache = use_cache
    logging.info(f'\n{args.eval_dataset.upper()} PPL: {ppl.item():.3f}')
    return ppl.item()
