import os
import pickle

import numpy as np

from medium_rl.envs.amp import AMP_DICT
from medium_rl.envs.gfp import GFP_DICT
from medium_rl.envs.utr import UTR_DICT


def load_amp():
    path = os.path.dirname(os.path.abspath(__file__))
    path = os.path.join(path, "amp")
    with open(os.path.join(path, "neg_amp.pkl"), "rb") as f:
        neg_amp = pickle.load(f)

    with open(os.path.join(path, "pos_amp.pkl"), "rb") as f:
        pos_amp = pickle.load(f)

    max_len = 60 + 2  # + 2 for CLS and EOS

    def tokenize(str_list):
        tokens = [[AMP_DICT["CLS"]] + [AMP_DICT[char] for char in str] + [AMP_DICT["EOS"]] for str in str_list]
        padded_arrays = [
            np.array(token_list + [AMP_DICT["PAD"]] * (max_len - len(token_list))) for token_list in tokens
        ]
        return np.stack(padded_arrays, axis=0)

    x = np.concat([tokenize(neg_amp), tokenize(pos_amp)], axis=0)
    y = np.zeros(x.shape[0])
    y[len(neg_amp) :] = 1

    return x, y


def load_gfp():
    path = os.path.dirname(os.path.abspath(__file__))
    path = os.path.join(path, "gfp")
    x = np.load(os.path.join(path, "gfp_x.npy"))
    y = np.load(os.path.join(path, "gfp_y.npy"))

    # Add special tokens
    x = x + 3
    x = np.concatenate(
        [GFP_DICT["CLS"] * np.ones((x.shape[0], 1)), x, GFP_DICT["EOS"] * np.ones((x.shape[0], 1))],
        axis=1,
    )
    return x.astype(np.int32), y


def load_utr():
    path = os.path.dirname(os.path.abspath(__file__))
    path = os.path.join(path, "utr")
    x = np.load(os.path.join(path, "utr_x.npy"))
    y = np.load(os.path.join(path, "utr_y.npy"))

    # Add special tokens
    x = x + 3
    x = np.concatenate(
        [GFP_DICT["CLS"] * np.ones((x.shape[0], 1)), x, UTR_DICT["EOS"] * np.ones((x.shape[0], 1))],
        axis=1,
    )
    return x.astype(np.int32), y
