import pickle
import os
import torch
import numpy as np

def cosdist(a,b):
    a_norm = a / np.linalg.norm(a, axis=1, keepdims=True)
    b_norm = b / np.linalg.norm(b, axis=1, keepdims=True)
    cosdist_res = np.matmul(a_norm, b_norm.T)
    return cosdist_res

def load_pt_reps(model_size='1b'):
    filename = f'runs/olmo-{model_size}-ft/grad-store/dolma_sample-2k-fix/task_/all_repsNone.pkl'
    with open(filename,'rb') as f:
        obj = pickle.load(f)
    return obj

def load_ocl_reps(task_type, model_size='1b'):
    base_dir = f'runs/olmo-{model_size}-ft/grad-store/{task_type}-fix'
    n_tasks = len(os.listdir(base_dir))
    all_vecs = []

    for task_id in range(n_tasks):
        filename = os.path.join(base_dir, f'task_{task_id}', 'all_repsNone.pkl')
        with open(filename,'rb') as f:
            obj = pickle.load(f)
        avg_grad_vec = obj.mean(0)
        all_vecs.append(avg_grad_vec)
    all_vecs = torch.stack(all_vecs)
    return all_vecs

def get_cos_dist_mat(ocl_tasks, model_size='1b'):
    ocl_mats = []
    for ocl_task in ocl_tasks:
        ocl_mats.append(load_ocl_reps(ocl_task, model_size))
    ocl_mats = torch.cat(ocl_mats)

    pt_mats = load_pt_reps(model_size)
    ocl_mats, pt_mats = ocl_mats.numpy(), pt_mats.numpy()
    ocl_pt_cosdist = cosdist(ocl_mats, pt_mats)
    return ocl_pt_cosdist
# runs/olmo-1b-ft/grad-store/flan-fix/task_2