import argparse
from collections import namedtuple
from functools import partial 

from hydra.utils import instantiate
from omegaconf import OmegaConf
import torch
from torch.utils.data import DataLoader

from agent import Agent
from data import BatchSampler, collate_segments_to_batch, EpisodeDataset
from envs import Env, WorldModelEnv as DWMWorldModelEnv
from game import Game
from game.keymap import get_keymap_and_action_names
from game import PlayEnv
from iris.gpt import GPT
from iris.tokenizer import Encoder, Decoder, Tokenizer
from iris.world_model_env import WorldModelEnv as IrisWorldModelEnv
from models.actor_critic import ActorCritic
from models.diffuser import WorldModel


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--fps', type=int, default=10)
    parser.add_argument('--horizon', type=int, default=1000)
    parser.add_argument('--no-header', action='store_true')
    return parser.parse_args()


@torch.no_grad()
def main():
    args = parse_args()
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

    cfg = OmegaConf.load('config/config.yaml')
    env = Env(partial(instantiate, config=cfg.env.train), num_envs=1, device=device)
    
    h, w = 64, 64
    multiplier = 800 // h
    size = [h * multiplier, w * multiplier]
    
    ac = ActorCritic(env.num_actions)

    # DWM
    c = instantiate(cfg.world_model)
    c.num_actions = env.num_actions
    wm = WorldModel(c)
    
    # IRIS
    c = cfg.tokenizer
    edc = instantiate(c.enc_dec_config)
    tok = Tokenizer(vocab_size=c.vocab_size, embed_dim=c.embed_dim, encoder=Encoder(edc), decoder=Decoder(edc), with_lpips=False).to(device)
    gpt = GPT(obs_vocab_size=c.vocab_size, act_vocab_size=env.num_actions, config=instantiate(cfg.gpt)).to(device)
    
    ac.load_state_dict(torch.load('checkpoints/expert.pt', map_location=device))
    wm.load_state_dict(torch.load('checkpoints/dwm.pt', map_location=device))
    tok.load_state_dict(torch.load('checkpoints/iris_tok.pt', map_location=device))
    gpt.load_state_dict(torch.load('checkpoints/iris_gpt.pt', map_location=device))

    agent = Agent(wm, ac, tok, gpt).to(device)
    agent.eval()

    dataset = EpisodeDataset(directory='dataset')
    bs = BatchSampler(dataset=dataset, batch_size=1, sequence_length=cfg.world_model.num_steps_conditioning, can_sample_beyond_end=False)
    dl = DataLoader(dataset, batch_sampler=bs, collate_fn=collate_segments_to_batch, pin_memory=True)
    
    dwm_wm_env = DWMWorldModelEnv(wm, dl, horizon=args.horizon)
    iris_wm_env = IrisWorldModelEnv(tok, gpt, dl, horizon=args.horizon)

    NamedEnv = namedtuple('NamedEnv', 'name env')

    envs = [
        NamedEnv('DWM', dwm_wm_env),
        NamedEnv('IRIS', iris_wm_env),
        NamedEnv('Real', env),
    ]

    keymap, action_names = get_keymap_and_action_names(cfg.env.keymap)

    play_env = PlayEnv(agent, envs, action_names) 

    game = Game(play_env, keymap, action_names, size=size, fps=args.fps, verbose=not args.no_header, record_mode=False)
    game.run()


if __name__ == "__main__":
    main()
