import torch
from torch import nn
from typing import Union, Dict, Any, List
from model.transformer_xl import TransformerXL
from model.llama import TrajLlama
from dataloader.code.input_specs import RLTaskInput
import argparse

class Gato(nn.Module):

    def __init__(self, config: Union[argparse.Namespace, Dict[str, Any]]):
        super(Gato, self).__init__()
        self.config = config
        
        self.transformer = None
        if config.model == 'transformer_xl':
            self.transformer = TransformerXL(config)
        elif config.model == 'unoffical_gato':
            raise ValueError('The currently implemented unofficial gato model has a bug in the attention mechanism')
        elif config.model == 'llama':
            self.transformer = TrajLlama(config)
        else:
            raise NotImplementedError

    def forward(self, inputs:RLTaskInput, compute_loss:bool=True, mems=None, batch_dataset_name:List = None, batch_raw_obs:Union[List, dict]=None):
        logits, loss, loss_datasets, new_mems = self.transformer(inputs, compute_loss, mems, batch_dataset_name, batch_raw_obs)
        return logits, loss, loss_datasets, new_mems