import os
import torch
import numpy as np
from argparse import Namespace

from large_rl.embedding.base import BaseEmbedding
from large_rl.commons.seeds import set_randomSeed


def train(model, emb: BaseEmbedding, args: Namespace):
    print("=== Train Model ===")
    model.train()
    for epoch in range(args.num_epochs):
        _in = emb.sample(num_samples=args.batch_size, if_np=False)
        out = model(_in)
        loss = model.loss_function(*out)
        if ((epoch + 1) % 500) == 0: print(f"epoch: {epoch} loss: {loss}")
    return model


def main(args):
    # Set the random seed
    set_randomSeed(seed=args.seed)

    # make the directory if it doesn't exist yet
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    emb = BaseEmbedding(num_embeddings=args.num_all_actions,
                        # dim_embed=args.dim_deep,
                        embed_type="pretrained",
                        embed_path=os.path.join(args.recsim_data_dir, f"item_attr.npy"),
                        device=args.device)
    from large_rl.embedding.VAE.vanilla_vae import VanillaVAE as Model
    model = Model(in_channels=emb.shape[-1], latent_dim=args.recsim_dim_embed)
    model = train(model=model, emb=emb, args=args)

    _in = emb.get_all(if_np=False)
    _, emb = model(_in, return_embedding=True)
    print(f"Pretrained Embedding: {emb.shape}, Mean: {emb.mean()}, Std: {emb.std()}")
    np.save(file=os.path.join(args.save_dir, f"item"), arr=emb)
    torch.save(model.state_dict(), os.path.join(args.save_dir, "vae_weight.pkl"))


if __name__ == '__main__':
    from large_rl.commons.args import get_all_args, add_args

    # Get the hyper-params
    args = get_all_args()

    # =========== DEBUG =======================
    args.env_name = "recsim"
    args = add_args(args=args)
    args.device = "cpu"
    args.num_epochs = 1000
    # =========== DEBUG =======================

    args.recsim_data_dir = "data/movielens/ml_100k/ml-100k"
    args.batch_size = 64
    args.save_dir = f"{args.recsim_data_dir}/trained_weight/"
    args.recsim_data_dir = os.path.join(DATASET_PATH, args.recsim_data_dir)
    args.user_embedding_path = os.path.join(DATASET_PATH, args.recsim_data_dir, "user_attr.npy")
    args.save_dir = os.path.join(DATASET_PATH, args.save_dir)
    print(args)

    main(args=args)
