import tyro
import torch
import random
import numpy as np
from trainer import Trainer
from configs import Args, MAMUJOCO_ENV_NAMES, SMACV1_ENV_NAMES, SMACV2_ENV_NAMES


def main():
    env_name = args.env_name
    exsize = args.exsize
    seed = args.seed
    use_llm = args.use_llm

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.set_num_threads(1)
    
    if env_name in MAMUJOCO_ENV_NAMES:
        from buffers.buffer_continuous import ReplayBuffer
    elif env_name in SMACV1_ENV_NAMES or env_name in SMACV2_ENV_NAMES:
        from buffers.buffer_discrete import ReplayBuffer
    else:
        raise NotImplementedError(f"Environment {env_name} not supported.")

    if env_name in MAMUJOCO_ENV_NAMES:
        from algos.misodice_continuous import DemoDICE as Method
    elif env_name in SMACV1_ENV_NAMES or env_name in SMACV2_ENV_NAMES:
        from algos.misodice_discrete import DemoDICE as Method
    else:
        raise NotImplementedError(f"Environment {env_name} not supported.")

    buffer = ReplayBuffer.from_h5py(env_name, exsize, use_llm=use_llm)
    st_dim = buffer.st_dim
    ob_dim = buffer.ob_dim
    ac_dim = buffer.ac_dim
    n_agents = buffer.n_agents

    imitator = Method(st_dim, ob_dim, ac_dim, n_agents, config=args)
    
    print(f"st_dim: {st_dim}, ob_dim: {ob_dim}, ac_dim: {ac_dim}, n_agents: {n_agents}")

    trainer = Trainer(imitator, config=args)
    trainer.train(buffer)
    
    
if __name__ == "__main__":
    args = tyro.cli(Args)
    main()