from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import pandas as pd
from tqdm import tqdm
from time import time
import math
import torch.nn.functional as F

def sampling(model, tokenizer, prompts, answers, encoding, max_sampling, conf_threshold, sampling_params, dataset_handler):
    anchor_token = '<ANC>'
    anchor_id = torch.tensor(tokenizer.convert_tokens_to_ids(anchor_token)).to('cuda')
    batch_size = len(prompts)
    results = []
    # 为每个prompt初始化存储
    for prompt, answer in zip(prompts, answers):
        results.append({
            "prompt": prompt,
            "gt_answer": answer,
            "samples": []
        })
        count = 0
    while count<max_sampling:
        start_time = time()
        count +=1
        with torch.no_grad():
            generated_outputs = model.generate(**encoding, **sampling_params, return_dict_in_generate=True, output_logits=True)
        generated_outputs = generated_outputs
        generated_ids = generated_outputs['sequences']
        generated_texts = tokenizer.batch_decode(
            generated_ids[:, encoding['input_ids'].size(1):], 
            skip_special_tokens=True)
        
        model_answers = [dataset_handler.extract_answer(response) for response in generated_texts]
        generated_logits = generated_outputs['logits']
        anc_pos = torch.where(generated_ids == anchor_id)
        input_length = encoding['input_ids'].size(1)
        batch_count = 0
        for item in range(batch_size):
            prompt = prompts[item]
            if item in anc_pos[0]:
                # token_pos is changed here for one by one match
                token_pos = (anc_pos[1][batch_count]-input_length, anc_pos[0][batch_count])
                pred_logits = torch.stack([generated_logits[token_pos[0].item()][token_pos[1].item()]])
                pred_conf = F.softmax(pred_logits, dim=-1)[:,-1]
                pred_conf = pred_conf.item()
                
                results[item]['samples'].append({
                "response": generated_texts[item],
                "answer": model_answers[item],
                "confidence": pred_conf
                })
                batch_count += 1
            else:
                results[item]['samples'].append({
                "response": generated_texts[item],
                "answer": model_answers[item],
                "confidence": -1
                })
                continue
        total_time = time() - start_time
        print(f'time used for one time sapling: {total_time}')
        
    return results


def sample_noconf(model, tokenizer, prompts, answers, encoding, max_sampling, conf_threshold, sampling_params, dataset_handler):
    batch_size = len(prompts)
    results = []
    # 为每个prompt初始化存储
    for prompt, answer in zip(prompts, answers):
        results.append({
            "prompt": prompt,
            "gt_answer": answer,
            "samples": []
        })
        count = 0
    while count<max_sampling:
        start_time = time()
        count +=1
        with torch.no_grad():
            generated_outputs = model.generate(**encoding, **sampling_params)
        generated_ids = generated_outputs
        generated_texts = tokenizer.batch_decode(
            generated_ids[:, encoding['input_ids'].size(1):], 
            skip_special_tokens=True)
        
        model_answers = [dataset_handler.extract_answer(response) for response in generated_texts]
        batch_count = 0
        for item in range(batch_size):
            prompt = prompts[item]
            results[item]['samples'].append({
            "response": generated_texts[item],
            "answer": model_answers[item],
            })
            batch_count += 1
        total_time = time() - start_time
        print(f'time used for one time sapling: {total_time}')
        
    return results
