import json
import numpy as np
import os
import pandas as pd
import random
from collections import Counter
from config import MATH_PROBE_FREQ
from tqdm import tqdm


def generate_perfect_predictions(data, interval, stable):
    """
    Generate the perfect prediction per question
    """
    df = data.sort_values(
        by=['unique_id', 'chain_id', 'tokens']
    )
    predict = dict()
    grb = df.groupby(by=['unique_id', 'chain_id'], sort=False)
    min_stable_t = dict()
    for key, question_df in grb:
        uid, cid = key
        tokens = np.ceil(question_df.tokens.values / MATH_PROBE_FREQ).astype(int) * MATH_PROBE_FREQ
        answers = question_df.curr_answer.values
        answers = [v if v != np.nan else 'nan' for v in answers]
        answers = np.array(answers)
        gt = question_df.answer.values[0]
        last_probe = 0
        for idx in range(len(tokens)):
            t, a = tokens[idx], answers[idx]
            if (t - last_probe >= interval) or (idx == len(tokens) - 1):
                last_probe = t
                if a == gt:
                    rest_idx = list(range(idx, len(tokens)))
                    if len(rest_idx) > stable:
                        rest_idx = rest_idx[:stable]
                    if np.unique(answers[rest_idx]).shape[0] == 1:
                        predict[(uid, cid, t)] = 1
                        min_stable_t[(uid, cid)] = min(
                            t,
                            min_stable_t.get((uid, cid), np.inf)
                        )
    for key, question_df in grb:
        uid, cid = key
        tokens = question_df.tokens.values
        answers = question_df.curr_answer.values
        gt = question_df.answer.values[0]
        last_probe = 0
        for idx in range(len(tokens)):
            t, a = tokens[idx], answers[idx]
            if (uid, cid, t) not in predict.keys():
                mst = min_stable_t.get((uid, cid), None)
                if mst is not None:
                    if mst > t:
                        predict[(uid, cid, t)] = 1 / (mst - t)
                    else:
                        predict[(uid, cid, t)] = 1
                else:
                    predict[(uid, cid, t)] = 0

    df = list()
    for (uid, cid, t), p in predict.items():
        df.append({"unique_id": uid, "chain_id": cid, "tokens": t, "pred": p})
    return pd.DataFrame(df)



def baseline(data, n_chains, resample_chains=True):
    data_final = data[data.type == 'final']
    df = list()
    for uid, question_df in data_final.groupby(by='unique_id'):
        gt = question_df.answer.values[0]
        tokens = question_df.tokens.values
        votes = question_df.curr_answer.values.tolist()
        all_idxs = list(range(len(tokens)))
        if resample_chains:
            sampled_idxs = random.choices(all_idxs, k=n_chains)
        else:
            sampled_idxs = sorted(all_idxs)[:n_chains]
        tokens_used = np.sum(tokens[sampled_idxs])
        sampled_votes = [votes[i] for i in sampled_idxs]
        unique_votes, counters = np.unique(sampled_votes, return_counts=True)
        consensus = unique_votes[np.argmax(counters)]
        is_correct = int(consensus == gt)
        df.append(
            {"unique_id": uid, "tokens": tokens_used, "accuracy": is_correct,
            "sampled_chains": sampled_idxs, "sampled_votes": sampled_votes})

    return pd.DataFrame(df)


def short_m(data, n_chains, m_frac, resample_chains=True):
    data_final = data[data.type == 'final']
    df = list()
    m = max(1, min(int(m_frac * n_chains), n_chains))
    for uid, question_df in data_final.groupby(by='unique_id'):
        gt = question_df.answer.values[0]
        tokens = question_df.tokens.values
        votes = question_df.curr_answer.values
        if resample_chains:
            sampled_idxs = random.choices(list(range(len(tokens))), k=n_chains)
        else:
            sampled_idxs = sorted(list(range(len(tokens)))[:n_chains])
        tokens_resampled = tokens[sampled_idxs]
        votes_resampled = votes[sampled_idxs]
        ordered_idx = np.argsort(tokens_resampled)
        token_cap = np.amax(tokens_resampled[ordered_idx[:m]])
        tokens_used = token_cap * (n_chains - m) + np.sum(tokens_resampled[ordered_idx[:m]])
        sampled_votes = votes_resampled[tokens_resampled <= token_cap]
        unique_votes, counters = np.unique(sampled_votes, return_counts=True)
        consensus = unique_votes[np.argmax(counters)]
        is_correct = int(consensus == gt)
        df.append(
            {"unique_id": uid, "tokens": tokens_used, "accuracy": is_correct})

    return pd.DataFrame(df)


def dynasor(data, n_chains, interval, patience, resample_chains=True,record_dir=None):
    if record_dir is not None:
        os.makedirs(record_dir, exist_ok=True)
    subsamples = data[['unique_id', 'chain_id']].drop_duplicates()
    if resample_chains:
        subsamples_wdup = subsamples.groupby('unique_id', group_keys=False).apply(
            lambda x: x.sample(n=n_chains, replace=True)
        ).reset_index(drop=True)
    else:
        subsamples_wdup = subsamples
    subsamples = subsamples_wdup.drop_duplicates()
    subsamples = data.merge(subsamples, on=['unique_id', 'chain_id',], how='inner')
    max_tokens = subsamples[subsamples.type == 'final']
    df = list()
    for uid, question_df in subsamples.groupby(by='unique_id'):
        grb_chains = question_df.groupby(by='chain_id')
        chain2tv = dict()
        for cid, chain_df in grb_chains:
            df2 = chain_df.sort_values(by='tokens', ascending=True)
            tokens = df2.tokens.values
            curr_answers = df2.curr_answer.values
            idx, found, counter = 1, False, 1
            last_t = 0
            while (idx < tokens.shape[0]) and (not found):
                if tokens[idx] - last_t >= interval:
                    last_t = tokens[idx]
                    if curr_answers[idx] == curr_answers[idx - 1]:
                        counter += 1
                    else:
                        counter = 1
                    if counter == patience:
                        found = True
                        tv = tokens[idx], curr_answers[idx]
                idx += 1
            if not found:
                end_stat = max_tokens[(max_tokens.unique_id == uid) & (max_tokens.chain_id == cid)]
                tv = end_stat.tokens.values[0], end_stat.curr_answer.values[0]
            chain2tv[cid] = tv
        votes, token_ct = list(), 0
        sampled_chains = subsamples_wdup[subsamples_wdup.unique_id == uid].chain_id.values
        for sc in sampled_chains:
            tv = chain2tv[sc]
            votes.append(tv[1])
            token_usage_raw = tv[0]
            token_usage_w_probing = token_usage_raw + 10 * np.floor(token_usage_raw / interval)
            token_ct += token_usage_w_probing
        unique_votes, counters = np.unique(votes, return_counts=True)
        consensus = unique_votes[np.argmax(counters)]
        is_correct = int(consensus == question_df.answer.values[0])
        df.append({"unique_id": uid, "tokens": token_ct, "accuracy": is_correct,
                   "sampled_chains": sampled_chains, "sampled_votes": votes})
        record_log = {
            "probe_interval": int(interval),
            "sampled_chains": [int(sc) for sc in sampled_chains],
            "max_tokens": [int(chain2tv[sc][0]) for sc in sampled_chains],
        }
        if record_dir is not None:
            with open(os.path.join(record_dir, f"{uid}.json"), "w") as f:
                json.dump(record_log, f, indent=4)
    return pd.DataFrame(df)


def duchess(
    data_df,
    chains_to_use,
    threshold_high=0.9,
    threshold_low=0.1,
    warm_up = 2,
    patience_high=2,
    patience_low=2,
    interval=16,
    consensus_frac=0.6,
    voting_frac=1,
    branch_strategy='random',
    progress=True,
    resample_chains=True,
    record_dir=None,
    sample_lambda=1,
):
    df = list()
    data = data_df.sort_values(by=['unique_id', 'chain_id', 'tokens'], ascending=True)
    data['token_budget'] = np.ceil(data['tokens'] / 16).astype(int) * 16  # TODO: parameterize probing frequency

    consensus_threshold = max(2, int(chains_to_use * consensus_frac))
    voting_threshold = max(3, int(chains_to_use * voting_frac))
    grb_q = data.groupby('unique_id')

    if record_dir is not None:
        os.makedirs(record_dir, exist_ok=True)
    
    if progress:
        iterator = tqdm(grb_q)
    else:
        iterator = grb_q

    for q_idx, question_df in iterator:
        if question_df.shape[0] == 0:
            continue

        gt = question_df.answer.values[0]
        
        chain_ids = question_df.chain_id.unique().tolist()
        if resample_chains:
            chosen_chain_idxs = random.choices(chain_ids, k=chains_to_use)
        else:
            chosen_chain_idxs = sorted(chain_ids)[:chains_to_use]

        grb_chains = question_df.groupby(by='chain_id')
        ans_mat = list()
        proba_mat = list()
        max_tokens = list()
        eos_answers = list()
        token_budgets = None

        for chain_id in chosen_chain_idxs:
            for chain_id2, chain_df in grb_chains:
                if chain_id2 == chain_id:

                    chain_intermediate = chain_df[chain_df.type == 'intermediate']
                    curr_token_budgets = chain_intermediate.token_budget.values
                    curr_answers = chain_intermediate.curr_answer.values
                    predictions = chain_intermediate.pred.values

                    if (token_budgets is None) or (token_budgets.shape[0] < curr_token_budgets.shape[0]):
                        token_budgets = curr_token_budgets
                    
                    chain_final = chain_df[chain_df.type == 'final']
                    max_token = chain_final.tokens.values[0]
                    final_answer = chain_final.curr_answer.values[0]

                    ans_mat.append(curr_answers)
                    proba_mat.append(predictions)
                    max_tokens.append(max_token)
                    eos_answers.append(final_answer)
        
        if np.amax(token_budgets) < np.amax(max_tokens):
            token_budgets = np.append(token_budgets, np.amax(token_budgets) + interval)
        
        if record_dir is not None:
            record_log = {
                "unique_id": str(q_idx),
                "max_tokens": [int(m) for m in max_tokens],
                "warm_up_tokens": int(warm_up * interval),
                "iterations": list(),
            }
            last_token_budget = 0
        
        max_len = np.amax([a.shape[0] for a in ans_mat])
        for i in range(len(ans_mat)):
            assert(ans_mat[i].shape[0] == proba_mat[i].shape[0])
            curr_len = ans_mat[i].shape[0]
            if curr_len < max_len:
                padding = np.full(max_len - curr_len, eos_answers[i], dtype=ans_mat[i].dtype)
                ans_mat[i] = np.concatenate([ans_mat[i], padding])

                padding = np.full(max_len - curr_len, 0.5, dtype=proba_mat[i].dtype)
                proba_mat[i] = np.concatenate([proba_mat[i], padding])
        
        ans_mat, proba_mat = np.array(ans_mat), np.array(proba_mat)
        
        slot2chain = list(range(chains_to_use))
        slot2counter = dict()

        final_answers = list()
        tb_at_termination = 0
        majority_ans, ans_counter = None, dict()
        probing_tokens = 0

        reason = 'NA'
        for iter_idx in range(max_len):

            t = token_budgets[iter_idx]
            if t % interval != 0:
                continue

            if t <= warm_up * interval:
                continue

            if record_dir is not None:
                iter_log = {
                    "token_budget": int(t),
                    "step_sizes": list(),
                    "probed_chains": list(),
                    "branch_outs": list()
                }

            for slot_idx, chain_idx in enumerate(slot2chain):

                high_count, low_count = slot2counter.get(slot_idx, (0, 0))
                proba = proba_mat[chain_idx, iter_idx]

                if proba >= threshold_high:
                    high_count += 1
                else:
                    high_count = 0
                if proba <= threshold_low:
                    low_count += 1
                else:
                    low_count = 0

                slot2counter[slot_idx] = (high_count, low_count)

            terminated_slot_idx = list()
            remaining_slot_idx, remaining_slot_proba = list(), list()
            if record_dir is not None:
                branch_outs = list()
            for slot_idx, chain_idx in enumerate(slot2chain):
                chain_reason = None
                new_final_ans = None
                if t >= max_tokens[chain_idx]:
                    terminated_slot_idx.append(slot_idx)
                    if str(eos_answers[chain_idx]) != 'Invalid':
                        new_final_ans = eos_answers[chain_idx]
                    chain_reason = 'EOS'
                else:
                    high_count, low_count = slot2counter.get(slot_idx, (0, 0))
                    if high_count >= patience_high:
                        terminated_slot_idx.append(slot_idx)
                        new_final_ans = ans_mat[chain_idx, iter_idx]
                        chain_reason = 'high confidence'
                        probing_tokens += 10
                        if record_dir is not None:
                            iter_log["probed_chains"].append(slot_idx)
                    elif low_count >= patience_low:
                        terminated_slot_idx.append(slot_idx)
                        chain_reason = 'low confidence'
                    else:
                        remaining_slot_idx.append(slot_idx)
                        remaining_slot_proba.append(proba_mat[chain_idx, iter_idx])
                
                if record_dir is not None:
                    iter_log["step_sizes"].append(int(min(t - last_token_budget, max_tokens[chain_idx] - last_token_budget)))

                if new_final_ans is not None:
                    new_final_ans_count = ans_counter.get(new_final_ans, 0) + 1
                    if new_final_ans_count > ans_counter.get(majority_ans, 0):
                        majority_ans = new_final_ans
                    ans_counter[new_final_ans] = new_final_ans_count
                    final_answers.append(new_final_ans)
                
                '''
                if chain_reason is not None:
                    print(f"Terminated chain {chain_idx} in slot {slot_idx}")
                    print(f"High count: {high_count}, Low count: {low_count}")
                    print(f"At token budget {t}")
                    print(f"Reason: {chain_reason}")
                '''
            
            # check if termination condition is met
            if len(final_answers) >= voting_threshold:
                tb_at_termination = t
                reason = 'Voting'
                break
            
            elif ans_counter.get(majority_ans, 0) >= consensus_threshold:
                tb_at_termination = t
                reason = 'Consensus'
                break
            
            elif len(terminated_slot_idx) > 0:

                if len(remaining_slot_idx) == 0:
                    tb_at_termination = t
                    reason = 'Exhausted'
                    break

                # branch-out
                if branch_strategy == 'greedy':
                    best_remaining_slot_idx = remaining_slot_idx[np.argmax(remaining_slot_proba)]
                    for ti in terminated_slot_idx:
                        slot2chain[ti] = slot2chain[best_remaining_slot_idx]
                        slot2counter[ti] = slot2counter[best_remaining_slot_idx]
                        if record_dir is not None:
                            iter_log["branch_outs"].append((int(ti), int(best_remaining_slot_idx)))
                elif branch_strategy == 'random':
                    for ti in terminated_slot_idx:
                        selection = random.choice(remaining_slot_idx)
                        slot2chain[ti] = slot2chain[selection]
                        slot2counter[ti] = slot2counter[selection]
                        if record_dir is not None:
                            iter_log["branch_outs"].append((int(ti), int(selection)))
                elif branch_strategy == 'greedy_prob':
                    weights = np.exp([
                        r ** (1 / sample_lambda) for r in remaining_slot_proba
                    ])
                    weights = weights / np.sum(weights)
                    
                    for ti in terminated_slot_idx:
                        selection = np.random.choice(remaining_slot_idx, p=weights)
                        slot2chain[ti] = slot2chain[selection]
                        slot2counter[ti] = slot2counter[selection]
                        if record_dir is not None:
                            iter_log["branch_outs"].append((int(ti), int(selection)))
                else:
                    raise ValueError(f'Invalid branch strategy: {branch_strategy}')
        
            if record_dir is not None:
                record_log["iterations"].append(iter_log)
                last_token_budget = t

        if reason == 'NA':
            tb_at_termination = token_budgets[-1]

        q_log = {"unique_id": q_idx}
        q_log['tokens'] = tb_at_termination * chains_to_use + probing_tokens
        if len(final_answers) > 0:
            q_log['accuracy'] = int(majority_ans == gt)
        else:
            q_log['accuracy'] = 0
        q_log['reason'] = reason
        df.append(q_log)

        if record_dir is not None:
            with open(os.path.join(record_dir, f"{q_idx}.json"), "w") as f:
                json.dump(record_log, f, indent=4)
    
    return pd.DataFrame(df)
