import sys
import numpy as np
from copy import deepcopy
from typing import Callable, List, Any
from collections import defaultdict
from transformers import PreTrainedTokenizer

from search.decoding import residual_mppi_step, SearchResult
from search.utils import compute_divergence

def amulet_game(log_ref, log_player, iteration_num=60, alpha=2.0, player_lambda=2.0, eta=10.0):
    iter_num = iteration_num
    bsz = log_player.shape[0]
    token_size = log_player.shape[-1]

    Q = np.zeros((bsz, iteration_num + 1, token_size))
    log_players_0 = deepcopy(log_player)
    log_player_mem = np.zeros((bsz, iteration_num + 1, token_size))
    log_player_mem[:, 0] = log_players_0[:, -1, :]

    for cur_iter in range(1, iter_num + 1):
        Q[:, cur_iter] = alpha * (log_player[:, -1, :] - log_ref[:, -1, :])
        log_player[:, -1, :] = ( player_lambda * log_players_0[:, -1, :] + np.expand_dims(np.sum(Q, axis = 1) / cur_iter, axis = 0) + log_player_mem[:,cur_iter-1] / (eta * cur_iter) ) / ( player_lambda + 1 / (cur_iter * eta) )

        if cur_iter == iter_num:
            return log_player
            
        exp_logits = np.exp(log_player - np.max(log_player, axis=-1, keepdims=True))
        softmax_probs = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
        log_player = np.log(softmax_probs)
        log_player_mem[:, cur_iter] = log_player[:, -1, :]
            
    return log_player

def amulet(llm: Any, prompt: str, sample: Any, verifier: Callable[[str, List[str]], List[List[float]]], tokenizer: PreTrainedTokenizer, args: Any) -> SearchResult:
    current_text = prompt
    step = 0

    metrics = defaultdict(list)
    while True:
        step += 1

        output_state = residual_mppi_step(llm=llm, prompt=current_text, temperature=args.temperature, max_tokens=1, top_logprobs_num=20)
        
        output_res = output_state.meta_info.get("result", {})
        output_logprobs = output_res.get("output_top_logprobs", [])
        if not output_logprobs:
             break
        
        first_step_logprobs = output_logprobs[0]
        pref_logprobs = [prob[0] for prob in first_step_logprobs] # logprob of top-20 tokens 
        pref_token_texts = [prob[2] for prob in first_step_logprobs] # logprob of top-20 tokens 
        
        nonpref_prompt = sample["_non_pref_prompt"]
        nonpref_prompt += current_text[len(prompt):]

        nonpref_state = residual_mppi_step(llm=llm, prompt=nonpref_prompt, temperature=args.temperature, max_tokens=1, top_logprobs_num=20)
        output_nonpref_res = nonpref_state.meta_info.get("result", {})
        output_nonpref_logprobs = output_nonpref_res.get("output_top_logprobs", [])
        if not output_nonpref_logprobs:
            break
            
        first_step_nonpref_logprobs = output_nonpref_logprobs[0]
        nonpref_logprobs = [prob[0] for prob in first_step_nonpref_logprobs] # logprob of top-20 tokens 
        nonpref_token_texts = [prob[2] for prob in first_step_nonpref_logprobs] # logprob of top-20 tokens 

        common_tokens = list(set(pref_token_texts) & set(nonpref_token_texts))
        token_to_pref_idx = {tok: pref_token_texts.index(tok) for tok in common_tokens}
        token_to_nonpref_idx = {tok: nonpref_token_texts.index(tok) for tok in common_tokens}

        if len(common_tokens) < 5:
            max_token_text = pref_token_texts[0]
        else:
            pref_logprobs_arr = np.array([pref_logprobs[token_to_pref_idx[tok]] for tok in common_tokens]).reshape(1, 1, -1)
            nonpref_logprobs_arr = np.array([nonpref_logprobs[token_to_nonpref_idx[tok]] for tok in common_tokens]).reshape(1, 1, -1)

            log_player = amulet_game(pref_logprobs_arr, nonpref_logprobs_arr)
            max_token_id = np.argmax(log_player[0, 0])
            max_token_text = common_tokens[max_token_id]
        
        current_text += max_token_text
        if max_token_text == tokenizer.eos_token or output_state.completed or step >= args.max_tokens:
            break
        
    response = current_text[len(prompt):]
    return SearchResult(
        outputs=[response],
        prediction=response,
        scores=[[1.0]],
        agg_scores=[1.0],
        metrics={
            "step": step,
            **{k: np.mean(v) for k, v in metrics.items()}
        }
    )
