import numpy as np
from data_utils.utils import deterministic_random_indices
from utils.config import load_configs
from data_utils import load_raw_ds
import re
import string
import pickle

def filter_nan(arr):
    return arr[~np.isnan(arr)]

def get_base_ppl_arr(model_name, return_ss=False):
    if model_name in ['7b-10k','olmo-7b-10k']:
        ppl_base_arr = np.load('runs/stats/stats-olmo-7b-ft/flan-1k/task_0/pt-base_ppl_results.pkl.npy')
        #rand_idxs = deterministic_random_indices(ppl_base_arr.shape[0], 10000)
        #ppl_base_arr_ss = ppl_base_arr[rand_idxs]
    elif model_name == ['1b-10k','olmo-1b-10k']:
        ppl_base_arr = np.load('runs/stats/stats-olmo-1b-ft/flan-1k/task_0/pt-base_ppl_results.pkl.npy')
        #rand_idxs = deterministic_random_indices(ppl_base_arr.shape[0], 10000)
        #ppl_base_arr_ss = ppl_base_arr[rand_idxs]  
    elif model_name == '7b-ins':
        ppl_base_arr = np.load('runs/stats/stats-olmo-7b-ins-ft-test/ft/dolly/task_0/pt-base_ppl_results.pkl.npy')
    elif model_name == 'mpt-7b':
        ppl_base_arr = np.load('runs/stats/stats-mpt-7b-ft-lr2e-6/flan-1k/task_0/pt-subsample10000_ppl_results.pkl.npy')
    elif model_name == 'olmo2-7b':
        ppl_base_arr = np.load('runs/stats/stats-olmo2-1124-7b-ft/flan-1k-lr2e-6/task_0/pt-base_ppl_results.pkl.npy')
    elif model_name == 'olmo2-13b':
        ppl_base_arr = np.load('runs/stats/stats-olmo2-1124-13b-ft/flan-1k-lr2e-6/task_0/pt-base_ppl_results.pkl.npy')
    elif model_name == 'pythia-1b':
        ppl_base_arr = np.load('runs/stats/stats-pythia/1b/tulu_train-1k-lr2e-6/task_0/pt-base_ppl_results.pkl.npy')
    elif model_name == 'pythia-7b':
        ppl_base_arr = np.load('runs/stats/stats-pythia/6.9b/tulu_train-1k-lr2e-6/task_0/pt-base_ppl_results.pkl.npy')
    elif model_name == 'pythia-12b':
        ppl_base_arr = np.load('runs/stats/stats-pythia/12b/tulu_train-1k-lr2e-6/task_0/pt-base_ppl_results.pkl.npy')
    else:
        raise NotImplementedError
    
    if return_ss:
        rand_idxs = deterministic_random_indices(ppl_base_arr.shape[0], 10000)
        ppl_base_arr = ppl_base_arr[rand_idxs]

    return np.abs(ppl_base_arr)

def get_ft_ppl_arr(model_name, task_cat, task_id, return_ss=False):
    if model_name in ['7b-10k','olmo-7b-10k']:
        ppl_arr =  np.load(f'runs/stats/stats-olmo-7b-ft/{task_cat}-1k-lr2e-6/task_{task_id}/pt_ppl_results.pkl.npy')
        #rand_idxs = deterministic_random_indices(ppl_arr.shape[0], 10000)
        #ppl_arr_ss = ppl_arr[rand_idxs]
    elif model_name in ['1b-10k','olmo-1b-10k']:
        ppl_arr =  np.load(f'runs/stats/stats-olmo-1b-ft/{task_cat}-1k-lr2e-6/task_{task_id}/pt_ppl_results.pkl.npy')
        #rand_idxs = deterministic_random_indices(ppl_arr.shape[0], 10000)
        #ppl_arr_ss = ppl_arr[rand_idxs]
    elif model_name == '7b-ins':
        ppl_arr = np.load(f'runs/stats/stats-olmo-7b-ins-ft-test/ft/{task_cat}/task_{task_id}/pt_ppl_results.pkl.npy')
    elif model_name == 'olmo2-7b':
        ppl_arr = np.load(f'runs/stats/stats-olmo2-1124-7b-ft/{task_cat}-1k-lr2e-6/task_{task_id}/pt_ppl_results.pkl.npy')
    elif model_name == 'olmo2-13b':
        ppl_arr = np.load(f'runs/stats/stats-olmo2-1124-13b-ft/{task_cat}-1k-lr2e-6/task_{task_id}/pt_ppl_results.pkl.npy')
    elif model_name == 'pythia-1b':
        ppl_arr = np.load(f'runs/stats/stats-pythia/1b/{task_cat}-1k-lr2e-6/task_{task_id}/pt_ppl_results.pkl.npy')
    elif model_name == 'pythia-7b':
        ppl_arr = np.load(f'runs/stats/stats-pythia/6.9b/{task_cat}-1k-lr2e-6/task_{task_id}/pt_ppl_results.pkl.npy')
    elif model_name == 'pythia-12b':
        ppl_arr = np.load(f'runs/stats/stats-pythia/12b/{task_cat}-1k-lr2e-6/task_{task_id}/pt_ppl_results.pkl.npy') 
    if return_ss:
        rand_idxs = deterministic_random_indices(ppl_arr.shape[0], 10000)
        ppl_arr = ppl_arr[rand_idxs]

    return np.abs(ppl_arr)
    

def get_ft_ppl_inc_pos(model_name, task_cat, task_id):
    ft_ppl_arr = get_ft_ppl_arr(model_name, task_cat, task_id)
    base_ppl_arr = get_base_ppl_arr(model_name)
    ft_ppl_inc_arr = ft_ppl_arr - base_ppl_arr
    return ft_ppl_inc_arr > 0

def compute_ppl_where_ft_inc(model_name, task_cat, task_id, path):
    ppl_arr = np.abs(np.load(path))
    ppl_inc_pos = get_ft_ppl_inc_pos(model_name, task_cat, task_id)
    ppl_arr_filter = ppl_arr[ppl_inc_pos]
    return filter_nan(ppl_arr_filter)


def compute_ppl(path, reduce='mean'):
    ppl_arr = np.abs(np.load(path))
    #base_ppl_arr = get_base_ppl_arr(model_name)
    #ppl_inc = ppl_arr - base_ppl_arr
    #print(filter_nan(ppl_arr).mean(), filter_nan(base_ppl_arr).mean())
    if reduce == 'mean':
        return filter_nan(ppl_arr).mean()
    else:
        return ppl_arr

def compute_ppl_inc(model_name,  path, reduce='mean'):
    ppl_arr = np.abs(np.load(path))
    base_ppl_arr = get_base_ppl_arr(model_name)
    ppl_inc = ppl_arr - base_ppl_arr
    #print(filter_nan(ppl_arr).mean(), filter_nan(base_ppl_arr).mean())
    if reduce == 'mean':
        return filter_nan(ppl_inc).mean()
    else:
        return ppl_inc

def compute_reduced_forgetting_perc_verbose(model_name, task_cat, task_id, path):
    ppl_inc = compute_ppl_inc(model_name, path)
    ft_ppl = get_ft_ppl_arr(model_name, task_cat, task_id)
    #print(f'Ft ppl {ft_ppl}')
    base_ppl = get_base_ppl_arr(model_name)
    #print(f'base ppl {base_ppl}')
    ft_ppl_inc = ft_ppl - base_ppl

    #print(filter_nan(ft_ppl_inc).mean())
    fgt_perc = ppl_inc / filter_nan(ft_ppl_inc).mean()
    
    return fgt_perc, ppl_inc, ft_ppl_inc, base_ppl
    
def compute_ppl(path, reduce='mean'):
    ppl_arr = np.abs(np.load(path))
    #base_ppl_arr = get_base_ppl_arr(model_name)
    #ppl_inc = ppl_arr - base_ppl_arr
    #print(filter_nan(ppl_arr).mean(), filter_nan(base_ppl_arr).mean())
    if reduce == 'mean':
        return filter_nan(ppl_arr).mean()
    else:
        return ppl_arr
    
def compute_ppl_inc_ft_positive_only(model_name, task_cat, task_id, path, reduce='mean'):
    ppl_arr = np.abs(np.load(path))
    
    ft_ppl_arr = get_ft_ppl_arr(model_name, task_cat, task_id)
    base_ppl_arr = get_base_ppl_arr(model_name)
    
    ppl_inc = ppl_arr - base_ppl_arr
    ft_ppl_inc = ft_ppl_arr - base_ppl_arr

    filter_ppl_inc = ppl_inc[ft_ppl_inc > 0]
    if reduce == 'mean':
        return filter_nan(filter_ppl_inc).mean()
    else:
        return filter_ppl_inc


def compute_ppl_ft_positive_only(model_name, task_cat, task_id, path, reduce='mean'):
    ppl_arr = np.abs(np.load(path))
    
    ft_ppl_arr = get_ft_ppl_arr(model_name, task_cat, task_id)
    base_ppl_arr = get_base_ppl_arr(model_name)
    
    #ppl_inc = ppl_arr - base_ppl_arr
    ft_ppl_inc = ft_ppl_arr - base_ppl_arr

    filter_ppl = ppl_arr[ft_ppl_inc > 0]
    if reduce == 'mean':
        return filter_nan(filter_ppl).mean()
    else:
        return filter_ppl



def compute_relative_forgetting(model_name, arr_or_path):
    if type(arr_or_path) is str:
        arr = np.load(arr_or_path)
    else:
        arr = arr_or_path
    arr = np.abs(arr)
    base_ppl_arr = np.abs(get_base_ppl_arr(model_name))

    ppl_inc = arr - base_ppl_arr
    ppl_inc_filt = ppl_inc[~np.isnan(ppl_inc)]
    ppl_inc_nz = ppl_inc_filt.copy()

    #ppl_inc_nz[ppl_inc_nz < 0] = 0

    #print(ppl_inc_nz.mean(), ppl_inc_filt.mean())

    return ppl_inc_nz.mean(0)

def pretty_print(*numbers):
    pt_numbers = ['{:.4f}'.format(x) if x is not None else None for x in numbers]
    print(*pt_numbers)

def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
        return re.sub(regex, ' ', text)
    def white_space_fix(text):
        return ' '.join(text.split())
    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)
    def lower(text):
        return text.lower()
    return white_space_fix(remove_articles(remove_punc(lower(s))))

def get_tokens(s):
    if not s: return []
    return normalize_answer(s).split()

def compute_exact(a_gold, a_pred):
  return int(normalize_answer(a_gold) == normalize_answer(a_pred))


def evaluate_flan_em_arr(lm_outputs, ds=None):
    def extract_before_two_blank_lines(text):
        return text.split('\n\n')[0]

    if ds is None:
        config = load_configs('configs/defaults.yaml', 'configs/llm/ocl_ins/stat_7b_on_flan_v2.yaml')
        ds = load_raw_ds('train', config, None, 'tulu')
    gts = [x[1] for x in ds]
    preds = [x.outputs[0].text for x in lm_outputs]
    em_scores = []

    assert len(gts) == len(preds)

    for gt, pred in zip(gts, preds):
        pred_extracted = extract_before_two_blank_lines(pred)
        score = compute_exact(gt, pred_extracted)
        em_scores.append(score)
    em_scores = np.array(em_scores, dtype=int)
    return em_scores
        
def get_7b_ins_base_output_res():
    base_res_path = 'runs/stats/stats-olmo-7b-ins-ft-test/flan-v2/ft/mmlu/task_0/pt-base_output_results.pkl'
    with open(base_res_path,'rb') as f:
        obj = pickle.load(f)
    return obj

def get_all_7b_ins_res(filter_tasks=None, replay_method=None, trunc=False, temp=None, max_task=-1, peft=False):
    task_cat2num = {
        'mmlu': 57,
        'bbh': 27,
        'truthful_qa': 32,
        'dolly': 8
    }
    if filter_tasks is not None:
        task_cat2num = {k: v for k,v in task_cat2num.items() if k == filter_tasks}
    all_outputs = []
    for task, task_num in task_cat2num.items():
        print(task)
        for task_id in range(task_num):
            if task_id == max_task:
                break
            if replay_method is None:
                if peft:
                    output_path = f'runs/stats/stats-olmo-7b-ins-peft-test/flan-v2/ft/{task}/task_{task_id}/pt_output_results.pkl'
                else:
                    output_path = f'runs/stats/stats-olmo-7b-ins-ft-test/flan-v2/ft/{task}/task_{task_id}/pt_output_results.pkl'
            else:
                if not trunc:
                    output_path = f'runs/stats/stats-olmo-7b-ins-ft-test/replay/flan_v2/{replay_method}_mix_0.125/{task}/task_{task_id}/pt_output_results.pkl'
                else:
                    output_path = f'runs/stats/stats-olmo-7b-ins-ft-test/replay/flan_v2_truncate_replay/{replay_method}_mix_0.125/{task}/task_{task_id}/pt_output_results.pkl'
                    if temp is not None:
                        output_path =  f'runs/stats/stats-olmo-7b-ins-ft-test/replay/flan_v2_truncate_replay_t{temp}/{replay_method}_mix_0.125/{task}/task_{task_id}/pt_output_results.pkl'
            with open(output_path, 'rb')  as f:
                output_obj = pickle.load(f)
            all_outputs.append(output_obj)    
    return all_outputs    

def get_7b_ins_fgt_arr():
    base_output = get_7b_ins_base_output_res()
    base_em_arr = evaluate_flan_em_arr(base_output)
    all_outputs = get_all_7b_ins_res()
    after_em_arrs = np.array([evaluate_flan_em_arr(x) for x in all_outputs])
    base_correct_mask = base_em_arr == 1
    after_em_arrs_bc = after_em_arrs[:,base_correct_mask]
    fgt = 1 - after_em_arrs_bc
    return fgt