import pandas as pd
import sqlite3
import torch
import torch.nn.functional as F
import os
import re


def get_score(output, target_model, input_len):
    with torch.no_grad():
        if target_model.config.is_encoder_decoder == False:
            logits = target_model(output).logits
            logits = logits[:,:-1,:]
            logits = torch.nn.functional.log_softmax(logits, dim=-1)
            logits = torch.gather(logits,
                          dim = -1,
                          index = output[:,1:,None])
            if logits.isnan().any():
                print(logits.size())
                print(logits)

            return torch.mean(logits[:,input_len-1:,:])

        else:
            logits = target_model(output[:, :input_len], decoder_input_ids=output[:,input_len:]).logits
            logits = logits[:, :-1, :]
            logits = torch.nn.functional.log_softmax(logits, dim=-1)
            logits = torch.gather(logits,
                                  dim = -1,
                                  index = output[:, input_len+1:, None])
            return torch.mean(logits)


def top_k_top_p_filter(logits: torch.Tensor, top_k: int = 0, top_p: float = 0.0):
    """

    Args:
        logits (torch.Tensorpe_): 2D tensor with shape (batch, vocab)
        top_k (int, optional): top_k. Defaults to 0.
        top_p (float, optional): top_p. Defaults to 0.0.

    Returns:
        torch.Tensor: a renormalized logits
    """

    if top_k is not None and top_k > 0:
        filter = torch.topk(logits, min(top_k, logits.size(-1)))[0]
        logits[logits < filter[..., [-1]]] = float('-inf')
    if top_p is not None and top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(
            F.softmax(sorted_logits, dim=-1), dim=-1)
        filter = cumulative_probs > top_p
        filter[..., 1:] = filter[..., :-1].clone()
        filter[..., 0] = 0
        dim = filter.dim() - 1 
        indices_to_remove = filter.scatter(dim, sorted_indices, filter)
        logits[indices_to_remove] = float('-inf')
    return logits


def norm_logits(logits : torch.Tensor, temperature : float, top_k : float, top_p : float) -> torch.Tensor:
    """

    Args:
        logits (torch.Tensor): shape (1, vocab)
        temperature (float): temperature
        top_k (float): top_k
        top_p (float): top_p

    Returns:
        torch.Tensor: next token with shape as (batch,  1)
    """
    ori_logits = logits
    logits = logits / temperature
    logits = top_k_top_p_filter(logits, top_k=top_k, top_p=top_p)
    probs = torch.log_softmax(logits, dim=-1).exp()
    if probs.isnan().any() or probs.isinf().any() or (probs<0).any():
        print(torch.logical_not(logits.isinf()).any())
        print(logits[probs.isnan()])
        print(logits[probs.isinf()])
        raise RuntimeError('norm logits error')


    return probs


def extract_first_function(code_string):

    lines = code_string.split('\n')

    # Initialize variables
    function_lines = []
    in_function = False
    base_indent = None

    i = 0
    while i < len(lines):
        line = lines[i]
        stripped = line.strip()


        if not in_function and (not stripped or stripped.startswith('from ') or stripped.startswith('import ')):
            if stripped.startswith(('from ', 'import ')):
                function_lines.append(line)
            i += 1
            continue


        if not in_function and re.match(r'^\s*def\s+\w+.*:', line):
            in_function = True
            base_indent = len(line) - len(line.lstrip())
            function_lines.append(line)
            i += 1
            continue


        if in_function:
            # Skip empty lines
            if not stripped:
                function_lines.append(line)
                i += 1
                continue

            current_indent = len(line) - len(line.lstrip())

            if current_indent <= base_indent and stripped:
                break

            function_lines.append(line)
            i += 1
            continue

        i += 1

    return '\n'.join(function_lines)




