import os
import pickle

from medium_rl.envs.proxies.load_data import load_amp, load_gfp, load_utr
from medium_rl.envs.proxies.train_proxy import train_model
from medium_rl.utils import save_model

CLS = 0
PAD = 1
EOS = 2


def save_proxy(model_cfg, best_params, val_stats, path):
    proxy_path = os.path.join(path, "proxy.pkl")
    val_stats_path = os.path.join(path, "val_stats.pkl")

    save_model(model_cfg, best_params, proxy_path)
    with open(val_stats_path, "wb") as f:
        pickle.dump(val_stats, f)


def train_amp(path):
    x, y = load_amp()
    model_cfg = {
        "num_tokens": 20 + 3,  # 20 amino acids + CLS + PAD + EOS
        "embed_dim": 64,
        "hid_dim": 64,
        "output_dim": 1,  # Either binary classification or single regression
        "num_layers": 4,
        "num_head": 8,
        "pad_token_idx": PAD,
        "dropout": 0,
    }
    model_cfg, trained_state, best_params, val_stats = train_model(
        x,
        y,
        model_cfg,
        type="classification",
        batch_size=256,
        learning_rate=1e-4,
        max_epochs=250,
        val_percent=0.2,
        weight_decay=1e-6,
        patience=15,
        seed=0,
    )
    save_proxy(model_cfg, best_params, val_stats, path)


def train_gfp(path):
    x, y = load_gfp()
    model_cfg = {
        "num_tokens": 20 + 3,  # 20 amino acids + CLS + PAD + EOS
        "embed_dim": 128,
        "hid_dim": 128,
        "output_dim": 1,  # Either binary classification or single regression
        "num_layers": 3,
        "num_head": 8,
        "pad_token_idx": PAD,
        "dropout": 0,
    }
    model_cfg, trained_state, best_params, val_stats = train_model(
        x,
        y,
        model_cfg,
        type="regression",
        batch_size=128,
        learning_rate=1e-5,
        max_epochs=250,
        val_percent=0.2,
        weight_decay=1e-6,
        patience=5,
        seed=0,
    )
    save_proxy(model_cfg, best_params, val_stats, path)


def train_utr(path):
    x, y = load_utr()
    model_cfg = {
        "num_tokens": 20 + 3,  # 20 amino acids + CLS + PAD + EOS
        "embed_dim": 64,
        "hid_dim": 64,
        "output_dim": 1,  # Either binary classification or single regression
        "num_layers": 4,
        "num_head": 8,
        "pad_token_idx": PAD,
        "dropout": 0,
    }
    model_cfg, trained_state, best_params, val_stats = train_model(
        x,
        y,
        model_cfg,
        type="regression",
        batch_size=128,
        learning_rate=1e-4,
        max_epochs=250,
        val_percent=0.2,
        weight_decay=1e-6,
        patience=15,
        seed=0,
    )
    save_proxy(model_cfg, best_params, val_stats, path)


if __name__ == "__main__":
    train_utr("src/medium_rl/envs/proxies/utr/")  # 0.14 val loss
    train_amp("src/medium_rl/envs/proxies/amp/")  # 0.91 val accuracy
    train_gfp("src/medium_rl/envs/proxies/gfp/")
