import gzip
import numpy as np
import os
import pandas as pd
import torch


def save_gzip_file(tensor, fname):
    tct = tensor.cpu().contiguous()
    with gzip.open(fname, "wb", compresslevel=9) as f:
        torch.save(tct, f)


def load_gzip_file(fname):
    with gzip.open(fname, "rb") as f:
        return torch.load(f, map_location="cpu", weights_only=True)

def process_math_id(id, reverse=False):
    if not reverse:
        return id.split('.json')[0].replace('/', '-')
    else:
        return id.replace('-', '/') + ".json"


def convert_math_data_setting_to_str(l, w, s, i, m, d):
    return f"l{l}_w{w}_s{s}_i{i}_m{m}_d{d}"


def convert_data_setting_to_str(v, r, l, w, s, i, e):
    return f"v{v}_r{r}_l{l}_w{w}_s{s}_i{i}_e{e}"


def convert_model_setting_to_str(arch, t, n, f, o, p=None, b=None, i=None, u=None, final=None, weight_decay=None, lr=None):
    if arch == 'transformer':
        return f"{arch}_t{t}_n{n}_f{f}_o{o}_b{b}_wd{weight_decay}_lr{lr}"
    else:
        i_str = '-'.join(map(str, i)) if i else 'None'
        u_str = '-'.join(map(str, u)) if u else 'None'
        final_str = '-'.join(map(str, final)) if final else 'None'
        return f"{arch}_t{t}_p{p}_b{b}_i{i_str}_u{u_str}_f{final_str}_wd{weight_decay}_lr{lr}"
    
def load_val_predictions(root_dir, model_dir, epoch,):
    val_info = pd.read_csv(os.path.join(root_dir, "model", model_dir, "val.info.csv"))
    val_pred = np.load(os.path.join(root_dir, "model", model_dir, f"val.pred.{epoch}.npz"))['predictions']
    # test_labels = torch.load(os.path.join(root_dir, "tensor", data_dir, "test.label.pt"), weights_only=True, map_location='cpu')
    # test_labels = test_labels.detach().cpu().numpy()
    # test_pred = test_preds[epoch, :]
    val_info['pred'] = val_pred
    # test_info['label'] = test_labels
    return val_info


def load_test_predictions(root_dir, model_dir, epoch,):
    test_info = pd.read_csv(os.path.join(root_dir, "model", model_dir, "test.info.csv"))
    test_pred = np.load(os.path.join(root_dir, "model", model_dir, f"test.pred.{epoch}.npz"))['predictions']
    # test_labels = torch.load(os.path.join(root_dir, "tensor", data_dir, "test.label.pt"), weights_only=True, map_location='cpu')
    # test_labels = test_labels.detach().cpu().numpy()
    # test_pred = test_preds[epoch, :]
    test_info['pred'] = test_pred
    # test_info['label'] = test_labels
    return test_info


def load_test_prompt_predictions(root_dir, model_dir, epoch,):
    test_info = pd.read_csv(os.path.join(root_dir, "model_prompt", model_dir, "test.info.csv"))
    test_pred = np.load(os.path.join(root_dir, "model_prompt", model_dir, "test_preds.npz"))['predictions']
    # test_labels = torch.load(os.path.join(root_dir, "tensor", data_dir, "test.label.pt"), weights_only=True, map_location='cpu')
    # test_labels = test_labels.detach().cpu().numpy()
    # test_pred = test_preds[epoch, :]
    test_pred = test_pred[epoch, :]
    test_pred = np.argmax(test_pred, axis=1)
    test_info['pred'] = test_pred
    # test_info['label'] = test_labels
    return test_info
