import argparse
from functools import partial 
from pathlib import Path

from hydra.utils import instantiate
from omegaconf import OmegaConf
import torch
from torch.utils.data import DataLoader

from agent import Agent
from data import BatchSampler, collate_segments_to_batch, EpisodeDataset
from envs import Env, WorldModelEnv
from game import Game
from game.keymap import get_keymap_and_action_names
from game import PlayEnv
from models.actor_critic import ActorCritic
from models.diffuser import WorldModel


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--fps', type=int, default=5)
    parser.add_argument('--no-header', action='store_true')
    return parser.parse_args()


@torch.no_grad()
def main():
    args = parse_args()
    cfg = OmegaConf.load('config/trainer.yaml')
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

    train_env = Env(partial(instantiate, config=cfg.env.train), num_envs=1, device=device)
    test_env = Env(partial(instantiate, config=cfg.env.test), num_envs=1, device=device)
    
    h, w = 64, 64
    multiplier = 800 // h
    size = [h * multiplier, w * multiplier]
    
    dataset_dir = Path(cfg.collection.path_to_static_dataset).expanduser() if cfg.collection.path_to_static_dataset is not None else Path('dataset')
    train_dataset = EpisodeDataset(directory=dataset_dir / 'train', name='train_dataset')
    test_dataset  = EpisodeDataset(directory=dataset_dir / 'test' , name='test_dataset')

    c = instantiate(cfg.world_model)
    c.num_actions = test_env.num_actions
    wm = WorldModel(c)
    ac = ActorCritic(test_env.num_actions)
    agent = Agent(wm, ac).to(device)
    agent.eval()

    bs = BatchSampler(dataset=train_dataset, batch_size=1, sequence_length=cfg.world_model.num_steps_conditioning, can_sample_beyond_end=False)
    dl = DataLoader(train_dataset, batch_sampler=bs, collate_fn=collate_segments_to_batch, pin_memory=True)
    wm_env = WorldModelEnv(wm, dl, horizon=cfg.training.actor_critic.imagination_horizon)

    from collections import namedtuple

    NamedEnv = namedtuple('NamedEnv', 'name env')

    envs = [
        NamedEnv('test', test_env),
        NamedEnv('train', train_env),
        NamedEnv('wm', wm_env),
    ]

    keymap, action_names = get_keymap_and_action_names(cfg.env.keymap)

    play_env = PlayEnv(agent, envs, action_names) 

    game = Game(play_env, keymap, action_names, size=size, fps=args.fps, verbose=not args.no_header, record_mode=False)
    game.run()


if __name__ == "__main__":
    main()
