import pickle
import numpy as np

def make_fpd_split_from_existing_7b_ins(full_arr):
    mmlu_start = 0
    bbh_start = 57
    truthful_qa_start = 57 + 27
    dolly_start = 57 + 27 + 32 

    with open('stats/olmo-7b-ins/fpd-split-olmo-7b-ins-id.pkl','rb') as f:
        base_info = pickle.load(f)
    base_info['train_mat'] = full_arr[base_info['train_ocl_idxs']]
    base_info['test_mat'] = full_arr[base_info['test_ocl_idxs']]
    base_info['pt_task_info']['cat'] = 'flan_v2'

    with open('stats/olmo-7b-ins/fpd-split-flan-bin-olmo-7b-ins-id.pkl','wb') as wf:
        pickle.dump(base_info,wf)

    with open('stats/olmo-7b-ins/fpd-split-olmo-7b-ins-ood-truthful_qa.pkl','rb') as f:
        base_ood1_info = pickle.load(f)

    with open('stats/olmo-7b-ins/fpd-split-olmo-7b-ins-ood-dolly.pkl','rb') as f:
        base_ood2_info = pickle.load(f)

    new_ood_info = {}
    new_ood_info['train_mat'] = full_arr[base_ood1_info['train_ocl_idxs']]
    new_ood_info['test_mat'] = full_arr[truthful_qa_start:]
    new_ood_info['train_ocl_idxs'] = base_ood1_info['train_ocl_idxs']
    new_ood_info['test_ocl_idxs'] = base_ood1_info['test_ocl_idxs'] + base_ood2_info['test_ocl_idxs']
    new_ood_info['train_ocl_task_info'] = base_ood1_info['train_ocl_task_info']
    new_ood_info['test_ocl_task_info'] = base_ood1_info['test_ocl_task_info'] + base_ood2_info['test_ocl_task_info']
    new_ood_info['pt_task_info'] = base_ood1_info['pt_task_info']
    new_ood_info['pt_task_info']['cat'] = 'flan_v2'
    
    with open('stats/olmo-7b-ins/fpd-split-flan-bin-olmo-7b-ins-ood.pkl','wb') as wf:
        pickle.dump(new_ood_info,wf)
    

    