import warnings
import torch.nn.functional as F
import time

warnings.filterwarnings("ignore")

import torch

def get_layer_context(model, tokenizer, input_ids, layer_idx, print_context=False):
    config = model.model.layers[0].self_attn.config
    window_size = config['window_size'] if config else 16
    decoder_layer = model.model.layers[layer_idx]
    idx = decoder_layer.self_attn.indecies[0, 0, :]
    values, _ = torch.sort(idx)
    values = values.to('cuda:0')
    select_input_ids = input_ids.gather(0, values)
    new_input_ids = torch.cat([select_input_ids, input_ids[-window_size:]])
    if print_context:
        print(tokenizer.decode(new_input_ids))
    return new_input_ids.unsqueeze(0)

def get_layer_context_arocss(model, tokenizer, input_ids, layer_idx, print_context=False, return_score=False):
    window_size = 1
    k = 1024
    attn_layers = []
    for i in range(0, 31+1):
        # print(i)
        decoder_layer = model.model.layers[i]
        attn_layers.append(decoder_layer.self_attn.attention_socres)
    a = torch.stack(attn_layers, dim=0).mean(0)
    b = torch.stack(attn_layers, dim=1).mean(0)
    select_num = 8
    import_heads = torch.topk(a, select_num).values.sum() / a.sum()
    import_layers = torch.topk(b, select_num).values.sum() / b.sum()
    print(f'select_num:{select_num}, layers:{b.shape[0]}, heads:{a.shape[0]}.')
    print(f'score:{import_layers}. import_layers:', torch.topk(b, 8).indices)
    print(f'score:{import_heads}. import_heads:', torch.topk(a, 8).indices)

    # sum_scores = torch.stack(pooled_attn_layers, dim=0).sum(dim=0)
    # top_k = torch.topk(sum_scores, k-window_size, dim=-1)
    # idx = top_k.indices        
    # values, _ = torch.sort(idx)
    # values = values.to('cuda:0')
    # select_input_ids = input_ids.gather(0, values)
    # new_input_ids = torch.cat([select_input_ids, input_ids[-window_size:]])
    # if print_context:
    #     print(tokenizer.decode(new_input_ids))
    # return new_input_ids.unsqueeze(0)
    if return_score:
        return input_ids.unsqueeze(0), torch.stack(attn_layers, dim=0)
    else:
        return input_ids.unsqueeze(0)

def reduce_layer(model, layer_idx):
    original_layers = model.model.layers
    model.model.layers = model.model.layers[:layer_idx+1]
    return model, original_layers

def recover_layer(model, layers):
    model.model.layers = layers
    return model

def set_topk(model, topk, mode='gemfilter'):
    decoder_layers = model.model.layers
    for i in range(len(decoder_layers)):
        if mode == 'gemfilter':
            decoder_layers[i].self_attn.topk = topk
        elif mode == 'snapkv':
            decoder_layers[i].self_attn.config.max_capacity_prompt = topk 
        elif mode == 'h2o':
            recent_size = decoder_layers[i].self_attn.kv_cache.recent_size
            decoder_layers[i].self_attn.kv_cache.hh_size = topk - recent_size
            decoder_layers[i].self_attn.kv_cache.cache_max_size = topk
        else:
            raise NotImplementedError
    return

def set_select_mode(model, mode):
    decoder_layers = model.model.layers
    for i in range(len(decoder_layers)):
        decoder_layers[i].self_attn.select_mode = mode
    return

def set_select_layer(model, select_layer_idx):
    if select_layer_idx is None:
        select_layer_idx = model.model.layers[0].self_attn.select_layer_idx
    else:
        decoder_layers = model.model.layers
        for i in range(len(decoder_layers)):
            decoder_layers[i].self_attn.select_layer_idx = select_layer_idx
    return select_layer_idx

def set_config(model, config):

    decoder_layers = model.model.layers
    for i in range(len(decoder_layers)):
            decoder_layers[i].self_attn.config = config
    return 

def set_probe_context(model, probe_context):
    if probe_context is None:
        return None
    else:
        decoder_layers = model.model.layers
        for i in range(len(decoder_layers)):
            decoder_layers[i].self_attn.probe_context = probe_context
    return probe_context

@torch.no_grad()
def my_greedy_generate(model, tokenizer, pred_token_idx, past_key_values, max_gen_len):
    generated_ids = [pred_token_idx.item()]
    for _ in range(max_gen_len):
        outputs = model(
            input_ids=pred_token_idx,
            past_key_values=past_key_values
        )
        past_key_values = outputs.past_key_values
        pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
        if pred_token_idx == tokenizer.eos_token_id:
            break
        generated_ids.append(pred_token_idx.item())
    return generated_ids


@torch.no_grad()
def my_greedy_generate_selection(input_ids, attn_mask, model, tokenizer, max_gen_len=50, select_layer_idx=None, print_context=False):
    set_select_mode(model, True)
    select_layer_idx = set_select_layer(model, select_layer_idx)
    model, original_layers = reduce_layer(model, select_layer_idx)
    _ = model(input_ids, attention_mask=attn_mask, output_last_logits_only=True)
    
    new_input_ids = get_layer_context(
        model, tokenizer, input_ids[0], select_layer_idx, print_context=print_context)
    model = recover_layer(model, original_layers)
    
    set_select_mode(model, False)
    outputs = model(new_input_ids, attention_mask=attn_mask, output_last_logits_only=True)

    past_key_values = outputs.past_key_values
    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)

    output_ids = my_greedy_generate(model, tokenizer, pred_token_idx, past_key_values, max_gen_len=max_gen_len)
    response = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
    return response,  tokenizer.decode(new_input_ids[0])

@torch.no_grad()
def my_greedy_generate_config(input_ids, attn_mask, model, tokenizer, config, max_gen_len=50, print_context=False):
    set_select_mode(model, True)
    set_config(model, config)
    model, original_layers = reduce_layer(model, config['layer'])

    t1 = time.time()
    _ = model(input_ids, attention_mask=attn_mask, output_last_logits_only=True)
    
    new_input_ids = get_layer_context(
        model, tokenizer, input_ids[0], config['layer'], print_context=print_context)

    t2 = time.time()
    model = recover_layer(model, original_layers)
    
    set_select_mode(model, False)
    outputs = model(new_input_ids, attention_mask=attn_mask, output_last_logits_only=True)

    t3 = time.time()
    
    past_key_values = outputs.past_key_values
    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)

    output_ids = my_greedy_generate(model, tokenizer, pred_token_idx, past_key_values, max_gen_len=max_gen_len)
    response = tokenizer.decode(output_ids, skip_special_tokens=True).strip()

    t4 = time.time()

    _ = my_greedy_generate(model, tokenizer, pred_token_idx, past_key_values, max_gen_len=1)

    t5 = time.time()

    pre_fill_time = t2 -t1
    pre_fill_com_time = t3 - t1
    TTFT = pre_fill_com_time + t5 -t4
    Latency = pre_fill_com_time + t4 - t2
    TPOT = (t4 - t2) /len(output_ids)
    time_dict = {'TTFT':TTFT, 'TPOT': TPOT, 'Latency': Latency}
    
    return response,  tokenizer.decode(new_input_ids[0]), time_dict



@torch.no_grad()
def my_greedy_generate_with_probe(input_ids, attn_mask, model, tokenizer, max_gen_len=50, select_layer_idx=None, print_context=False, probe_context=None):
    set_select_mode(model, True)
    set_probe_context(model, probe_context)
    select_layer_idx = set_select_layer(model, select_layer_idx)
    model, original_layers = reduce_layer(model, select_layer_idx)
    _ = model(input_ids, attention_mask=attn_mask, output_last_logits_only=True)
    
    # new_input_ids = get_layer_context(
    return_score = probe_context != None
    new_input_ids, att_scores  = get_layer_context_arocss(
        model, tokenizer, input_ids[0], select_layer_idx, print_context=print_context, return_score=return_score)
    model = recover_layer(model, original_layers)
    
    set_select_mode(model, False)
    outputs = model(new_input_ids, attention_mask=attn_mask, output_last_logits_only=True)

    past_key_values = outputs.past_key_values
    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)

    output_ids = my_greedy_generate(model, tokenizer, pred_token_idx, past_key_values, max_gen_len=max_gen_len)
    response = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
    return response,  tokenizer.decode(new_input_ids[0]), att_scores


@torch.no_grad()
def my_greedy_generate_standard(input_ids, attn_mask, model, tokenizer, max_gen_len=50):
    outputs = model(input_ids, attention_mask=attn_mask, output_last_logits_only=True)

    past_key_values = outputs.past_key_values
    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)

    output_ids = my_greedy_generate(model, tokenizer, pred_token_idx, past_key_values, max_gen_len=max_gen_len)
    response = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
    return response

@torch.no_grad()
def my_greedy_generate_standard_time(input_ids, attn_mask, model, tokenizer, max_gen_len=50):

    t1 = time.time()
    outputs = model(input_ids, attention_mask=attn_mask, output_last_logits_only=True)

    past_key_values = outputs.past_key_values
    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
    t2 = time.time()
    output_ids = my_greedy_generate(model, tokenizer, pred_token_idx, past_key_values, max_gen_len=max_gen_len)
    t3 = time.time()
    _ = my_greedy_generate(model, tokenizer, pred_token_idx, past_key_values, max_gen_len=1)
    t4 = time.time()
    response = tokenizer.decode(output_ids, skip_special_tokens=True).strip()

    pre_fill_time = t2 - t1
    TTFT = pre_fill_time + t4 - t3
    Latency = pre_fill_time + t3 - t2
    TPOT = (t3 - t2) / len(output_ids)
    time_dict = {'TTFT':TTFT, 'TPOT': TPOT, 'Latency': Latency}
    

    return response, time_dict

@torch.no_grad()
def origin_generate(input_ids, attn_mask, model, tokenizer, max_gen_len=50):
    input = {"input_ids":input_ids, 'attention_mask':attn_mask}
    output = model.generate(
                **input,
                max_new_tokens=max_gen_len,
                num_beams=1,
                do_sample=False,
                temperature=1.0,
            )[0]
    context_length = input_ids.shape[-1]
    response = tokenizer.decode(output[context_length:], skip_special_tokens=True).strip()
    return response
