import os
import json
from util import is_equiv, extract_math_answer
import random
import numpy as np
import math

try:
    import ujson as json
except ImportError:
    try:
        import simplejson as json
    except ImportError:
        import json

import numpy as np
import random
from collections import Counter
from tqdm import tqdm


class AdapEstimator:
    def __init__(self, keys, gradnorm_data, metric):
        self.keys = keys
        
        self.metric = metric
        if metric == 'Entropy':
            self.reversed = False
        elif metric == 'VoG':
            self.reversed = False
        elif metric == 'Length':
            self.reversed = True
        elif metric == 'GradNorm':
            self.reversed = False
        elif metric == 'CS':
            self.reversed = False
        
        self.input_features = {}
        for question, data in gradnorm_data.items():
            self.input_features[question] = {}
            metrics = data['input_metrics']
            input_grad_norm = np.array(metrics['each_token_gradient_norm'])[:-1]
            self.input_features[question]['VoG'] = input_grad_norm.var()
            self.input_features[question]['GradNorm'] = input_grad_norm.mean()
            self.input_features[question]['Entropy'] = metrics['entropy_loss']
            self.input_features[question]['Length'] = metrics['length']
        
        self._difficulty_all = {}
        {question: [] for question in keys}
        if metric == 'CS':
            self._difficulty_all = {question: [] for question in keys}
        else:
            for question, metrics in self.input_features.items():
                self._difficulty_all[question] = [metrics[self.metric] / 10]

        # self._scores = {question: 1 for question in keys}
        self._scores = {question: 0 for question in keys}
        # self._difficulty_all = {question: [] for question in keys}
        self.all_selected_answers = {question: [] for question in keys}
        self.gen_idx = {question: 0 for question in keys}

    def add_answer(self, question, answer, output_metrics, input_len):
        if self.metric == 'VoG':
            _difficulty = np.array(output_metrics['each_token_gradient_norm']).var()
        elif self.metric == 'GradNorm':
            _difficulty = np.array(output_metrics['each_token_gradient_norm']).mean()
        elif self.metric == 'Entropy':
            _difficulty = output_metrics['entropy_loss']
        elif self.metric == 'Length':
            _difficulty = output_metrics['length'] - input_len
        elif self.metric == 'CS':
            _difficulty = 0
        else:
            _difficulty = np.array(output_metrics[self.metric]).mean()
        self._difficulty_all[question].append(_difficulty)
        self.all_selected_answers[question].append(answer)
        self.gen_idx[question] += 1
    
    def pop(self, question):
        if question in self.all_selected_answers:
            self.all_selected_answers.pop(question)
            self._scores.pop(question)
            self._difficulty_all.pop(question)
            self.input_features.pop(question)
    
    def get_rank(self, if_reverse=False):
        sorted_keys = sorted(self._scores, key=self._scores.get, reverse=if_reverse)
        return sorted_keys
        
    def get_distribution(self):
        for question, _score_list in self._difficulty_all.items():
            if _score_list:
                ##### Only count majority answer
                if self.metric == 'CS':
                    answers = self.all_selected_answers[question]
                    mv_cnt = most_frequent(answers)[1]
                    if self.reversed:
                        self._scores[question] = 1 / (mv_cnt / len(answers)) # not used
                    else:
                        self._scores[question] = mv_cnt / len(answers)
                else:
                    if self.reversed:
                        self._scores[question] = 1 / np.mean(_score_list) ** (len(self._scores) / 50)
                    else:
                        self._scores[question] = np.mean(_score_list) ** (len(self._scores) / 50)
            else:
                if self.metric == 'CS':
                    self._scores[question] = 100
                elif self.metric == 'Length':
                    self._scores[question] = 1

        sum_value = sum(self._scores.values())
        return {question: score / sum_value for question, score in self._scores.items()}


def easyhard_metric_experiment(metric_name, easy2hard, BUDGET_OPTIONS, gradnorm_data, samples_data, INIT_ALLOC, n_seeds=3):
    _coverages_mean = []
    
    # Generate shuffled indices once per budget to avoid redundant shuffling
    all_shuffled_indices = []
    for seed in range(n_seeds):
        random.seed(seed)
        gen_idxs = list(range(500))
        random.shuffle(gen_idxs)
        all_shuffled_indices.append(gen_idxs)
    
    # Process each budget option
    for B in tqdm(BUDGET_OPTIONS, desc="Processing budgets"):
        seed_coverage_ratio = []
        
        # Process each seed
        for seed in range(n_seeds):
            generation_idxs = all_shuffled_indices[seed]
            _Estimator = AdapEstimator(list(gradnorm_data.keys()), gradnorm_data, metric=metric_name)
            _coverage = {question: 0 for question in gradnorm_data}
            
            # Main budget allocation loop
            used_B = 0
            active_questions = set(_Estimator.all_selected_answers.keys())
            while used_B < (B + INIT_ALLOC * len(gradnorm_data)) and active_questions:
                # Get distribution once outside the selection loop
                cs_dist = _Estimator.get_distribution()
                sorted_keys = _Estimator.get_rank(if_reverse=easy2hard)
                question = sorted_keys[0]
                
                # If the question is already solved, skip it
                if _coverage[question] > 0:
                    _Estimator.pop(question)
                    active_questions.remove(question)
                    continue
                
                gt = gradnorm_data[question]['ground_truth']
                
                # Get next sample for this question
                idx = _Estimator.gen_idx[question]
                if idx >= len(generation_idxs):
                    _Estimator.pop(question)
                    active_questions.remove(question)
                    continue
                
                sample_idx = generation_idxs[idx]
                answer = samples_data[question][sample_idx]
                output_metrics = gradnorm_data[question][f'output_metrics_{sample_idx}']
                input_len = gradnorm_data[question]['input_metrics']['length']
                _Estimator.add_answer(question, answer, output_metrics, input_len)
                
                used_B += 1
                # Check if the answer solves the question
                if is_equiv(gt, answer):
                    _coverage[question] = 1
                    _Estimator.pop(question)
                    active_questions.remove(question)
            
            # Calculate coverage
            coverage_ratio = sum(_coverage.values()) / len(gradnorm_data)
            seed_coverage_ratio.append(coverage_ratio)
            
        _coverages_mean.append(np.mean(seed_coverage_ratio))
    return _coverages_mean

def run_adaptive_metric_experiment(metric_name, BUDGET_OPTIONS, gradnorm_data, samples_data, INIT_ALLOC, n_seeds=3):
    _coverages_mean = []
    
    # Generate shuffled indices once per budget to avoid redundant shuffling
    all_shuffled_indices = []
    for seed in range(n_seeds):
        random.seed(seed)
        gen_idxs = list(range(500))
        random.shuffle(gen_idxs)
        all_shuffled_indices.append(gen_idxs)
    
    # Process each budget option
    for B in tqdm(BUDGET_OPTIONS, desc="Processing budgets"):
        seed_coverage_ratio = []
        
        # Process each seed
        for seed in range(n_seeds):
            generation_idxs = all_shuffled_indices[seed]
            _Estimator = AdapEstimator(list(gradnorm_data.keys()), gradnorm_data, metric=metric_name)
            _coverage = {question: 0 for question in gradnorm_data}
            
            # Main budget allocation loop
            used_B = 0
            active_questions = set(_Estimator.all_selected_answers.keys())
            while used_B < (B + INIT_ALLOC * len(gradnorm_data)) and active_questions:
                # Get distribution once outside the selection loop
                cs_dist = _Estimator.get_distribution()

                # Select question based on distribution
                questions = list(cs_dist.keys())
                weights = list(cs_dist.values())
                question = random.choices(questions, weights=weights)[0]
                
                # If the question is already solved, skip it
                if _coverage[question] > 0:
                    _Estimator.pop(question)
                    active_questions.remove(question)
                    continue
                
                gt = gradnorm_data[question]['ground_truth']
                
                # Get next sample for this question
                idx = _Estimator.gen_idx[question]
                if idx >= len(generation_idxs):
                    _Estimator.pop(question)
                    active_questions.remove(question)
                    continue
                
                sample_idx = generation_idxs[idx]
                answer = samples_data[question][sample_idx]
                output_metrics = gradnorm_data[question][f'output_metrics_{sample_idx}']
                input_len = gradnorm_data[question]['input_metrics']['length']
                _Estimator.add_answer(question, answer, output_metrics, input_len)
                
                used_B += 1
                # Check if the answer solves the question
                if is_equiv(gt, answer):
                    _coverage[question] = 1
                    _Estimator.pop(question)
                    active_questions.remove(question)
            
            # Calculate coverage
            coverage_ratio = sum(_coverage.values()) / len(gradnorm_data)
            seed_coverage_ratio.append(coverage_ratio)
            
        _coverages_mean.append(np.mean(seed_coverage_ratio))

    return _coverages_mean

def pass_at_k(n, c, k):
    if n-c < k: return 1.0
    return 1.0 - np.prod(1.0 - k / np.arange(n-c+1, n+1))


def most_frequent(lst):
    """Returns the most frequent element and its count."""
    if not lst:
        return None, 0
    counter = Counter(lst)
    most_common = counter.most_common(1)
    return most_common[0]  # (element, count)


# samples_json_path = "Qwen2.5_Math_1.5B_MATH500_samples/samples_parsed.json"
# samples_json_path = "Qwen2.5_Math_1.5B_GSM8K_samples/samples_parsed.json"
# samples_json_path = "Qwen2.5_1.5B_MATH500_samples/samples_parsed.json"
samples_json_path = "Qwen2.5_1.5B_GSM8K_samples/samples_parsed.json"
with open(samples_json_path, "r") as f:
    samples_data = json.load(f)
# gradnorm_json_path = "Qwen2.5_Math_1.5B_MATH500_allocation/output_lengths.json"
# gradnorm_json_path = "Qwen2.5_Math_1.5B_GSM8K_allocation/output_lengths.json"
# gradnorm_json_path = "Qwen2.5_1.5B_MATH500_allocation/output_lengths.json"
gradnorm_json_path = "Qwen2.5_1.5B_GSM8K_allocation/output_lengths.json"
with open(gradnorm_json_path, "r") as f:
    gradnorm_data = json.load(f)

### Set Budget ###
INIT_ALLOC = 0
BUDGET_OPTIONS = [int(2**(i-2) * len(gradnorm_data)) for i in range(8)]
all_coverage = {}


### Oracle DIPA ###
pass_dict = {}
for question in samples_data:
    gt = gradnorm_data[question]['ground_truth']
    all_answers = samples_data[question]
    acc_count = 0
    for answer in all_answers:
        acc_count += 1 if is_equiv(gt, answer) else 0
    pass_k = {}
    for k in range(1, 501):
        pass_k[k] = pass_at_k(len(all_answers), acc_count, k)
    pass_dict[question] = pass_k

min_pass_k = {}
for question in gradnorm_data:
    pass_k = pass_dict[question]
    min_k = 1000
    for k in pass_k:
        if pass_k[k] >= 0.99:
            min_k = k
            break
    min_pass_k[question] = min_k


oracle_coverages_mean = []
for B in BUDGET_OPTIONS:
    generation_idxs = list(range(500))
    seed_coverage_ratio = []
    for seed in [0, 1, 2]:
        random.seed(seed)
        random.shuffle(generation_idxs)
        _coverage = {question: 0 for question in gradnorm_data}
        sorted_keys = sorted(min_pass_k, key=min_pass_k.get, reverse=False)
        used_B = 0
        total_B = B + INIT_ALLOC * len(gradnorm_data)
        for question in sorted_keys:
            if used_B >= total_B:
                break
            if _coverage[question] > 0:
                continue
            data = gradnorm_data[question]
            gt = data['ground_truth']
            for idx in generation_idxs:
                if used_B >= total_B:
                    break
                used_B += 1
                if is_equiv(gt, samples_data[question][idx]):
                    _coverage[question] += 1
                    break
        covered = 0
        for v in _coverage.values():
            if v > 0:
                covered += 1
        coverage_ratio = covered / len(gradnorm_data)
        seed_coverage_ratio.append(coverage_ratio)
    oracle_coverages_mean.append(np.mean(seed_coverage_ratio))
all_coverage['oracle'] = oracle_coverages_mean

### Uniform Allocation ###
generation_idxs = list(range(500))
baseline_coverages = []
for B in BUDGET_OPTIONS:
    _coverages = []
    generation_idxs = list(range(500))
    for seed in range(3):
        random.seed(seed)
        random.shuffle(generation_idxs)
        # mv_acc = 0
        coverage_uniform = {question: 0 for question in gradnorm_data}
        # If we apply an initial budget to all questions first
        shuffled_keys = list(gradnorm_data.keys())
        random.shuffle(shuffled_keys)
        used_B = 0
        start_gen_idx = 0
        while used_B < B + INIT_ALLOC * len(gradnorm_data):
            for question in shuffled_keys:
                if used_B >= B + INIT_ALLOC * len(gradnorm_data):
                    break
                data = gradnorm_data[question]
                gt = data['ground_truth']
                selected_idx = generation_idxs[start_gen_idx]
                if is_equiv(gt, samples_data[question][selected_idx]):
                    coverage_uniform[question] = 1
                used_B += 1
            start_gen_idx += 1
        covered = 0
        for v in coverage_uniform.values():
            if v > 0:
                covered += 1
        coverage_ratio = covered / len(gradnorm_data)
        # random_coverages_std.append(np.std(coverage_ratio))
        _coverages.append(coverage_ratio)
    seed_coverage_ratio = np.mean(_coverages)
    baseline_coverages.append(seed_coverage_ratio)

all_coverage['uniform'] = baseline_coverages


### DIPA (const) ###

random_coverages_mean = []
for B in BUDGET_OPTIONS:
    generation_idxs = list(range(500))
    seed_mv_acc = []
    seed_coverage_ratio = []
    for seed in [0, 1, 2]:
        # mv_acc = 0
        coverage_uniform = {question: 0 for question in gradnorm_data}
        gen_idx = {question: 0 for question in gradnorm_data}
        random.seed(seed)
        random.shuffle(generation_idxs)
        # Try to solve one question at a time until the problem is solved or the budget is exhausted
        used_B = 0
        active_questions = set(gradnorm_data.keys())

        while used_B < (B + INIT_ALLOC * len(gradnorm_data)) and active_questions:
            # Select question based on distribution
            question = random.choices(list(active_questions))[0]
            
            # If the question is already solved, skip it
            if coverage_uniform[question] > 0:
                active_questions.remove(question)
                continue
            
            gt = gradnorm_data[question]['ground_truth']
            used_B += 1
            
            # Get next sample for this question
            idx = gen_idx[question]
            if idx >= len(generation_idxs):
                active_questions.remove(question)
                continue
                
            sample_idx = generation_idxs[idx]
            answer = samples_data[question][sample_idx]
            gen_idx[question] += 1
            
            # Check if the answer solves the question
            if is_equiv(gt, answer):
                coverage_uniform[question] = 1
                active_questions.remove(question)
        
        # Calculate coverage
        coverage_ratio = sum(coverage_uniform.values()) / len(gradnorm_data)
        seed_coverage_ratio.append(coverage_ratio)
            
    random_coverages_mean.append(np.mean(seed_coverage_ratio))
all_coverage['random'] = random_coverages_mean


### DIPA ###
for metric in ['Length']:
    all_coverage['Generation_Length'] = run_adaptive_metric_experiment(metric, BUDGET_OPTIONS, gradnorm_data, samples_data, INIT_ALLOC, n_seeds=3)


### Easy to Hard ###
all_coverage['e2h'] = easyhard_metric_experiment('Length', True, BUDGET_OPTIONS, gradnorm_data, samples_data, INIT_ALLOC, n_seeds=3)
all_coverage['h2e'] = easyhard_metric_experiment('Length', False, BUDGET_OPTIONS, gradnorm_data, samples_data, INIT_ALLOC, n_seeds=3)


# Make the value in all_coverage has maximal 3 decimal points
for k, v in all_coverage.items():
    all_coverage[k] = [round(i, 3) for i in v]

print(all_coverage)
