import sys
import numpy as np
import sglang as sgl
import torch
import torch.nn.functional as F
from copy import deepcopy
from typing import Callable, List
from collections import defaultdict
from transformers import PreTrainedTokenizer
from sglang.lang.interpreter import ProgramState

from search.decoding import residual_mppi_step


def amulet_game(log_ref, log_player, iteration_num=60, alpha=2.0, player_lambda=2.0, eta=10.0):
    iter_num = iteration_num
    bsz = log_player.shape[0]
    token_size = log_player.shape[-1]

    Q = np.zeros((bsz, iteration_num + 1, token_size))
    log_players_0 = deepcopy(log_player)
    log_player_mem = np.zeros((bsz, iteration_num + 1, token_size))
    log_player_mem[:, 0] = log_players_0[:, -1, :]

    for cur_iter in range(1, iter_num + 1):
        # Update Q_i^{t + 1} at time t
        Q[:, cur_iter] = alpha * (log_player[:, -1, :] - log_ref[:, -1, :])

        # Update logits for player_i^{t + 1} 
        log_player[:, -1, :] = ( player_lambda * log_players_0[:, -1, :] + np.expand_dims(np.sum(Q, axis = 1) / cur_iter, axis = 0) + log_player_mem[:,cur_iter-1] / (eta * cur_iter) ) / ( player_lambda + 1 / (cur_iter * eta) )

        if cur_iter == iter_num:
            return log_player
            
        # Manual softmax implementation using numpy
        exp_logits = np.exp(log_player - np.max(log_player, axis=-1, keepdims=True))
        softmax_probs = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
        log_player = np.log(softmax_probs)

        # Save current policy into the memory for latter average computation
        log_player_mem[:, cur_iter] = log_player[:, -1, :]
            
    return log_player


@sgl.function
def amulet(s: ProgramState, prompt, sample, verifier: Callable[[str, List[str]], List[List[float]]], tokenizer: PreTrainedTokenizer, args):
    preference = sample["preference"]
    question = sample["question"]
    s += prompt
    step = 0

    residual_mppi_step_samples = args.n_samples // args.residual_mppi_rollout_samples
    metrics = defaultdict(list)
    while True:
        step += 1

        prompt_input = s.text()
        output_state = residual_mppi_step.run(prompt=prompt_input, temperature=args.temperature, max_tokens=1, top_logprobs_num=20)
        
        output_logprobs = output_state.get_meta_info("result")["output_top_logprobs"][0]
        pref_logprobs = [prob[0] for prob in output_logprobs] # logprob of top-20 tokens 
        pref_token_texts = [prob[2] for prob in output_logprobs] # logprob of top-20 tokens 
        
        nonpref_prompt = sample["_non_pref_prompt"]
        nonpref_prompt += output_state.text()[len(prompt):]

        nonpref_state = residual_mppi_step.run(prompt=nonpref_prompt, temperature=args.temperature, max_tokens=1, top_logprobs_num=20)
        output_nonpref_logprobs = nonpref_state.get_meta_info("result")["output_top_logprobs"][0]
        nonpref_logprobs = [prob[0] for prob in output_nonpref_logprobs] # logprob of top-20 tokens 
        nonpref_token_texts = [prob[2] for prob in output_nonpref_logprobs] # logprob of top-20 tokens 

        # Find common token texts
        common_tokens = list(set(pref_token_texts) & set(nonpref_token_texts))
        token_to_pref_idx = {tok: pref_token_texts.index(tok) for tok in common_tokens}
        token_to_nonpref_idx = {tok: nonpref_token_texts.index(tok) for tok in common_tokens}

        if len(common_tokens) < 5:
            max_token_text = pref_token_texts[0]
        else:
            pref_logprobs = np.array([pref_logprobs[token_to_pref_idx[tok]] for tok in common_tokens]).reshape(1, 1, -1)
            nonpref_logprobs = np.array([nonpref_logprobs[token_to_nonpref_idx[tok]] for tok in common_tokens]).reshape(1, 1, -1)

            log_player = amulet_game(pref_logprobs, nonpref_logprobs)
            max_token_id = np.argmax(log_player[0, 0])
            max_token_text = common_tokens[max_token_id]
        
        s += max_token_text
        # print(max_token_text, end="", flush=True)
        if max_token_text == tokenizer.eos_token or output_state["completed"] or step >= args.max_tokens:
            break
        
    response = s.text()[len(prompt):]
    s["prediction"] = response
    s["outputs"] = [response]
    s["scores"] = [[1.0]]
    s["agg_scores"] = [1.0]
    s["metric"] = {
        "step": step,
        **{k: np.mean(v) for k, v in metrics.items()}
    }
