import torch
import pytorch_lightning as pl
from lightning_transformers.task.nlp.language_modeling import (
    LanguageModelingDataModule,
    LanguageModelingTransformer,
)
import transformers
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn
import torch.nn as nn
from torch.distributions.gamma import Gamma
from data.utils import discretize_time
from torch.nn import CrossEntropyLoss
from data.dms import DMSData
from data.constants import REF_RBD_SEQ
from models import register_model
import math, logging
from models.criterions.cross_entropy import CEWeightedLoss
from esm import pretrained
from typing import IO, Any, Callable, Dict, Optional, Tuple, Type, Union
from utils.args import str2bool
from transformers import GPT2LMHeadModel
from transformers import AutoConfig, PreTrainedTokenizerBase
from collections import defaultdict
# import transformers.LogitsProcessor as LogitsProcessor

class PrependLogitsProcessor(transformers.LogitsProcessor):
    def __init__(self, real_vocab_size) -> None:
        super().__init__()
        self.real_vocab_size = real_vocab_size

    def __call__(self, input_ids, scores) -> Any:
        # print(scores.size(), input_ids.size())
        # print(scores.size())
        # print(scores.size())
        scores[:, self.real_vocab_size:].fill_(-torch.inf)
        # print(scores.size())
        # return scores[:, :self.real_vocab_size]
        return scores
        # if scores.dim() == 2:
            # return scores[:, :self.real_vocab_size]
        # elif scores.dim() == 3:
            # return scores[:, :, :self.real_vocab_size]
        


class GPT2TimeModel(transformers.GPT2LMHeadModel):
    def __init__(self, config) -> None:
        super().__init__(config)

        # if not config.zero_offset:
        self.build_offset_layer(config)
        self.config = config
    
    def build_offset_layer(self, config):
        if getattr(self.config, "transformer_offset", False):
            self.offset_layer = transformers.GPT2LMHeadModel(config) 
        else:
            self.offset_layer = nn.Linear(config.hidden_size, config.vocab_size) # Just a linear layer

    @classmethod
    def from_config(cls, config):
        model = cls(config)
        return model
    
    def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
        """
        Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method.
        """
        # print(input_ids)
        # print(self.config.prepend_property)
        # print(self.config.vocab_size)
        
        # if getattr(self.config, "prepend_property", False) and input_ids[0, 0] < self.config.vocab_size:
        #     offset = self.config.vocab_size
        #     prepend_ids = []
        #     for prop in self.config.data_properties:
        #         prop_tok = kwargs[prop] + offset
        #         offset += len(getattr(self.config, "%s_dict" % prop))
        #         # print(prop_tok.unsqueeze(-1).size(), input_ids.size())
        #         prepend_ids.append(prop_tok)
        #         # input_ids = torch.cat([prop_tok.unsqueeze(-1), input_ids], dim=-1)
        #         # print(input_ids)
        #     prepend_ids = torch.stack(prepend_ids, dim=1)
        #     # print(prepend_ids.size())
        #     input_ids = torch.cat([prepend_ids, input_ids], dim=-1)
        #     print(input_ids.size())
        #     print(input_ids)
        #     exit()

        # print(kwargs["attention_mask"])

        return {"input_ids": input_ids, 
                "input_time": kwargs["input_time"], 
                "logits_processor": kwargs.get("logits_processor", None),
                "token_type_ids": kwargs.get("token_type_ids", None)} # [B, 1]?} # "input_time": kwargs["input_time"], 

    def get_offset(self, outputs=None, **argv):
        # print(argv)
        # print(torch.sum(argv["input_ids"][0] != argv["input_ids"][3]))
        if self.config.transformer_offset:
            outputs = self.offset_layer.forward(input_ids = argv.get("input_ids"), labels = argv.get("labels"), \
                attention_mask = argv.get("attention_mask"), output_hidden_states=True)
            offset = outputs.logits
        else:
            hidden_states = outputs.hidden_states[-1]
            offset = self.offset_layer(hidden_states)
        
        return offset
    
    def forward(self, input_time, return_hidden_states=False, **argv):
        time = discretize_time(
            input_time, 
            one_step=False, 
            normalize_time_a=self.config.normalize_time_a, 
            normalize_time_b=self.config.normalize_time_b,
            discrete=False)
        # print(time)
        beam_size = argv.get("input_ids").size(0) // input_time.size(0)
        time = time.unsqueeze(1).repeat(1, beam_size).view(-1)
        # print(time)
        # print(argv.get("token_type_ids", None))
        outputs = super().forward(
            input_ids = argv.get("input_ids"), labels = argv.get("labels"), \
            attention_mask = argv.get("attention_mask"), 
            output_hidden_states=True, 
            token_type_ids=argv.get("token_type_ids", None))
        # print(outputs)
        # print(len(outputs))
        
        rate = outputs.logits
        # print(rate.size(), time.size())
        # print(argv.get("input_ids"))
        # exit()
        logits = rate * time.unsqueeze(-1).unsqueeze(-1)
        offset = self.get_offset(outputs, **argv)
        logits = logits + offset
        outputs.logits = logits

        if return_hidden_states:
            outputs.hidden_states = outputs.hidden_states

        return outputs

class GPT2TimeConcatAddGlobalModel(GPT2TimeModel):
    def __init__(self, config) -> None:
        super().__init__(config)
        self.real_vocab_size = config.real_vocab_size
    
    def forward(self, input_time, return_hidden_states=False, **argv):
        time = discretize_time(
            input_time, 
            one_step=False, 
            normalize_time_a=self.config.normalize_time_a, 
            normalize_time_b=self.config.normalize_time_b,
            discrete=False)
        beam_size = argv.get("input_ids").size(0) // input_time.size(0)
        time = time.unsqueeze(1).repeat(1, beam_size).view(-1)

        # insert the token_type_ids:
        prop_tok = argv[self.config.data_properties[0]].long() + self.real_vocab_size # [B]
        
        global_prop_tok = prop_tok.clone()
        global_prop_tok.fill_(self.vocab_size - 1)

        if self.training:
            use_global_mask = torch.rand(prop_tok.size()).to(prop_tok.device) < self.config.train_global_prob
            prepend_ids_new = prop_tok.clone()
            prepend_ids_new[use_global_mask] = global_prop_tok[use_global_mask]
            argv["token_type_ids"] = prop_tok
            outputs = super().forward(input_time, **argv)
        else:
            # during the evaluation and generation time, mix the logits
            # local
            argv["token_type_ids"] = prop_tok
            outputs = super().forward(input_time, **argv)
        
            # global:
            argv["token_type_ids"] = global_prop_tok
            global_outputs = super().forward(input_time, **argv)

            mix_logits = self.config.test_global_prob * global_outputs.logits + (1 - self.config.test_global_prob ) * outputs.logits
            outputs.logits = mix_logits
            
        outputs.logits = outputs.logits[:, :, :self.real_vocab_size]

        if return_hidden_states:
            outputs.hidden_states = outputs.hidden_states

        return outputs

class GPT2TimePrependAddGlobalModel(GPT2TimeModel):
    def __init__(self, config) -> None:
        super().__init__(config)
        self.real_vocab_size = config.real_vocab_size
    
    def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
        """
        Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method.
        """
        # if getattr(self.config, "prepend_property", False) and input_ids[0, 0] < self.config.vocab_size:
        #     offset = self.config.vocab_size
        #     prepend_ids = []
        #     for prop in self.config.data_properties:
        #         prop_tok = kwargs[prop] + offset
        #         offset += len(getattr(self.config, "%s_dict" % prop))
        #         # print(prop_tok.unsqueeze(-1).size(), input_ids.size())
        #         prepend_ids.append(prop_tok)
        #         # input_ids = torch.cat([prop_tok.unsqueeze(-1), input_ids], dim=-1)
        #         # print(input_ids)
        #     prepend_ids = torch.stack(prepend_ids, dim=1)
        #     # print(prepend_ids.size())
        #     input_ids = torch.cat([prepend_ids, input_ids], dim=-1)
        #     print(input_ids.size())
        #     print(input_ids)
        #     exit()

        # print(kwargs["attention_mask"])

        return {"input_ids": input_ids, 
                "input_time": kwargs["input_time"], 
                "logits_processor": kwargs.get("logits_processor", None),
                "token_type_ids": kwargs.get("token_type_ids", None)} # [B, 1]?} # "input_time": kwargs["input_time"], 

    def prepend_property_tokens(self, **batch):
        offset = self.real_vocab_size
        prepend_ids = [] # local
        prepend_global_ids = [] # global
        for prop in self.config.data_properties:
            prop_tok = batch[prop].long() + offset
            offset += (len(getattr(self.config, "%s_dict" % prop)) + 1) # 1 is the extra global label
            prepend_ids.append(prop_tok)

            global_prop_tok = prop_tok.clone()
            global_prop_tok.fill_(offset - 1)
            prepend_global_ids.append(global_prop_tok)
        prepend_ids = torch.stack(prepend_ids, dim=1)
        prepend_global_ids = torch.stack(prepend_global_ids, dim=1)
        return prepend_ids, prepend_global_ids
        # print(prepend_ids.size())
        # print(prepend_ids)
        # prepend_ids_size = prepend_ids.size(1)

        # ori_input_ids = batch["input_ids"]
        # batch["input_ids"] = torch.cat([prepend_ids, input_ids], dim=-1)

    def forward(self, input_time, return_hidden_states=False, **argv):
        time = discretize_time(
            input_time, 
            one_step=False, 
            normalize_time_a=self.config.normalize_time_a, 
            normalize_time_b=self.config.normalize_time_b,
            discrete=False)
        # print(time)
        beam_size = argv.get("input_ids").size(0) // input_time.size(0)
        time = time.unsqueeze(1).repeat(1, beam_size).view(-1)
        # print(time)
        # print(argv.get("token_type_ids", None))

        prepend_ids, global_prepend_tok = self.prepend_property_tokens(**argv) # [B, k]
        prepend_ids_size = prepend_ids.size(1)

        # print(self.real_vocab_size, self.config.vocab_size)
        # print(argv["location"])

        # print(prepend_ids.size(), prepend_ids, global_prepend_tok.size())
        # print(global_prepend_tok.squeeze())
        # print(prepend_ids.squeeze())
        # exit()
        # print(self.training)

        if self.training:
            use_global_mask = torch.rand(prepend_ids.size()).to(prepend_ids.device) < self.config.train_global_prob
            # print(use_global_mask)
            prepend_ids_new = prepend_ids.clone()
            prepend_ids_new[use_global_mask] = global_prepend_tok[use_global_mask]
            # print(prepend_ids_new.squeeze())

            argv["input_ids"] = torch.cat([prepend_ids_new, argv["input_ids"]], dim=-1)
            argv["labels"] = argv["input_ids"] # TODO: attention_mask?
            argv["attention_mask"] = (argv["input_ids"] != self.config.pad_idx)

            outputs = super().forward(input_time, **argv)
            outputs.logits = outputs.logits[:, :, :self.real_vocab_size] # avoid to generating the location tokens
        else:
            # during the evaluation and generation time, mix the logits
            ori_input_ids = argv["input_ids"]
            # print(ori_input_ids.size())

            # local tokens:
            argv["input_ids"] = torch.cat([prepend_ids, ori_input_ids], dim=-1)
            argv["labels"] = argv["input_ids"]
            argv["attention_mask"] = (argv["input_ids"] != self.config.pad_idx)

            outputs = super().forward(input_time, **argv)
            local_logits = outputs.logits[:, :, :self.real_vocab_size] # avoid to generating the location tokens
        
            # global tokens:
            argv["input_ids"] = torch.cat([global_prepend_tok, ori_input_ids], dim=-1)
            argv["labels"] = argv["input_ids"]
            argv["attention_mask"] = (argv["input_ids"] != self.config.pad_idx)
            global_outputs = super().forward(input_time, **argv)
            global_outputs = global_outputs.logits[:,:,:self.real_vocab_size] # avoid to generating the location tokens
            mix_logits = self.config.test_global_prob * global_outputs + (1 - self.config.test_global_prob ) * local_logits
            outputs.logits = mix_logits
            
        outputs.logits = outputs.logits[:, prepend_ids_size:, :]

        if return_hidden_states:
            outputs.hidden_states = outputs.hidden_states

        return outputs


class GPT2TimeModelMultiHosts(transformers.GPT2LMHeadModel):
    def __init__(self, config, num_component, symmetry=True) -> None:
        super().__init__(config)
        self.num_component = num_component
        assert symmetry, "It is more complicated for non-symmetry matrix. Leave it as a future work"
        self.trans_rates = nn.ModuleList([transformers.GPT2LMHeadModel(config) for _ in range(num_component * (num_component + 1) // 2)])
        self.offsets = nn.ModuleList([transformers.GPT2LMHeadModel(config) for _ in range(num_component)]) # K
        self.config = config

    def get_initial_prob(self,  input_ids, labels, attention_mask):
        prob_vectors = []
        for k in range(self.num_component):
            outputs = self.offsets[k].forward(input_ids = input_ids, labels = labels, \
                    attention_mask = attention_mask, output_hidden_states=True)
            offset = outputs.logits # [B, L, V]
            x0 = nn.Softmax(dim=-1)(offset)
            prob_vectors.append(x0)
        
        prob_vectors = torch.stack(prob_vectors, dim=-1) # [B, L, V, K]
        return prob_vectors.view(-1, self.num_component)

    def get_trans_matrix(self, input_ids, labels, attention_mask):
        rates_matrix = [[0 for _ in range(self.num_component)] for _ in range(self.num_component)]
        for i in range(self.num_component):
            for j in range(i, self.num_component):
                k = (2 * self.num_component - i + 1) / 2 + j - i
                outputs = self.trans_rates[k].forward(input_ids = input_ids, labels = labels, \
                    attention_mask = attention_mask, output_hidden_states=True)
                rate = outputs.logits # [B, L, V]
                rate = rate + torch.rand(rate.size()).to(rate.device) * self.eps # To avoid the A is ill-defined.
                rate = nn.ReLU()(rate) # TODO: any better choice?

                rates_matrix[i][j] = rate
                rates_matrix[j][i] = rate
        rates_matrix = [item for row in rates_matrix for item in row]
        rates_matrix = torch.stack(rates_matrix, dim=-1) # [B, L, V, K**2]
        rates_matrix = rates_matrix.view(rates_matrix.size(0), rates_matrix.size(1), rates_matrix.size(2), self.num_component, self.num_component)
        
        # to avoid overfloat
        rates_matrix = torch.clamp(rates_matrix, min=self.eps, max=self.inf)
        eig_value, eig_vector = torch.linalg.eigh(rates_matrix.view(-1, self.num_component, self.num_component)) 
        # L: value, BxK, V: BxKxK

        return  eig_value, eig_vector

    @classmethod
    def from_config(cls, config):
        model = cls(config)
        return model
    
    def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
        """
        Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method.
        """
        return {"input_ids": input_ids, "input_time": kwargs["input_time"]}
    
    def forward(self, input_time, **argv):
        time = discretize_time(
            input_time, 
            one_step=False, 
            normalize_time_a=self.config.normalize_time_a, 
            normalize_time_b=self.config.normalize_time_b,
            discrete=False)
        beam_size = argv.get("input_ids").size(0) // input_time.size(0)
        time = time.unsqueeze(1).repeat(1, beam_size).view(-1)

        B, L = argv.get("input_ids").size()
        V = len(self.alphabet)


        eig_value, eig_vecs = self.get_trans_matrix(argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"))
        # [B*L*V, K], [B*L*V,K,K]
        init_prob = self.get_initial_prob(argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"))
        
        # For symmetry, C=V
        const = torch.bmm(eig_vecs.transpose(-2, -1), init_prob.unsqueeze(-1)).squeeze(-1) # [B*L*V,K]

        time = time.reshape(-1, 1, 1).expand(-1, L, V).reshape(-1, 1) # [B*L*V,1]
        p = (const * torch.exp(time * eig_value)).unsqueeze(1) * eig_vecs # [B*L*V, K, K]
        p = p.view(-1, V, self.num_component, self.num_component) # [B*L, V, K, K]
        p = torch.sum(p, dim=-1) # [B*L, V, K]
        p = p.view(B, -1, p.size(-1)) #[B, L*V, K]
        host_label = argv["host"].unsqueeze(1).repeat(1, p.size(1)).unsqueeze(-1) # [B, L*V, 1]
        p = torch.gather(p, -1, host_label).squeeze(-1).view(B, L, -1) # [B, L, V]
        logits = torch.log(p + self.eps) 
        # log_probs = - torch.gather(logp, argv.get("labels").unsqueeze(-1)).squeeze(-1) # [B, L]
        return logits

@register_model("gpt2_time_new")
# class GPT2TimeNew(LightningModule): 
class GPT2TimeNew(LanguageModelingTransformer):
    def __init__(self, config, alphabet, **kwargs) -> None:
        self.config = config
        self.alphabet = alphabet
        self.pad_idx = alphabet.pad()

        super().__init__(
            pretrained_model_name_or_path=config.model_name_or_path, # GPT-2
            load_weights=config.load_weights,  # False
            vocab_size=len(alphabet) if kwargs.get("vocab_size") is None else kwargs.get("vocab_size"),  # TODO: build the alphabet first!!!!!!!!!!
            max_position_embeddings=config.max_position_embeddings, # 1024 by default, but please set larger.
            num_hidden_layers=config.num_hidden_layers, # 12
            hidden_size=config.hidden_size # 768
            )
        
        
        
        # if not config.zero_offset:
        #     self.build_offset_layer(config)
        # if getattr(config, "second_order_rate", False):
        #     self.second_order_rate_layer = self.build_second_order_rate(config)

        # if getattr(config, "add_location", False): # .add_location:
        #     self.location_embeddings = nn.Embedding(len(config.location_list), config.hidden_size)
        #     # self.embeddings_nn = nn.Linear(config.hidden_size*2, config.hidden_size)
        # if getattr(config, "add_lineage", False):
        #     self.lineage_embeddings = nn.Embedding(len(config.lineage_to_index), config.hidden_size)
        
        if getattr(config, "load_from_pretrain_checkpoint", None):
            self.load_pretrained_model(config.load_from_pretrain_checkpoint)
        
        # num_of_training_samples = self.trainer.datamodule.train_dataset
    

    def initialize_model(self, pretrained_model_name_or_path: str):
        """create and initialize the model to use with this task,
        Feel free to overwrite this method if you are initializing the model in a different way
        """
        config = AutoConfig.from_pretrained(
            pretrained_model_name_or_path="gpt2", **self.model_data_kwargs
        )
        # print(config.vocab_size)
        setattr(config, "normalize_time_a", self.config.normalize_time_a)
        setattr(config, "normalize_time_b", self.config.normalize_time_b)
        setattr(config, "transformer_offset", self.config.transformer_offset)
        setattr(config, "transformer_model_name", getattr(self.config, "transformer_model_name", "GPT2Model"))
        self.model = GPT2TimeModel.from_config(config)

    # def build_offset_layer(self, config):
    #     if getattr(self.config, "transformer_offset", False):
    #         self.offset_layer = LanguageModelingTransformer(
    #             pretrained_model_name_or_path=config.model_name_or_path,
    #             load_weights=config.load_weights,
    #             vocab_size=len(self.alphabet),
    #             max_position_embeddings=config.max_position_embeddings,
    #             num_hidden_layers=config.num_hidden_layers,
    #             hidden_size=config.hidden_size) 
    #     else:
    #         self.offset_layer = nn.Linear(config.hidden_size, len(self.alphabet)) # Just a linear layer

    def build_second_order_rate(self, config):
        if getattr(self.config, "transformer_second_order_rate", False):
            second_order_layer = LanguageModelingTransformer(
                pretrained_model_name_or_path=config.model_name_or_path,
                load_weights=config.load_weights,
                vocab_size=len(self.alphabet),
                max_position_embeddings=config.max_position_embeddings,
                num_hidden_layers=config.num_hidden_layers,
                hidden_size=config.hidden_size) 
        else:
            second_order_layer = nn.Linear(config.hidden_size, len(self.alphabet)) # Just a linear layer
        return second_order_layer


    def load_pretrained_model(self, path):
        pretrained_model_state_dict = torch.load(path, map_location="cpu")["state_dict"]
        for state in pretrained_model_state_dict:
            if state in self.state_dict():
                if self.state_dict()[state].size() != pretrained_model_state_dict[state].size():
                    logging.warning("The parameter %s of pretrained model (%s) doesn't fit the current model %s." % (state, str(pretrained_model_state_dict[state].size()), str(self.state_dict()[state].size())))
                else:
                    self.state_dict()[state].copy_(pretrained_model_state_dict[state])

    def configure_optimizers(self) -> Dict:
        # rank_zero_warn(
        #     "You haven't specified an optimizer or lr scheduler. "
        #     "Defaulting to AdamW with an lr of 1e-5 and linear warmup for 10% of steps. "
        #     "To change this, override ``configure_optimizers`` in  TransformerModule."
        # )
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.learning_rate)
        num_training_steps, num_warmup_steps = self.compute_warmup(
            num_training_steps=-1,
            num_warmup_steps=0.1,
        )
        scheduler = transformers.get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {"scheduler": scheduler, "interval": "step", "frequency": 1},
        }

    @classmethod
    def load_from_checkpoint(
        cls,
        checkpoint_path: Union[str, IO],
        map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
        hparams_file: Optional[str] = None,
        strict: bool = True,
        hf_pipeline_kwargs: Optional[Dict] = None,
        # config = None,
        args = None,
        **kwargs
    ):
        # if args.ensemble and len(checkpoint_path.split(",")) > 1:
        #     checkpoint_paths = checkpoint_path.split(",")
        #     # print(checkpoint_paths)
        #     model_list = []
        #     for path in checkpoint_paths:
        #         _model = super().load_from_checkpoint(path, map_location, hparams_file, strict)
        #         model_list.append(_model)
        #     # models = [super().load_from_checkpoint(path, map_location, hparams_file, strict) for path in checkpoint_paths]
        #     model = nn.ModuleList(model_list) 
        # else:
        
        # hparams_file=checkpoint_path.split("/checkpoints/")[0] + "/hparams.yaml"
        # print(hparams_file)
        # print(checkpoint_path)

        model = super().load_from_checkpoint(checkpoint_path, map_location, hparams_file, strict)
        # model.resume_from_checkpoint = checkpoint_path
        model.config.resume_from_checkpoint = checkpoint_path
        model.config.pred_data_paths = getattr(args, "pred_data_paths", "")
        if args is not None:
            model.config.test_data_paths = args.test_data_paths
        for key in kwargs:
            logging.info("Overwrite model hyperparameter %s:" % key + ", from " + str(getattr(model, key, None)) + " to " + str(kwargs[key]))
            setattr(model, key, kwargs[key])
        return model

    @classmethod
    def add_argparse_args(cls, parent_parser):
        # parent_parser = super(myGPT2, cls).add_argparse_args()
        # For testing
        parent_parser.add_argument('--load_weights', action='store_true')
        parent_parser.add_argument('--num_hidden_layers', type=int, default=12)
        parent_parser.add_argument('--tau', type=float, default=1.0, help="Devide t by tau.")
        parent_parser.add_argument('--hidden_size', type=int, default=768)
        parent_parser.add_argument('--model_name_or_path', type=str, default="gpt2")
        parent_parser.add_argument('--load_from_pretrain_checkpoint', type=str, default=None)
        # parent_parser.add_argument('--max_position_embeddings', type=int, default=1280)
        # For time embeddings
        parent_parser.add_argument('--normalize_time_a', type=int, default=1,  help="t = (t-b)/a")
        parent_parser.add_argument('--normalize_time_b', type=int, default=0, help="t = (t-b)/a")
        # parent_parser.add_argument('--time_agnostic', action='store_true')
        parent_parser.add_argument('--add_location', action='store_true', help="Add the location information.")
        parent_parser.add_argument('--add_lineage', action='store_true', help="Add the lineage information.")
        parent_parser.add_argument('--count_mse_loss', action='store_true', help="Use the count mse loss instead of ce loss.")
        # Settings for the off-set layer:
        parent_parser.add_argument('--weight_loss_by_count', type=str2bool, default="false", help="Weight loss of each sample by their counting not frequency")
        parent_parser.add_argument('--no_normalization_in_batch', action='store_true', help="Don't normalize the loss weight within the batch!!")
        parent_parser.add_argument('--zero_offset', action='store_true', help="Set the sequences distribution at offset as 0")
        parent_parser.add_argument('--offset_share_layer', type=int, default=-1, help="Use the hidden state at layer i to output the offset.")
        parent_parser.add_argument('--transformer_offset', action='store_true', help="Use another transformer NN to predict the offset.")
        # parent_parser.add_argument('--regression_loss', action='store_true', help="Use the regression loss instead of the MLE loss.")
        # parent_parser.add_argument('--normalize_time_a', type=int, default=1,  help="t = (t-b)/a")
        parent_parser.add_argument('--second_order_rate', action='store_true', help="Add the second order rate in modeling.")
        parent_parser.add_argument('--transformer_second_order_rate', action='store_true', help="Add the second order rate in modeling.")
        parent_parser.add_argument('--output_token_losses', type=str2bool, default="false")

        parent_parser.add_argument('--do_sample', type=str2bool, default="false")
        parent_parser.add_argument('--temperature', type=float, default=1.0)
        parent_parser.add_argument('--num_beams', type=int, default=1)
        parent_parser.add_argument('--num_return_sequences', type=int, default=1)

        parent_parser.add_argument('--zero_time', action='store_true', help="Set the time as zero.")
        parent_parser.add_argument('--set_time', type=float, default=None)
        
        parent_parser.add_argument('--ensemble', type=str2bool, default="false")
        parent_parser.add_argument('--average_over_time', type=str2bool, default="false")

        parent_parser.add_argument('--freeze_params_before_layer', type=int, default=0)
        parent_parser.add_argument('--weight_loss_by_time', type=str2bool, default="false")
        parent_parser.add_argument('--weight_loss_by_time_logistic_x0', type=float, default=None, help="f(x0)=0.5")
        parent_parser.add_argument('--weight_loss_by_time_logistic_k', type=float, default=0.5, help="small k, smoother.")
        
        parent_parser.add_argument('--debias_sample_weight', type=str2bool, default="false")

        return parent_parser
        
    def nll_loss(self, lm_logits, labels, loss_weight=None, reduce=True, ignore_bos=False):
        labels = labels.masked_fill(torch.eq(labels, self.alphabet.pad()), -100)
        if ignore_bos:
            # labels = labels.masked_fill(torch.eq(labels, self.alphabet.eos()), -100)
            labels = labels.masked_fill(torch.eq(labels, self.alphabet.bos()), -100)
        # Shift so that tokens < n predict n
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = CrossEntropyLoss(reduce=False)
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        loss = loss.view(shift_labels.size())

        if reduce:
            # print(loss)
            loss = loss.sum(dim=-1) / (shift_labels != -100).sum(dim=-1) # [B]

            if loss_weight is not None:
                if not self.config.no_normalization_in_batch:
                    loss_weight = loss_weight / loss_weight.sum()
                # print(loss_weight.sum(), "loss_weight", loss)
                # print(loss_weight.sum(), loss)
                loss = torch.sum(loss * loss_weight)
                # print(loss)
                # exit()
                # loss = loss.sum() / (shift_labels != -100).sum()
            else:
                loss = loss.mean()
        # else:
            # print((shift_labels != -100).sum(-1))
            # if not getattr(self.config, "output_token_losses", False):
                # loss = loss.sum(-1) # TODO: / (shift_labels != -100).sum(-1) # calculate the loss for each sample
        return loss
    
    def count_mse_loss(self, lm_logits, labels, total_count, target_count, reduce=True):
        # First, predict the count, and then
        # Shift so that tokens < n predict n
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        log_prob = torch.log_softmax(shift_logits, dim=-1) # [B, L, V]
        target_log_prob = torch.gather(log_prob, -1, shift_labels.unsqueeze(-1)).squeeze(-1) # [B, L]
        target_log_prob = (target_log_prob * (shift_labels != self.alphabet.pad())).sum(-1) # [B]
        loss = (target_log_prob.exp() * total_count - target_count) ** 2 / (target_count.max() ** 2) # .mean()
        # print(((target_log_prob.exp() * total_count - target_count) ** 2).mean(), loss.mean())
        if reduce:
            loss = loss.mean()
        return loss

    def get_second_order_rate(self, batch, outputs=None):
        if getattr(self.config, "transformer_second_order_rate", False):
            rate = self.second_order_rate_layer.model(input_ids = batch["input_ids"], labels = batch["labels"], attention_mask = batch["attention_mask"]).logits
        else:
            assert outputs is not None
            hidden_states = outputs.hidden_states[getattr(self.config, "offset_share_layer", -1)]
            rate = self.second_order_rate_layer(hidden_states)
        return rate

    def get_offset(self, batch, outputs=None):
        if getattr(self.config, "transformer_offset", False):
            offset = self.offset_layer.model(input_ids = batch["input_ids"], labels = batch["labels"], attention_mask = batch["attention_mask"]).logits
        else:
            assert outputs is not None
            hidden_states = outputs.hidden_states[getattr(self.config, "offset_share_layer", -1)]
            offset = self.offset_layer(hidden_states)
        return offset

    def get_unnorm_nll(self, rate_logits, labels, reduce=True):
        loss = - nn.NLLLoss(reduce=False)(rate_logits.view(-1, rate_logits.size(-1)), labels.view(-1))
        loss = loss.view(labels.size())
        if reduce:
            return loss.sum(-1)
        else:
            return loss

    def core(self, batch):
        inputs_embeds = self.model.transformer.wte(batch["input_ids"])
        if self.config.add_location:
            loc_embeds = self.location_embeddings(batch["location"])
            inputs_embeds = inputs_embeds + loc_embeds.unsqueeze(1)
        if getattr(self.config, "add_lineage", False):
            lineage_embeds = self.lineage_embeddings(batch["lineage"]) 
            # inputs_embeds = self.model.transformer.wte(batch["input_ids"])
            inputs_embeds = inputs_embeds + lineage_embeds.unsqueeze(1)
        outputs = self.model(inputs_embeds = inputs_embeds, labels = batch["labels"], attention_mask = batch["attention_mask"], output_hidden_states=True)
        # outputs = self.model(input_ids = batch["input_ids"], labels = batch["labels"], attention_mask = batch["attention_mask"], output_hidden_states=True)
        return outputs
    
    def get_rate(self, outputs):
        return outputs.logits

    def testing_forward(self, batch, batch_idx, return_rate=False, return_offset=False):
        loss_weight = batch.get('freq', None)
        max_time, min_time = self.max_testing_time, self.min_testing_time
        input_times = torch.arange(min_time, max_time + 1).to(batch["input_ids"].device)
        
        time = discretize_time(
            input_times, 
            one_step=False, 
            normalize_time_a=self.config.normalize_time_a, 
            normalize_time_b=self.config.normalize_time_b,
            discrete=False)

        outputs = self.core(batch)
        rate = self.get_rate(outputs).unsqueeze(0) # [1, B, L, V], time: [T, 1, 1, 1]
        logits = rate * time.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) / getattr(self.config, "tau", 1.0) # [B, L, V]
        if not self.config.zero_offset:
            offset = self.get_offset(batch, outputs).unsqueeze(0)
            logits = logits + offset
        else:
            logits = logits
        # logits: [T, B, L, V]
        labels = batch["labels"]# [T, B, L]
        labels = labels.masked_fill(torch.eq(labels, self.alphabet.pad()), -100)
       
        repeat_labels = labels.unsqueeze(0).repeat(logits.size(0), 1, 1)  
        loss = self.nll_loss(logits.view(-1, logits.size(2), logits.size(3)), \
            repeat_labels.view(-1, repeat_labels.size(2)), loss_weight=loss_weight, reduce=False)
        loss = loss.view(logits.size(0), -1) # [T, B]
        
        loss_dict = {}
        if return_rate:
            loss_dict["rate"] = self.get_unnorm_nll(rate.squeeze(0), labels)
        if return_offset and not self.config.zero_offset:
            loss_dict["offset"] = self.get_unnorm_nll(offset, labels)

        return loss, loss_dict

    def logistic_time_loss_weight(self, time):
        # k = 0.1
        k = getattr(self.config, "weight_loss_by_time_logistic_x0", 0.1)
        x0 = getattr(self.config, "weight_loss_by_time_logistic_k", 50) # 50
        return 1 / (1 + torch.exp(-k * (time - x0)))

    def forward(self, batch, batch_idx, reduce=True, return_rate=False, return_offset=False, mode="train"):
        if getattr(self.config, "zero_time", False):
            batch["input_time"].fill_(0.)

        if getattr(self.config, "set_time", None) is not None:
            batch["input_time"].fill_(self.config.set_time) # set time bin as a constant.

        # if getattr(self.config, "ensemble", False):
        #     for _model in self.model:
        #         logits = self.model(**batch).logits

        logits = self.model(**batch).logits / self.config.temperature
        if self.config.weight_loss_by_count and batch.get('freq', None) is not None and batch.get('bin_size', None) is not None:
            loss_weight = batch.get('freq', None) * batch.get('bin_size', None)
            if getattr(self.config, "debias_sample_weight", False):
                if mode == "train":
                    total_weights = self.trainer.datamodule.total_sample_count_train
                    total_num_of_samples = len(self.trainer.datamodule.train_dataset)
                    loss_weight = loss_weight / total_weights * total_num_of_samples / batch["input_ids"].size(0)
                elif mode == "val":
                    total_weights = self.trainer.datamodule.total_sample_count_valid
                    total_num_of_samples = len(self.trainer.datamodule.val_dataset)
                    loss_weight = loss_weight / total_weights * total_num_of_samples / batch["input_ids"].size(0)
                else:
                    loss_weight = None # For testing we don't calculate the loss weight
                # print(torch.sum(loss_weight))
                # loss_weight = loss_weight / sum(sample_num) * len(random_values) / batch_size
        elif not self.config.weight_loss_by_count and batch.get('freq', None) is not None:
            loss_weight = batch.get('freq', None)
        else:
            loss_weight = 1.0
        
        
        if getattr(self.config, "weight_loss_by_time", False):
            # print(loss_weight)
            loss_weight = loss_weight * self.logistic_time_loss_weight(batch["input_time"])
            # print(loss_weight)

        labels = batch["labels"]

        if getattr(self.config, "count_mse_loss", False): # self.config.count_mse_loss:
            loss = self.count_mse_loss(logits, labels, batch["bin_size"], batch["freq"] * batch["bin_size"], reduce=reduce)
        else:
            loss = self.nll_loss(logits, labels, loss_weight=loss_weight, reduce=reduce, ignore_bos=mode == "test")
        
        # loss_dict = {}
        # if return_rate:
        #     loss_dict["rate"] = self.get_rate(rate, labels)
        # if return_offset and not self.config.zero_offset:
        #     loss_dict["offset"] = self.get_rate(offset, labels)

        return loss, {}

    def training_step(self, batch, batch_idx):
        # self.generate("A")
        loss, loss_dict = self.forward(batch, batch_idx)
        self.log("train_loss", loss, prog_bar=True)
        for key in loss_dict:
            self.log("train_%s" % key, loss_dict[key], prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, loss_dict = self.forward(batch, batch_idx, mode="val")
        self.log("val_loss", loss, prog_bar=True)
        for key in loss_dict:
            self.log("val_%s" % key, loss_dict[key], prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx, dataloader_idx=0):
        loss, loss_dict = self.forward(batch, batch_idx, reduce=False, mode="test")
        # print(loss.size())
        # print(loss)
        # token_num = torch.sum(batch["labels"][..., 1:].contiguous() != self.alphabet.pad(), dim=-1)
        token_num = torch.sum(
            (batch["labels"][..., 1:].contiguous() != self.alphabet.pad()) * 
            (batch["labels"][..., 1:].contiguous() != self.alphabet.eos()) * 
            (batch["labels"][..., 1:].contiguous() != self.alphabet.bos()), dim=-1)

        if "freq" in batch and "bin_size" in batch:
            weight = batch["freq"] * batch["bin_size"]
        else:
            weight = token_num.new_zeros(token_num.size(0)) + 1.0
        # print(token_num)
        # print(weight)
        # exit()
        # exit()
        self.log("test_loss", loss.mean(), prog_bar=True)
        # for key in loss_dict:
            # self.log("test_%s" % key, loss_dict[key].mean(), prog_bar=True)
        return loss, token_num, weight

    def generate(self, text: str, device: torch.device = torch.device("cpu"), **kwargs) -> Any:
        # inputs = self.alphabet.encode_line("A")
        # print(inputs)
        inputs = inputs.to(self.model.device)
        input_time = torch.tensor([10.0]).to(self.model.device)
        # print(self.model)
        # print(self.model.generate(inputs.unsqueeze(0), input_time=input_time))
        # exit()
        return self.model.generate(inputs, **kwargs)

    def overwrite_generate_kwargs(self, new_config):
        setattr(self.config, "do_sample", new_config.do_sample)
        setattr(self.config, "num_beams", new_config.num_beams)
        setattr(self.config, "temperature", new_config.temperature)
        setattr(self.config, "num_return_sequences", new_config.num_return_sequences)
        setattr(self.config, "output_token_losses", new_config.output_token_losses)
        
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        generate_kwargs = {}
        generate_kwargs["temperature"] = getattr(self.config, "temperature", 1.0) # TODO: how to add this in testing?
        generate_kwargs["do_sample"] = getattr(self.config, "do_sample", True)
        generate_kwargs["num_beams"] = getattr(self.config, "num_beams", 1.0)
        setattr(self.model.config, "num_beams", generate_kwargs["num_beams"])

        generate_kwargs["num_return_sequences"] = max(getattr(self.config, "num_return_sequences", 1.0), generate_kwargs["num_beams"])

        if getattr(self.config, "generate_max_length", None) is None:
            generate_kwargs["max_length"] = self.config.max_position_embeddings
        else:
            generate_kwargs["max_length"] = getattr(self.config, "generate_max_length", None)
        generate_kwargs["pad_token_id"] = self.alphabet.pad()
        generate_kwargs["eos_token_id"] = self.alphabet.eos()
        generate_kwargs["bos_token_id"] = self.alphabet.bos()
        
        if batch["input_ids"][0, -1].item() == self.alphabet.eos():
            batch["input_ids"] = batch["input_ids"][:, :-1]

        # print(generate_kwargs)

        # model_inputs = {"input_ids": batch["input_ids"], "input_time": batch["input_time"]}
        # print(model_inputs)
        # print(self.model.tokenizer)
        # generate_kwargs["do_sample"] = False
        output_ids = self.model.generate(**batch, **generate_kwargs)
        # print(output_ids.size())
        input_time = batch["input_time"].unsqueeze(1).repeat(1, generate_kwargs["num_return_sequences"]).view(-1)
        # print(input_time.size())
        outputs = [{"prediction": self.alphabet.string(x), "src_time": input_time[i].item()} for i, x in enumerate(output_ids)]
        # print(outputs)
        # print(self.alphabet.bos(), self.alphabet.eos())
        return outputs
        
    def test_epoch_end(self, outputs):
        losses, token_nums, weights = [], [], []
        # print(len(outputs))
        if len(self.config.test_data_paths) == 1:
            outputs = [outputs]

        for dataloader_outputs in outputs:
            for output in dataloader_outputs:
                # outpu[0]: [B, L]
                losses.append(output[0].sum(-1)) # [B]
                token_nums.append(output[1])
                weights.append(output[2])
        losses = torch.cat(losses)
        token_nums = torch.cat(token_nums)
        weights = torch.cat(weights)
        # print("Sum of frequency", torch.exp(-losses).sum())
        print(torch.sum(weights), torch.sum(token_nums * weights))
        # print(losses.size(), token_nums.size(), weights.size())
        # print(outputs[0]) # loss, token_num, weight
        # ppl1 = torch.exp(torch.sum(losses) / torch.sum(token_nums))
        # print(ppl1)
        # ppl2 = torch.exp(torch.sum(weights * losses) / torch.sum(weights * token_nums))
        # print(ppl2)
        ppl = torch.exp(torch.sum(losses * weights) / torch.sum(token_nums * weights))
        nll = torch.sum(weights * losses) / torch.sum(weights)
        # nll = torch.exp(torch.sum(losses * weights))
        # exit()
        # collate data:
        # outputs is a list of dict, or a list of list of dict (for multiple dataloaders)
        # loss = torch.cat(outputs)
        self.log_dict({"perplexity": ppl, "nll": nll, "coverage": torch.exp(-losses).sum()})

        if self.config.output_token_losses:
            self.all_outputs = []
            for dataloader_outputs in outputs:
                for output in dataloader_outputs:
                    # print(output[0].size())
                    # loss = 
                    self.all_outputs.extend([x for x in output[0]])
        else:
            self.all_outputs = []
            # for loss in losses:
            #     # info_dict["prediction"] = loss.item()
            #     self.all_outputs.append(loss.item())
            for loss, tok_num in zip(losses, token_nums):
                # info_dict["prediction"] = loss.item()
                self.all_outputs.append({"prediction": loss.item(), "token_num": tok_num.item()})

        return ppl
    
    def output_testing_results(self, outputs, predict_dataset):
        
        predict_dataset = [item for sublist in predict_dataset for item in sublist]
        # print(len(outputs))
        # print(len(predict_dataset))
        assert len(outputs) == len(predict_dataset)
        results = []
        for index, output_loss in enumerate(outputs):
            # src_id,freq,src_time,prediction,rate,offset
            if self.config.output_token_losses:
                output_dict = {"prediction": " ".join([str(x.item()) for x in output_loss])}
            else:
                # output_dict = {"prediction": output_loss}
                output_dict = output_loss
            # print(output_dict)
            # exit()
            output_dict["src_id"] = predict_dataset[index]["src_id"]
            output_dict["src_time"] = predict_dataset[index]["src_time"]
            output_dict["freq"] = predict_dataset[index]["freq"]
            results.append(output_dict)
        return results

    def output_predicting_results(self, outputs, predict_dataset, *args, **kwargs):
        # assert len(outputs) == len(predict_dataset)
        # print(len(outputs), len(predict_dataset))
        results = []
        for i, output_dict in enumerate(outputs):
            # src_id,freq,src_time,prediction,rate,offset
            output_dict["prediction"] = output_dict["prediction"]
            output_dict["src_time"] = output_dict["src_time"]
            results.append(output_dict)
        
        output_path = args[0]
        # print(output_path)
        if output_path is not None and output_path.endswith(".csv"):
            fasta_path = output_path.split(".csv")[0] + ".fasta"
            # print(fasta_path)
            logging.info("Writing generations to %s" % fasta_path)
            with open(fasta_path, "w") as fout:
                for i, data in enumerate(results):
                    fout.write(">%d\n%s\n\n" % (i, data["prediction"]))

        return results

        results = []
        for output_dict in outputs:
            index = output_dict["index"]
            # src_id,freq,src_time,prediction,rate,offset
            output_dict["src_id"] = predict_dataset[index]["src_id"]
            output_dict["freq"] = predict_dataset[index]["freq"]
            results.append(output_dict)
        return results




@register_model("gpt2_time_ensemble")
class GPT2TimeEnsemble(GPT2TimeNew):
    def __init__(self, *models) -> None:
        # super().__init__(con)
        # print(type(models), len(models))
        super().__init__(models[0].config, models[0].alphabet)
        self._models = nn.ModuleList(models)
        self.alphabet = models[0].alphabet
        self.config = models[0].config
    
    @classmethod
    def load_from_checkpoint(cls, paths, **args):
        paths = paths.split(",")
        return GPT2TimeEnsemble(*[GPT2TimeNew.load_from_checkpoint(path, **args) for path in paths])

    @classmethod
    def add_argparse_args(cls, parent_parser):
        parent_parser = super(GPT2TimeEnsemble, cls).add_argparse_args(parent_parser)
        parent_parser.add_argument('--ensemble_checkpoints', type=str, nargs="+")
        return parent_parser

    def forward(self, batch, batch_idx, reduce=True, return_rate=False, return_offset=False):
        if getattr(self.config, "zero_time", False):
            batch["input_time"].fill_(0.)

        if getattr(self.config, "set_time", None) is not None:
            batch["input_time"].fill_(self.config.set_time) # set time bin as a constant.

        logits = []
        for model in self._models:
            _logits = model.model(**batch).logits
            _logits = nn.functional.log_softmax(_logits, dim=-1)
            # print(_logits.size())
            # print(torch.sum(torch.exp(_logits), dim=-1))
            logits.append(_logits)

        logits = torch.stack(logits, dim=0) / self.config.temperature
        # print(logits.size())
        logits = torch.logsumexp(logits, dim=0) - math.log(logits.size(0))
        # print(torch.sum(torch.exp(logits), dim=-1))
        loss_weight = batch.get('freq', None)

        labels = batch["labels"]

        if getattr(self.config, "count_mse_loss", False): # self.config.count_mse_loss:
            loss = self.count_mse_loss(logits, labels, batch["bin_size"], batch["freq"] * batch["bin_size"], reduce=reduce)
        else:
            loss = self.nll_loss(logits, labels, loss_weight=loss_weight, reduce=reduce)
        
        return loss, {}


@register_model("gpt2_time_prepend_property")
class GPT2TimePrependProperty(GPT2TimeNew):
    def __init__(self, config, alphabet) -> None:
        self.config = config
        super().__init__(config, alphabet, vocab_size=len(alphabet) + sum([len(getattr(self.config, "%s_dict" % prop)) for prop in config.data_properties]))
        # print(self.model.transformer.wte)
    
    def on_fit_start(self):
        tokenizer_length = len(self.tokenizer)
        # self.model.resize_token_embeddings(tokenizer_length) # Don't do this

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        generate_kwargs = {}
        generate_kwargs["temperature"] = getattr(self.config, "temperature", 1.0) # TODO: how to add this in testing?
        generate_kwargs["do_sample"] = getattr(self.config, "do_sample", True)
        generate_kwargs["num_beams"] = getattr(self.config, "num_beams", 1.0)
        setattr(self.model.config, "num_beams", generate_kwargs["num_beams"])
        generate_kwargs["num_return_sequences"] = max(getattr(self.config, "num_return_sequences", 1.0), generate_kwargs["num_beams"])
        if getattr(self.config, "generate_max_length", None) is None:
            generate_kwargs["max_length"] = self.config.max_position_embeddings
        else:
            generate_kwargs["max_length"] = getattr(self.config, "generate_max_length", None)
        generate_kwargs["pad_token_id"] = self.alphabet.pad()
        generate_kwargs["eos_token_id"] = self.alphabet.eos()
        generate_kwargs["bos_token_id"] = self.alphabet.bos()
        
        if batch["input_ids"][0, -1].item() == self.alphabet.eos():
            batch["input_ids"] = batch["input_ids"][:, :-1]
        
        offset = len(self.alphabet)
        prepend_ids = []
        for prop in self.config.data_properties:
            prop_tok = batch[prop].long() + offset
            offset += len(getattr(self.config, "%s_dict" % prop))
            prepend_ids.append(prop_tok)
        prepend_ids = torch.stack(prepend_ids, dim=1)  
        prepend_ids_size = prepend_ids.size(1)
        batch["input_ids"] = torch.cat([prepend_ids, batch["input_ids"]], dim=-1)
        # print(generate_kwargs)
        # model_inputs = {"input_ids": batch["input_ids"], "input_time": batch["input_time"]}
        # print(model_inputs)
        # print(self.model.tokenizer)
        # generate_kwargs["do_sample"] = False

        # forbid_tokens = [ _ for _ in range(len(self.alphabet), \
            # len(self.alphabet) + sum([len(getattr(self.config, "%s_dict" % prop)) for prop in self.config.data_properties]))]
        # generate_kwargs["forbid_tokens"] = torch.tensor(forbid_tokens).to(batch["input_ids"].device)
        # print(generate_kwargs["forbid_tokens"])

        generate_kwargs["logits_processor"] = [PrependLogitsProcessor(len(self.alphabet))]


        output_ids = self.model.generate(**batch, **generate_kwargs)
        # print(output_ids, prepend_ids_size, output_ids.size())
        output_ids = output_ids[:, prepend_ids_size:]
        # print(output_ids, output_ids.size())
        # print(output_ids.size())
        input_time = batch["input_time"].unsqueeze(1).repeat(1, generate_kwargs["num_return_sequences"]).view(-1)
        # print(input_time.size())
        outputs = [{"prediction": self.alphabet.string(x), "src_time": input_time[i].item()} for i, x in enumerate(output_ids)]
        # print(outputs)
        # exit()
        # print(self.alphabet.bos(), self.alphabet.eos())
        return outputs

    def initialize_config(self,):
        config = AutoConfig.from_pretrained(
            pretrained_model_name_or_path="gpt2", **self.model_data_kwargs
        )
        setattr(config, "prepend_property", True)
        setattr(config, "data_properties", self.config.data_properties)
        # print(config.vocab_size)
        setattr(config, "normalize_time_a", self.config.normalize_time_a)
        setattr(config, "normalize_time_b", self.config.normalize_time_b)
        setattr(config, "transformer_offset", self.config.transformer_offset)
        return config

    def initialize_model(self, pretrained_model_name_or_path: str):
        """create and initialize the model to use with this task,
        Feel free to overwrite this method if you are initializing the model in a different way
        """
        config = self.initialize_config()
        self.model = GPT2TimeModel.from_config(config)

    def forward(self, batch, batch_idx, reduce=True, return_rate=False, return_offset=False, mode="train"):
        input_ids = batch["input_ids"]

        offset = len(self.alphabet)
        prepend_ids = []
        for prop in self.config.data_properties:
            # print(prop, batch[prop])
            prop_tok = batch[prop].long() + offset
            offset += len(getattr(self.config, "%s_dict" % prop))
            # print(prop_tok.unsqueeze(-1).size(), input_ids.size())
            prepend_ids.append(prop_tok)
            # input_ids = torch.cat([prop_tok.unsqueeze(-1), input_ids], dim=-1)
            # print(input_ids)
        prepend_ids = torch.stack(prepend_ids, dim=1)
        # print(prepend_ids.size())
        # print(prepend_ids)
        prepend_ids_size = prepend_ids.size(1)

        ori_input_ids = batch["input_ids"]
        batch["input_ids"] = torch.cat([prepend_ids, input_ids], dim=-1)
        batch["labels"] = batch["input_ids"]
        # print(batch["attention_mask"])
        batch["attention_mask"] = (batch["input_ids"] != self.alphabet.pad())
        # print(batch["input_ids"].size())
        # print(batch["labels"].size())
        # print(batch["attention_mask"].size())

        
        if getattr(self.config, "zero_time", False):
            batch["input_time"].fill_(0.)

        if getattr(self.config, "set_time", None) is not None:
            batch["input_time"].fill_(self.config.set_time) # set time bin as a constant.

        logits = self.model(**batch).logits / self.config.temperature
        logits = logits[:, prepend_ids_size:, :]
        
        if self.config.weight_loss_by_count:
            loss_weight = batch.get('freq', None) * batch.get('bin_size', None)
        else:
            loss_weight = batch.get('freq', None)

        labels = ori_input_ids

        if getattr(self.config, "count_mse_loss", False): # self.config.count_mse_loss:
            loss = self.count_mse_loss(logits, labels, batch["bin_size"], batch["freq"] * batch["bin_size"], reduce=reduce)
        else:
            loss = self.nll_loss(logits, labels, loss_weight=loss_weight, reduce=reduce, ignore_bos=(mode == "test"))
        
        return loss, {}


@register_model("gpt2_time_prepend_property_add_global")
class GPT2TimePrependPropertyAddGlobal(GPT2TimeNew):
    def __init__(self, config, alphabet) -> None:
        self.config = config
        super().__init__(config, alphabet, vocab_size=len(alphabet) + 1 + sum([len(getattr(self.config, "%s_dict" % prop)) for prop in config.data_properties]))

    def on_fit_start(self):
        tokenizer_length = len(self.tokenizer)
    

    def initialize_config(self,):
        config = AutoConfig.from_pretrained(
            pretrained_model_name_or_path="gpt2", **self.model_data_kwargs
        )
        setattr(config, "prepend_property", True)
        setattr(config, "data_properties", self.config.data_properties)
        setattr(config, "normalize_time_a", self.config.normalize_time_a)
        setattr(config, "normalize_time_b", self.config.normalize_time_b)
        setattr(config, "transformer_offset", self.config.transformer_offset)

        setattr(config, "train_global_prob", self.config.train_global_prob)
        setattr(config, "test_global_prob", self.config.test_global_prob)

        for prop in self.config.data_properties:
            setattr(config, "%s_dict" % prop, getattr(self.config, "%s_dict" % prop))

        config.real_vocab_size = len(self.alphabet)
        config.pad_idx = self.pad_idx
        # config.vocab_size = len(self.alphabet) + 1 + sum([len(getattr(self.config, "%s_dict" % prop)) for prop in config.data_properties])
        return config

    @classmethod
    def add_argparse_args(cls, parent_parser):
        parent_parser = super(GPT2TimePrependPropertyAddGlobal, cls).add_argparse_args(parent_parser)
        parent_parser.add_argument('--train_global_prob', type=float, default=0.0)
        parent_parser.add_argument('--test_global_prob', type=float, default=0.0)
        return parent_parser

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        generate_kwargs = {}
        generate_kwargs["temperature"] = getattr(self.config, "temperature", 1.0) # TODO: how to add this in testing?
        generate_kwargs["do_sample"] = getattr(self.config, "do_sample", True)
        generate_kwargs["num_beams"] = getattr(self.config, "num_beams", 1.0)
        generate_kwargs["num_return_sequences"] = getattr(self.config, "num_return_sequences", 1.0)
        if getattr(self.config, "generate_max_length", None) is None:
            generate_kwargs["max_length"] = self.config.max_position_embeddings
        else:
            generate_kwargs["max_length"] = getattr(self.config, "generate_max_length", None)
        generate_kwargs["pad_token_id"] = self.alphabet.pad()
        generate_kwargs["eos_token_id"] = self.alphabet.eos()
        generate_kwargs["bos_token_id"] = self.alphabet.bos()
        
        if batch["input_ids"][0, -1].item() == self.alphabet.eos():
            batch["input_ids"] = batch["input_ids"][:, :-1]
        
        output_ids = self.model.generate(**batch, **generate_kwargs)
        # print(output_ids, prepend_ids_size, output_ids.size())
        # output_ids = output_ids[:, prepend_ids_size:]
        # print(output_ids, output_ids.size())
        # print(output_ids.size())
        input_time = batch["input_time"].unsqueeze(1).repeat(1, generate_kwargs["num_return_sequences"]).view(-1)
        # print(input_time.size())
        outputs = [{"prediction": self.alphabet.string(x), "src_time": input_time[i].item()} for i, x in enumerate(output_ids)]
        # print(outputs)
        # exit()
        # print(self.alphabet.bos(), self.alphabet.eos())
        return outputs
    
    def initialize_model(self, pretrained_model_name_or_path: str):
        """create and initialize the model to use with this task,
        Feel free to overwrite this method if you are initializing the model in a different way
        """
        config = self.initialize_config()
        print(config.vocab_size)
        self.model = GPT2TimePrependAddGlobalModel.from_config(config)
    
    def forward(self, batch, batch_idx, reduce=True, return_rate=False, return_offset=False, mode="train"):
        logits = self.model(**batch).logits / self.config.temperature
        
        if self.config.weight_loss_by_count:
            loss_weight = batch.get('freq', None) * batch.get('bin_size', None)
        else:
            loss_weight = batch.get('freq', None)

        if getattr(self.config, "count_mse_loss", False): # self.config.count_mse_loss:
            loss = self.count_mse_loss(logits, batch["labels"], batch["bin_size"], batch["freq"] * batch["bin_size"], reduce=reduce)
        else:
            loss = self.nll_loss(logits, batch["labels"], loss_weight=loss_weight, reduce=reduce, ignore_bos=(mode == "test"))
        
        return loss, {}


@register_model("gpt2_time_concat_property")
class GPT2TimeConcatProperty(GPT2TimeNew):
    def __init__(self, config, alphabet) -> None:
        config.transformer_model_name = "GPT2ModelWithExtraEmbeddings"
        # print(config.transformer_model_name)
        # print(len(getattr(config, "%s_dict" % config.data_properties[0])))
        self.config = config
        # print(len(getattr(self.config, "%s_dict" % config.data_properties[0])))
        super().__init__(config, alphabet, vocab_size=len(alphabet) + len(getattr(self.config, "%s_dict" % config.data_properties[0])) )
    
    def on_fit_start(self):
        tokenizer_length = len(self.tokenizer)
        # self.model.resize_token_embeddings(tokenizer_length) # Don't do this

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        generate_kwargs = {}
        generate_kwargs["temperature"] = getattr(self.config, "temperature", 1.0) # TODO: how to add this in testing?
        generate_kwargs["do_sample"] = getattr(self.config, "do_sample", True)
        generate_kwargs["num_beams"] = getattr(self.config, "num_beams", 1.0)
        setattr(self.model.config, "num_beams", generate_kwargs["num_beams"])

        generate_kwargs["num_return_sequences"] = max(getattr(self.config, "num_return_sequences", 1.0), generate_kwargs["num_beams"])

        if getattr(self.config, "generate_max_length", None) is None:
            generate_kwargs["max_length"] = self.config.max_position_embeddings
        else:
            generate_kwargs["max_length"] = getattr(self.config, "generate_max_length", None)
        generate_kwargs["pad_token_id"] = self.alphabet.pad()
        generate_kwargs["eos_token_id"] = self.alphabet.eos()
        generate_kwargs["bos_token_id"] = self.alphabet.bos()
        
        if batch["input_ids"][0, -1].item() == self.alphabet.eos():
            batch["input_ids"] = batch["input_ids"][:, :-1]
        
        prop_tok = batch[self.config.data_properties[0]].long() + len(self.alphabet) # [B]
        prop_tok = prop_tok.unsqueeze(1)
        # print(prop_tok.size(), batch["input_ids"].size())
        batch["token_type_ids"] = prop_tok
        # if token_type_ids.dim() == 1: # [B]
            # token_type_ids = token_type_ids.view(-1, 1).expand(input_shape)

        # make sure we won't generate location tokens!
        generate_kwargs["logits_processor"] = [PrependLogitsProcessor(len(self.alphabet))]
        # print(generate_kwargs)
        output_ids = self.model.generate(**batch, **generate_kwargs)
        # print(output_ids.size())
        input_time = batch["input_time"].unsqueeze(1).repeat(1, generate_kwargs["num_return_sequences"]).view(-1)
        outputs = [{"prediction": self.alphabet.string(x), "src_time": input_time[i].item()} for i, x in enumerate(output_ids)]
        return outputs

    def forward(self, batch, batch_idx, reduce=True, return_rate=False, return_offset=False, mode="train"):
        input_ids = batch["input_ids"]

        prop_tok = batch[self.config.data_properties[0]].long() + len(self.alphabet) # [B]
        batch["token_type_ids"] = prop_tok

        if getattr(self.config, "zero_time", False):
            batch["input_time"].fill_(0.)

        if getattr(self.config, "set_time", None) is not None:
            batch["input_time"].fill_(self.config.set_time) # set time bin as a constant.

        logits = self.model(**batch).logits / self.config.temperature
        logits = logits[:, :, :len(self.alphabet)]
        
        if self.config.weight_loss_by_count:
            loss_weight = batch.get('freq', None) * batch.get('bin_size', None)
        else:
            loss_weight = batch.get('freq', None)

        labels = batch["labels"]

        if getattr(self.config, "count_mse_loss", False): # self.config.count_mse_loss:
            loss = self.count_mse_loss(logits, labels, batch["bin_size"], batch["freq"] * batch["bin_size"], reduce=reduce)
        else:
            loss = self.nll_loss(logits, labels, loss_weight=loss_weight, reduce=reduce, ignore_bos=(mode == "test"))
        
        return loss, {}


@register_model("gpt2_time_concat_property_add_global")
class GPT2TimeConcatPropertyAddGlobal(GPT2TimePrependPropertyAddGlobal):

    def initialize_model(self, pretrained_model_name_or_path: str):
        """create and initialize the model to use with this task,
        Feel free to overwrite this method if you are initializing the model in a different way
        """
        config = self.initialize_config()
        self.model = GPT2TimeConcatAddGlobalModel.from_config(config)
    
# @register_model("gpt2_time_concat_property_add_global")
# class GPT2TimeConcatPropertyAddGlobal(GPT2TimeConcatProperty):
#     def __init__(self, config, alphabet) -> None:
#         # print(config.data_properties[0])
#         # print(getattr(config, "%s_dict" % config.data_properties[0]))
#         prop_dict = getattr(config, "%s_dict" % config.data_properties[0])
#         prop_list = getattr(config, "%s_list" % config.data_properties[0])
#         # prop_dict["global"]
#         prop_list.append("global")
#         prop_dict["global"] = len(prop_list) - 1
#         # print(prop_dict)
#         # print(config)
#         # prop_dict["global"] = 
#         # setattr(config, "%s_dict" % config.data_properties[0], prop_dict)
#         # setattr(config, "%s_list" % config.data_properties[0], prop_list)
#         # print(getattr(config, "%s_dict" % config.data_properties[0]))
#         # print(getattr(config, "%s_list" % config.data_properties[0]))
#         # exit()
#         super().__init__(config, alphabet)
#         # print(getattr(self.config, "%s_list" % self.config.data_properties[0]))
#         # print(len(getattr(self.config, "%s_dict" % config.data_properties[0])))
#         # exit()

#     def random_set_global(self, property_token):
#         # torch.multinomial([weights], 4, replacement=True)
#         weights = torch.tensor([1 - self.config.train_global_prob, self.config.train_global_prob]).to(property_token.device)
#         mask = torch.multinomial(weights, property_token.size(0), replacement=True).bool()
#         # print(mask)
#         # print(getattr(self.config, "%s_list" % self.config.data_properties[0]))
#         # print(len(getattr(self.config, "%s_list" % self.config.data_properties[0])))
#         property_token_with_global = property_token.clone()
#         property_token_with_global.masked_fill_(mask, len(getattr(self.config, "%s_dict" % self.config.data_properties[0])))
#         # print(property_token_with_global)
#         return property_token_with_global

#     @classmethod
#     def add_argparse_args(cls, parent_parser):
#         parent_parser = super(GPT2TimeConcatPropertyAddGlobal, cls).add_argparse_args(parent_parser)
#         parent_parser.add_argument('--train_global_prob', type=float, default=0.2)
#         return parent_parser

#     def forward(self, batch, batch_idx, reduce=True, return_rate=False, return_offset=False, mode="train"):
#         if mode == "train":
#             prop_token = self.random_set_global(batch[self.config.data_properties[0]])
#             batch[self.config.data_properties[0]] = prop_token
#             return super().forward(batch, batch_idx, reduce, return_rate, return_offset, mode)
#         else:
#             raise NotImplementedError
        
#         input_ids = batch["input_ids"]

#         prop_tok = batch[self.config.data_properties[0]].long() # [B]
#         batch["token_type_ids"] = prop_tok

#         if getattr(self.config, "zero_time", False):
#             batch["input_time"].fill_(0.)

#         if getattr(self.config, "set_time", None) is not None:
#             batch["input_time"].fill_(self.config.set_time) # set time bin as a constant.

#         logits = self.model(**batch).logits / self.config.temperature
#         logits = logits[:, :, :len(self.alphabet)]
        
#         if self.config.weight_loss_by_count:
#             loss_weight = batch.get('freq', None) * batch.get('bin_size', None)
#         else:
#             loss_weight = batch.get('freq', None)

#         labels = batch["labels"]

#         if getattr(self.config, "count_mse_loss", False): # self.config.count_mse_loss:
#             loss = self.count_mse_loss(logits, labels, batch["bin_size"], batch["freq"] * batch["bin_size"], reduce=reduce)
#         else:
#             loss = self.nll_loss(logits, labels, loss_weight=loss_weight, reduce=reduce, ignore_bos=(mode == "test"))
        
#         return loss, {}
