import argparse
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle as pkl
import policy
import warnings
from config import *
from create_tensor import load_metadata
from tqdm import tqdm
from utils import load_test_predictions, load_val_predictions


def plot_gsm8k(model_id):
    warnings.filterwarnings("ignore", category=DeprecationWarning)
    model_name = MODEL_IDS[model_id]
    os.makedirs(f'figs/gsm8k/{model_name}', exist_ok=True)
    
    gsm_full = load_metadata('gsm8k', model_id, for_training=False)
    gsm_test = gsm_full[gsm_full.train == 0]
    
    n_trials = 10
    n_chains = list(range(5, 11, 1))

    # Baseline
    if not os.path.exists(f'figs/gsm8k/{model_name}/baseline_exp.pkl'):
        baseline_exp = list()
        for n_chain in tqdm(n_chains, desc='Baseline'):
            chain_exps = list()
            for _ in range(n_trials):
                res = policy.baseline(gsm_test, n_chain)
                chain_exps.append(res)
            baseline_exp.append(chain_exps)
        
        with open(f'figs/gsm8k/{model_name}/baseline_exp.pkl', 'wb') as f:
            pkl.dump(baseline_exp, f)
    
    # Dynasor
    if not os.path.exists(f'figs/gsm8k/{model_name}/dynasor_exp.pkl'):
        dynasor_exp = list()
        for n_chain in tqdm(n_chains, desc='Dynasor'):
            chain_exps = list()
            for _ in range(n_trials):
                res = policy.dynasor(gsm_test, n_chain, 64, 5)
                chain_exps.append(res)
            dynasor_exp.append(chain_exps)

        with open(f'figs/gsm8k/{model_name}/dynasor_exp.pkl', 'wb') as f:
            pkl.dump(dynasor_exp, f)
    
    # Short-m@k
    if not os.path.exists(f'figs/gsm8k/{model_name}/shortm_exp.pkl'):
        shortm_exp = list()
        for n_chain in tqdm(n_chains, desc='Short-m@k'):
            chain_exps = list()
            for _ in range(n_trials):
                res = policy.short_m(gsm_test, n_chain, 0.6)
                chain_exps.append(res)
            shortm_exp.append(chain_exps)
        
        with open(f'figs/gsm8k/{model_name}/shortm_exp.pkl', 'wb') as f:
            pkl.dump(shortm_exp, f)
    
    # Our method
    if not os.path.exists(f'figs/gsm8k/{model_name}/greedyprob_exp.pkl'):
        # load predictions
        # selected_model, selected_epoch = "L20_test_mlp", 3
        selected_model, selected_epoch = "L14_mlp", 2
        val_preds = load_val_predictions(os.path.join(GSM8K_DIR, MODEL_IDS[model_id]), selected_model, selected_epoch)
        test_preds = load_test_predictions(os.path.join(GSM8K_DIR, MODEL_IDS[model_id]), selected_model, selected_epoch)
        default_pred = val_preds.pred.median()
        threshold = np.percentile(val_preds.pred.values, 70)
        gsm_test_wpred = gsm_test.merge(test_preds, on=['unique_id', 'chain_id', 'tokens'], how='left')
        gsm_test_wpred['pred'] = gsm_test_wpred['pred'].fillna(default_pred)

        greedyprob_exp = list()
        for n_chain in tqdm(n_chains, desc='Ours (greedy proba)'):
            chain_exps = list()
            for _ in range(n_trials):
                res = policy.duchess(
                    gsm_test_wpred, n_chain, threshold, -1, 3, 2, 1, 16, 0.6, 0.8, 'greedy_prob', False)
                chain_exps.append(res)
            greedyprob_exp.append(chain_exps)
        
        with open(f'figs/gsm8k/{model_name}/greedyprob_exp.pkl', 'wb') as f:
            pkl.dump(greedyprob_exp, f)


def plot_mmlu(model_id):
    warnings.filterwarnings("ignore", category=DeprecationWarning)
    model_name = MODEL_IDS[model_id]
    os.makedirs(f'figs/mmlu/{model_name}', exist_ok=True)
    
    mmlu_full = load_metadata('mmlu', model_id, for_training=False)
    mmlu_test = mmlu_full[mmlu_full.train == 0]
    
    n_trials = 10
    n_chains = list(range(5, 11, 1))

    # Baseline
    if not os.path.exists(f'figs/mmlu/{model_name}/baseline_exp.pkl'):
        baseline_exp = list()
        for n_chain in tqdm(n_chains, desc='Baseline'):
            chain_exps = list()
            for _ in range(n_trials):
                res = policy.baseline(mmlu_test, n_chain)
                chain_exps.append(res)
            baseline_exp.append(chain_exps)
        
        with open(f'figs/mmlu/{model_name}/baseline_exp.pkl', 'wb') as f:
            pkl.dump(baseline_exp, f)
    
    # Dynasor
    if not os.path.exists(f'figs/mmlu/{model_name}/dynasor_exp.pkl'):
        dynasor_exp = list()
        for n_chain in tqdm(n_chains, desc='Dynasor'):
            chain_exps = list()
            for _ in range(n_trials):
                res = policy.dynasor(mmlu_test, n_chain, 64, 5)
                chain_exps.append(res)
            dynasor_exp.append(chain_exps)

        with open(f'figs/mmlu/{model_name}/dynasor_exp.pkl', 'wb') as f:
            pkl.dump(dynasor_exp, f)
    
    # Short-m@k
    if not os.path.exists(f'figs/mmlu/{model_name}/shortm_exp.pkl'):
        shortm_exp = list()
        for n_chain in tqdm(n_chains, desc='Short-m@k'):
            chain_exps = list()
            for _ in range(n_trials):
                res = policy.short_m(mmlu_test, n_chain, 0.6)
                chain_exps.append(res)
            shortm_exp.append(chain_exps)
        
        with open(f'figs/mmlu/{model_name}/shortm_exp.pkl', 'wb') as f:
            pkl.dump(shortm_exp, f)
    
    if not os.path.exists(f'figs/mmlu/{model_name}/greedyprob_exp.pkl'):
        # load predictions
        selected_model, selected_epoch = "L14_mlp", 2
        val_preds = load_val_predictions(os.path.join(MMLU_DIR, MODEL_IDS[model_id]), selected_model, selected_epoch)
        test_preds = load_test_predictions(os.path.join(MMLU_DIR, MODEL_IDS[model_id]), selected_model, selected_epoch)
        default_pred = val_preds.pred.median()
        threshold = np.percentile(val_preds.pred.values, 80)
        mmlu_test_wpred = mmlu_test.merge(test_preds, on=['unique_id', 'chain_id', 'tokens'], how='left')
        mmlu_test_wpred['pred'] = mmlu_test_wpred['pred'].fillna(default_pred)

        greedyprob_exp = list()
        for n_chain in tqdm(n_chains, desc='Ours (greedy proba)'):
            chain_exps = list()
            for _ in range(n_trials):
                res = policy.duchess(
                    mmlu_test_wpred, n_chain, threshold, -1, 0, 2, 1, 80, 0.4, 1, 'greedy_prob', False, sample_lambda=0.8)
                chain_exps.append(res)
            greedyprob_exp.append(chain_exps)
        
        with open(f'figs/mmlu/{model_name}/greedyprob_exp.pkl', 'wb') as f:
            pkl.dump(greedyprob_exp, f)


def plot_math(model_id):
    warnings.filterwarnings("ignore", category=DeprecationWarning)
    model_name = MODEL_IDS[model_id]
    os.makedirs(f'figs/math/{model_name}', exist_ok=True)
    
    math_full = load_metadata('math', model_id, for_training=False)
    math_test = math_full[math_full.train == 0]
    
    n_trials = 10
    n_chains = list(range(5, 11, 1))
    
    # Baseline
    if not os.path.exists(f'figs/math/{model_name}/baseline_exp.pkl'):
        baseline_exp = list()
        for n_chain in tqdm(n_chains, desc='Baseline'):
            chain_exps = list()
            for _ in range(n_trials):
                res = policy.baseline(math_test, n_chain)
                chain_exps.append(res)
            baseline_exp.append(chain_exps)
        
        with open(f'figs/math/{model_name}/baseline_exp.pkl', 'wb') as f:
            pkl.dump(baseline_exp, f)
    
    # Dynasor
    if not os.path.exists(f'figs/math/{model_name}/dynasor_exp.pkl'):
        dynasor_exp = list()
        for n_chain in tqdm(n_chains, desc='Dynasor'):
            chain_exps = list()
            for _ in range(n_trials):
                res = policy.dynasor(math_test, n_chain, 64, 7)
                chain_exps.append(res)
            dynasor_exp.append(chain_exps)
            
        with open(f'figs/math/{model_name}/dynasor_exp.pkl', 'wb') as f:
            pkl.dump(dynasor_exp, f)

    # Short-m@k
    if not os.path.exists(f'figs/math/{model_name}/shortm_exp.pkl'):
        shortm_exp = list()
        for n_chain in tqdm(n_chains, desc='Short-m@k'):
            chain_exps = list()
            for _ in range(n_trials):
                res = policy.short_m(math_test, n_chain, 3)
                chain_exps.append(res)
            shortm_exp.append(chain_exps)
        
        with open(f'figs/math/{model_name}/shortm_exp.pkl', 'wb') as f:
            pkl.dump(shortm_exp, f)
    
    # Our method
    if not os.path.exists(f'figs/math/{model_name}/greedyprob_exp.pkl'):
        # load predictions
        selected_model, selected_epoch = "L14_mlp", 2
        val_preds = load_val_predictions(os.path.join(MATH_DIR, MODEL_IDS[model_id]), selected_model, selected_epoch)
        test_preds = load_test_predictions(os.path.join(MATH_DIR, MODEL_IDS[model_id]), selected_model, selected_epoch)
        default_pred = val_preds.pred.median()
        threshold = np.percentile(val_preds.pred.values, 80)
        math_test_wpred = math_test.merge(test_preds, on=['unique_id', 'chain_id', 'tokens'], how='left')
        math_test_wpred['pred'] = math_test_wpred['pred'].fillna(default_pred)

        greedyprob_exp = list()
        for n_chain in tqdm(n_chains, desc='Ours (greedy proba)'):
            chain_exps = list()
            for _ in range(n_trials):
                res = policy.duchess(
                    math_test_wpred, n_chain, threshold, -1, 0, 2, 2, 80, 0.6, 0.8, 'greedy_prob', False, sample_lambda=0.8)
                chain_exps.append(res)
            greedyprob_exp.append(chain_exps)
        
        with open(f'figs/math/{model_name}/greedyprob_exp.pkl', 'wb') as f:
            pkl.dump(greedyprob_exp, f)



def plot(dataset, model_id):
    func = None
    if dataset == 'mmlu':
        func = plot_mmlu
    elif dataset == 'gsm8k':
        func = plot_gsm8k
    elif dataset == 'math':
        func = plot_math
    
    func(model_id)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--dataset", type=str, default='all', required=False)
    parser.add_argument("-m", "--model", type=str, default='all', required=False)
    
    os.system('mkdir -p figs')

    args = parser.parse_args()
    if args.dataset == 'all':
        datasets = ['mmlu', 'gsm8k', 'math']
    else:
        datasets = [args.dataset]
    
    if args.model == 'all':
        models = ['deepseek-ai/DeepSeek-R1-Distill-Llama-8B',]
    else:
        models = [args.model]
    
    for dataset in datasets:
        for model in models:
            plot(dataset, model)
