import os
import json
from util import is_equiv, extract_math_answer
import random
import numpy as np
import math

try:
    import ujson as json
except ImportError:
    try:
        import simplejson as json
    except ImportError:
        import json

import numpy as np
import random
from collections import Counter
from tqdm import tqdm


class AdapEstimator:
    def __init__(self, keys, gradnorm_data, metric):
        self.keys = keys
        
        self.metric = metric
        if metric == 'Entropy':
            self.reversed = True
        elif metric == 'VoG':
            self.reversed = False
        elif metric == 'Length':
            self.reversed = True
        elif metric == 'GradNorm':
            self.reversed = False
        elif metric == 'CS':
            self.reversed = False
        
        self.input_features = {}
        for question, data in gradnorm_data.items():
            self.input_features[question] = {}
            metrics = data['input_metrics']
            input_grad_norm = np.array(metrics['each_token_gradient_norm'])[:-1]
            self.input_features[question]['VoG'] = input_grad_norm.var()
            self.input_features[question]['GradNorm'] = input_grad_norm.mean()
            self.input_features[question]['Entropy'] = abs(metrics['entropy_loss'])
            self.input_features[question]['Length'] = metrics['length']
        
        self._difficulty_all = {}
        {question: [] for question in keys}
        if metric == 'CS':
            self._difficulty_all = {question: [] for question in keys}
        else:
            for question, metrics in self.input_features.items():
                self._difficulty_all[question] = [metrics[self.metric]]

        # self._scores = {question: 1 for question in keys}
        self._scores = {question: 0 for question in keys}
        # self._difficulty_all = {question: [] for question in keys}
        self.all_selected_answers = {question: [] for question in keys}
        self.gen_idx = {question: 0 for question in keys}

    def add_answer(self, question, answer, output_metrics, input_len):
        if self.metric == 'VoG':
            _difficulty = np.array(output_metrics['each_token_gradient_norm']).var()
        elif self.metric == 'GradNorm':
            _difficulty = np.array(output_metrics['each_token_gradient_norm']).mean()
        elif self.metric == 'Entropy':
            _difficulty = abs(output_metrics['entropy_loss'])
        elif self.metric == 'Length':
            _difficulty = np.array(output_metrics['length'] - input_len).mean()
        elif self.metric == 'CS':
            _difficulty = 0
        else:
            _difficulty = np.array(output_metrics[self.metric]).mean()
        self._difficulty_all[question].append(_difficulty)
        self.all_selected_answers[question].append(answer)
        self.gen_idx[question] += 1
    
    def pop(self, question):
        if question in self.all_selected_answers:
            self.all_selected_answers.pop(question)
            self._scores.pop(question)
            self._difficulty_all.pop(question)
            self.input_features.pop(question)
        
    def get_distribution(self):
        for question, _score_list in self._difficulty_all.items():
            if _score_list:
                ##### Only count majority answer
                if self.metric == 'CS':
                    answers = self.all_selected_answers[question]
                    mv_cnt = most_frequent(answers)[1]
                    if self.reversed:
                        self._scores[question] = 1 / (mv_cnt / len(answers)) # not used
                    else:
                        self._scores[question] = mv_cnt / len(answers)
                # else:
                #     if self.reversed:
                #         self._scores[question] = 1 / np.mean(_score_list)
                #     else:
                #         self._scores[question] = np.mean(_score_list)
                else:
                    if self.reversed:
                        self._scores[question] = 1 / np.mean(_score_list) ** (len(self._scores) / 50)
                    else:
                        self._scores[question] = np.mean(_score_list) ** (len(self._scores) / 50)
            else:
                if self.metric == 'CS':
                    self._scores[question] = 100
                elif self.metric == 'Length':
                    self._scores[question] = 1
                # if new_score == 0:
                #     self._scores[question] = 1
                # else:
                #     self._scores[question] = new_score

        sum_value = sum(self._scores.values())
        return {question: score / sum_value for question, score in self._scores.items()}


def run_adaptive_metric_experiment(metric_name, BUDGET_OPTIONS, gradnorm_data, samples_data, INIT_ALLOC, n_seeds=3):
    _coverages_mean = []
    
    # Generate shuffled indices once per budget to avoid redundant shuffling
    all_shuffled_indices = []
    for seed in range(n_seeds):
        random.seed(seed)
        gen_idxs = list(range(500))
        random.shuffle(gen_idxs)
        all_shuffled_indices.append(gen_idxs)
    
    # Process each budget option
    for B in tqdm(BUDGET_OPTIONS, desc="Processing budgets"):
        seed_coverage_ratio = []
        
        # Process each seed
        for seed in range(n_seeds):
            generation_idxs = all_shuffled_indices[seed]
            _Estimator = AdapEstimator(list(gradnorm_data.keys()), gradnorm_data, metric=metric_name)
            _coverage = {question: 0 for question in gradnorm_data}
            
            # Main budget allocation loop
            used_B = 0
            active_questions = set(_Estimator.all_selected_answers.keys())
            while used_B < (B + INIT_ALLOC * len(gradnorm_data)) and active_questions:
                # Get distribution once outside the selection loop
                cs_dist = _Estimator.get_distribution()

                # Select question based on distribution
                questions = list(cs_dist.keys())
                weights = list(cs_dist.values())
                question = random.choices(questions, weights=weights)[0]
                
                # If the question is already solved, skip it
                if _coverage[question] > 0:
                    _Estimator.pop(question)
                    active_questions.remove(question)
                    continue
                
                gt = gradnorm_data[question]['ground_truth']
                
                # Get next sample for this question
                idx = _Estimator.gen_idx[question]
                if idx >= len(generation_idxs):
                    _Estimator.pop(question)
                    active_questions.remove(question)
                    continue
                
                sample_idx = generation_idxs[idx]
                answer = samples_data[question][sample_idx]
                output_metrics = gradnorm_data[question][f'output_metrics_{sample_idx}']
                input_len = gradnorm_data[question]['input_metrics']['length']
                _Estimator.add_answer(question, answer, output_metrics, input_len)
                
                used_B += 1
                # Check if the answer solves the question
                if is_equiv(gt, answer):
                    _coverage[question] = 1
                    _Estimator.pop(question)
                    active_questions.remove(question)
            
            # Calculate coverage
            coverage_ratio = sum(_coverage.values()) / len(gradnorm_data)
            seed_coverage_ratio.append(coverage_ratio)
            
        _coverages_mean.append(np.mean(seed_coverage_ratio))

    return _coverages_mean

def pass_at_k(n, c, k):
    if n-c < k: return 1.0
    return 1.0 - np.prod(1.0 - k / np.arange(n-c+1, n+1))


def most_frequent(lst):
    """Returns the most frequent element and its count."""
    if not lst:
        return None, 0
    counter = Counter(lst)
    most_common = counter.most_common(1)
    return most_common[0]  # (element, count)


# samples_json_path = "Qwen2.5_Math_1.5B_MATH500_samples/samples_parsed.json"
# samples_json_path = "Qwen2.5_Math_1.5B_GSM8K_samples/samples_parsed.json"
# samples_json_path = "Qwen2.5_1.5B_MATH500_samples/samples_parsed.json"
samples_json_path = "Qwen2.5_1.5B_GSM8K_samples/samples_parsed.json"
with open(samples_json_path, "r") as f:
    samples_data = json.load(f)
# gradnorm_json_path = "Qwen2.5_Math_1.5B_MATH500_allocation/output_lengths.json"
# gradnorm_json_path = "Qwen2.5_Math_1.5B_GSM8K_allocation/output_lengths.json"
# gradnorm_json_path = "Qwen2.5_1.5B_MATH500_allocation/output_lengths.json"
gradnorm_json_path = "Qwen2.5_1.5B_GSM8K_allocation/output_lengths.json"
with open(gradnorm_json_path, "r") as f:
    gradnorm_data = json.load(f)

INIT_ALLOC = 0
BUDGET_OPTIONS = [int(2**(i-4) * len(gradnorm_data)) for i in range(9)]
all_coverage = {}


# Uniform Baseline
generation_idxs = list(range(500))
baseline_coverages = []
for B in BUDGET_OPTIONS:
    _coverages = []
    generation_idxs = list(range(500))
    for seed in range(3):
        random.seed(seed)
        random.shuffle(generation_idxs)
        coverage_uniform = {question: 0 for question in gradnorm_data}
        # If we apply an initial budget to all questions first
        shuffled_keys = list(gradnorm_data.keys())
        random.shuffle(shuffled_keys)
        used_B = 0
        start_gen_idx = 0
        while used_B < B + INIT_ALLOC * len(gradnorm_data):
            for question in shuffled_keys:
                if used_B >= B + INIT_ALLOC * len(gradnorm_data):
                    break
                data = gradnorm_data[question]
                gt = data['ground_truth']
                selected_idx = generation_idxs[start_gen_idx]
                if is_equiv(gt, samples_data[question][selected_idx]):
                    coverage_uniform[question] = 1
                used_B += 1
            start_gen_idx += 1
        covered = 0
        for v in coverage_uniform.values():
            if v > 0:
                covered += 1
        coverage_ratio = covered / len(gradnorm_data)
        # random_coverages_std.append(np.std(coverage_ratio))
        _coverages.append(coverage_ratio)
    seed_coverage_ratio = np.mean(_coverages)
    baseline_coverages.append(seed_coverage_ratio)

all_coverage['uniform'] = baseline_coverages
for metric in ['VoG', 'GradNorm', 'Entropy', 'Length', 'CS']:
    all_coverage[metric] = run_adaptive_metric_experiment(metric, BUDGET_OPTIONS, gradnorm_data, samples_data, INIT_ALLOC, n_seeds=3)

for k, v in all_coverage.items():
    all_coverage[k] = [round(i, 3) for i in v]
print(all_coverage)
