import os, sys, random
import numpy as np 
import matplotlib.pyplot as plt 

ROOT = "/u/audreyh/workspace/test-code"
OUT_ROOT = "/work/hdd/bdkj/audreyh"
sys.path.append(os.path.join(ROOT, 'code'))
import helpers.io as io

REWARDS=[
    "oasst-rm", 
    "rm-gemma-2b", 
    "grm-llama-3b",
    "eurus-rm-7b",
    "rm-mistral-7b", 
    "beaver-7b", 
    "armo-rm", 
    "fsfairx-8b", 
]

def collate(outputs, rf, key): 
    print(f"Loading rewards from {rf}")
    rewards = np.load(rf)
    for i, output in enumerate(outputs): 
        output[key] = rewards[i]
        if 'response' in output: 
            del output['response']
    return outputs

def reward_histograms(outputs, **kwargs): 
    fig, axs = plt.subplots(len(REWARDS), 1, figsize=(8, len(REWARDS)*4))
    fig.suptitle(f"GSM8k reward distributions under {kwargs['policy']} with C={kwargs['cov']}")
    for reward, ax in zip(REWARDS,axs): 
        rmax = max(np.abs([output[reward] for output in outputs]))
        for label in [True, False]: 
            rewards = np.array([output[reward] for output in outputs if output['correct'] == label]) / rmax
            ax.hist(rewards, bins=50, alpha=1-0.5*(1-label),density=True)
            avg = rewards.mean()
            ax.axvline(avg, color='k', linestyle='dashed', linewidth=1, alpha=1-0.8*(1-label))
        ax.set_title(f"{reward}")
    plt.savefig(os.path.join(ROOT, 'scratch_notebooks/piref-out', f"{kwargs['policy']}-rewards.png"))

def get_first_correct(outputs): 
    corrects = [output['correct'] for output in outputs]
    try:
        index = corrects.index(True)
        print(f'Correct found at {index}')
    except: 
        print('No corrects found')
        index = len(corrects)
    return index 

def main(): 
    seeds = [102, 103, 104, 105, 106]    
    policy = 'gemma-2-2b'
    task = 'gsm8k'
    root = os.path.join(OUT_ROOT, f"data/{task}/{policy}")  

    k=20
    outputs = [] 
    
    for seed in seeds: 
        prefix = f"{task}-{policy}--shots-0-seed-{seed}"
        gf = os.path.join(root, "generations", f"{prefix}-generations.json")
        _outputs = io.json_load(gf)
        outputs += _outputs
    random.shuffle(outputs)
    cov = 0 
    for index, sublist in enumerate([outputs[i:i + k] for i in range(0, len(outputs), k)]):
        first_correct = get_first_correct(sublist) + 1
        cov += first_correct
    cov = cov / index
    print(f"Coverage: {cov}")   
    del outputs
        

    outputs = []
    for seed in seeds: 
        prefix = f"{task}-{policy}--shots-0-seed-{seed}"
        gf = os.path.join(root, "generations", f"{prefix}-generations.json")
        _outputs = io.json_load(gf)
        for reward in REWARDS:
            rf = os.path.join(root, reward, f"{prefix}-{reward}-rewards.npy")   
            if not os.path.exists(rf):  
                continue
            _outputs = collate(_outputs, rf, reward)
        outputs += _outputs
    # io.json_dump(outputs, os.path.join(ROOT, 'scratch_notebooks/piref-out', f"{policy}-outputs.json"))
    
    reward_histograms(outputs, policy=policy, cov=cov)

if __name__ == "__main__": 
    main()