import torch
import pytorch_lightning as pl
from lightning_transformers.task.nlp.language_modeling import (
    LanguageModelingDataModule,
    LanguageModelingTransformer,
)
# from pytorch_lightning.pytorch.utilities import grad_norm
from lightning.pytorch.utilities import grad_norm
from lightning_transformers.task.nlp.translation import TranslationTransformer
from models.utils import calc_time_sample_prob
import transformers
from pytorch_lightning.utilities import rank_zero_warn
import torch.nn as nn
from torch.distributions.gamma import Gamma
from data.utils import discretize_time
from torch.nn import CrossEntropyLoss
from data.dms import DMSData
from data.constants import REF_RBD_SEQ
from models import register_model
import math, logging
from models.criterions.cross_entropy import CEWeightedLoss
from esm import pretrained
from typing import IO, Any, Callable, Dict, Optional, Tuple, Type, Union
from utils.args import str2bool
from transformers import GPT2LMHeadModel
from transformers import AutoConfig, PreTrainedTokenizerBase
from collections import defaultdict, namedtuple
import os
import time as time_time
from functorch import vmap
from models.gpt2_new import GPT2TimeModel
from copy import deepcopy
import torch.nn.functional as F
from torch.nn.functional import normalize
from collections import Counter

def power_iteration_simple(A, num_iterations=1000, tol=1e-6):
    # A: [K, K]
    K = A.size(0)
    v = torch.randn(K).to(A.device).detach()
    v /= torch.norm(v)

    for _ in range(num_iterations):
        w = A @ v # [K]
        eigenvalue = torch.sum(w * v, dim=-1) # [B, 1]
        v_new = w / torch.norm(w)
        if torch.norm(v_new - v) < tol:
            break
        # print(torch.norm(v_new - v))
        v = v_new
    # test pytorch
    return eigenvalue, v_new
        
def power_iteration(A, num_iterations=100, tol=1e-4, power_gradient=False): # 1e-4
    # A: [B, K, K]
    B, K = A.size(0), A.size(1)

    v = torch.randn(B, K).to(A.device).detach()
    v /= torch.norm(v, dim=-1, keepdim=True)
    
    v = v.unsqueeze(-1) # [B, K, 1]

    if power_gradient:
        for _ in range(num_iterations):
            w = torch.bmm(A, v) # [B, K, 1]
            eigenvalue = torch.sum(w * v, dim=1) # [B, 1]
            v_new = w / torch.norm(w, dim=1, keepdim=True)
            # if num_iterations % 100 == 0:
                # print(torch.norm(v_new - v))
            if torch.norm(v_new - v) < tol:
                break
            # print(torch.norm(v_new - v))
            # print(_, torch.norm(v_new - v))
            v = v_new
    else:
        with torch.no_grad(): # stop the gradient during the iteration, the training would not be stable.
            for _ in range(num_iterations):
                w = torch.bmm(A, v) # [B, K, 1]
                eigenvalue = torch.sum(w * v, dim=1) # [B, 1]
                v_new = w / torch.norm(w, dim=1, keepdim=True)
                # if num_iterations % 100 == 0:
                    # print(torch.norm(v_new - v))
                if torch.norm(v_new - v) < tol:
                    break
                # print(torch.norm(v_new - v))
                # print(_, torch.norm(v_new - v))
                v = v_new
    
        # only calculate the gradient for the last step
        w = torch.bmm(A, v)
        eigenvalue = torch.sum(w * v, dim=1) # [B, 1]
        v_new = w / torch.norm(v_new, dim=1, keepdim=True)

    return eigenvalue, v_new

def power_iteration_v2(A, num_iterations=1000, tol=1e-4): # 1e-4
    # A: [B, K, K]
    B, K = A.size(0), A.size(1)

    v = torch.randn(B, K, 1).to(A.device).detach()
    v /= torch.norm(v, dim=1, keepdim=True)

    u = torch.randn(B, K, 1).to(A.device).detach()
    u /= torch.norm(u, dim=1, keepdim=True)

    eps = 1e-8
    
    v = v.unsqueeze(-1) # [B, K, 1]

    with torch.no_grad():
        for _ in range(num_iterations):
            # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
            # are the first left and right singular vectors.
            # This power iteration produces approximations of `u` and `v`.
            v = normalize(torch.bmm(A.transpose(-2, -1), u), dim=1, eps=eps, out=v)
            u = normalize(torch.bmm(A, v), dim=1, eps=eps, out=u)
        if num_iterations > 0:
            # See above on why we need to clone
            u = u.clone(memory_format=torch.contiguous_format)
            v = v.clone(memory_format=torch.contiguous_format)
    
    sigma = torch.bmm(u.transpose(-2, -1), torch.bmm(A, v)).squeeze(-1) # [B, 1]
    v /= torch.norm(v, dim=1, keepdim=True) # [B, K, 1]
    # v = v.squeeze(-1).unsqueeze(1)
    return sigma, v
            
def block_power_method(A, k, num_iters=500, tol=1e-3):
    """
    Compute the top k eigenvalues of a matrix using the block power method.

    Args:
    - A (torch.Tensor): Input matrix of size (B, n, n).
    - k (int): Number of top eigenvalues to compute.
    - num_iters (int): Maximum number of iterations.
    - tol (float): Convergence tolerance.

    Returns:
    - eigenvalues (torch.Tensor): Top k eigenvalues.
    - eigenvectors (torch.Tensor): Corresponding eigenvectors.
    """

    B, n = A.size(0), A.size(1)
    eigenvectors = torch.randn(B, n, k, device=A.device)
    eigenvectors, R = torch.linalg.qr(eigenvectors)
    eigenvectors = eigenvectors / eigenvectors.norm(dim=1, keepdim=True)

    for _ in range(num_iters):
        # Power iteration
        # prev_eigenvectors = eigenvectors
        eigenvectors = torch.bmm(A, eigenvectors) # [B, n, k]

        # QR decomposition
        Q, R = torch.linalg.qr(eigenvectors)
        eigenvectors = Q

        # Check for convergence
        eigenvalues = torch.diagonal(torch.bmm(eigenvectors.transpose(-2, -1), torch.bmm(A, eigenvectors)), 
                                 dim1=1, dim2=2) # [B, k]
        
        change = torch.norm(torch.bmm(A, eigenvectors) - eigenvectors * eigenvalues.unsqueeze(1), dim=1)
        # change = torch.norm(prev_eigenvectors - eigenvectors @ eigenvectors.T @ prev_eigenvectors)
        print(torch.mean(change))
        if torch.mean(change) < tol:
            break

    # Compute eigenvalues
    eigenvalues = torch.diagonal(torch.bmm(eigenvectors.transpose(-2, -1), torch.bmm(A, eigenvectors)), 
                                 dim1=1, dim2=2)

    return eigenvalues, eigenvectors

def gram_schmidt(vectors):
    if vectors.dim() == 3:
        K = vectors.size(1)
        assert vectors.size(2) == K
    elif vectors.dim() == 2:
        K = vectors.size(1)
        assert vectors.size(0) == K
        vectors = vectors.unsqueeze(1) # batch size == 1

    basis = []
    for i in range(K):
        v = vectors[:, i] # B, K
        w = v - sum(torch.sum(v * b, dim=-1, keepdim=True) * b for b in basis)
        if torch.norm(w) > 1e-10:  # Avoid dividing by very small numbers
            basis.append(w / torch.norm(w, dim=-1, keepdim=True))
    basis = torch.stack(basis, dim=-1)
    return basis


GPTOutputs = namedtuple('GPTOutputs', ['logits', 'info_dict'])

class GPT2TimeModelMultiHosts(transformers.GPT2LMHeadModel):
    def __init__(self, config, num_component, symmetry=True, base_models=None, **args) -> None:
        super().__init__(config)
        self.num_component = num_component
        assert symmetry, "It is more complicated for non-symmetry matrix. Leave it as a future work"

        # logging.info(config)

        self.share_base = config.share_base
        self.output_layer_type = config.output_layer_type

        self.veigh = vmap(torch.linalg.eigh)  # 
        self.config = config
        self.eps = config.min_rate_value # 1e-12
        self.inf = config.max_rate_value # 1e5
        # self.inf = 1e3
        if config.pos_function == "softplus":
            self.pos_func = torch.nn.Softplus()
        elif config.pos_function == "sigmoid":
            self.pos_func = torch.nn.Sigmoid()
        elif config.pos_function == "relu":
            self.pos_func = torch.nn.ReLU()
        elif config.pos_function == "exp":
            self.pos_func = torch.exp
        elif config.pos_function == "abs":
            self.pos_func = torch.abs
        else:
            self.pos_func = None
        # logging.info("pos_func", config.pos_function)

        if config.offset_pos_function == "softmax":
            self.offset_pos_func = nn.Softmax(dim=-1)
        elif config.offset_pos_function == "softplus":
            self.offset_pos_func = torch.nn.Softplus()
        elif config.offset_pos_function == "relu":
            self.offset_pos_func = torch.nn.ReLU()
        elif config.offset_pos_function == "abs":
            self.offset_pos_function = torch.abs
        else:
            self.offset_pos_func = None
        
        self.apply_log_softmax = getattr(config, "apply_log_softmax", False)

        # LPBPCG
        # torch.lobpcg(A, k=None)

        self.build_models(config, num_component, symmetry=symmetry, base_models=base_models, **args)

        # print(self.trans_base)
        # print(self.offsets_base)

    def build_models(self, config, num_component, symmetry=True, base_models=None, **args):
        if config.share_base:
            # Share the GPT-2 Transformer layers
            if base_models is not None:
                self.trans_base = base_models["trans_base"]
                self.offsets_base = base_models["offsets_base"]
            else:
                self.trans_base = transformers.GPT2LMHeadModel(config)
                if self.config.transformer_offset:
                    self.offsets_base = transformers.GPT2LMHeadModel(config)
                else:
                    self.offsets_base = self.trans_base

            if config.output_layer_type == "linear":
                self.trans_heads = nn.ModuleList([nn.Linear(config.hidden_size, config.vocab_size) for _ in range(num_component * (num_component + 1) // 2)]) 
                self.offsets_heads = nn.ModuleList([nn.Linear(config.hidden_size, config.vocab_size) for _ in range(num_component)]) 
            elif config.output_layer_type == "gpt2":
                output_layer_config = args["output_layer_config"]
                # print(output_layer_config)
                self.trans_heads = nn.ModuleList([transformers.GPT2LMHeadModel(output_layer_config) for _ in range(num_component * (num_component + 1) // 2)]) 
                self.offsets_heads = nn.ModuleList([transformers.GPT2LMHeadModel(output_layer_config) for _ in range(num_component)]) 
        else:
            self.trans_rates = nn.ModuleList([transformers.GPT2LMHeadModel(config) for _ in range(num_component * (num_component + 1) // 2)])        
            self.offsets = nn.ModuleList([transformers.GPT2LMHeadModel(config) for _ in range(num_component)]) # K
        
        if getattr(config, "add_trans_layer_norm", False):
            self.trans_layer_norm = nn.LayerNorm([num_component, num_component], elementwise_affine=False)

        if getattr(config, "add_geo_info", False):
            self.geo_feats = nn.Parameter(args["geo_info"], requires_grad=False) # [n_countries, n_feats]
            self.dis_sigma = nn.Parameter(torch.tensor(1000.0), requires_grad=True) # [n_feats, n_feats]
            self.dis_ratio = nn.Parameter(torch.tensor(0.1), requires_grad=True) # [n_feats, n_feats]
            # self.geo_feat_map = nn.Parameter(torch.randn(self.geo_feats.size(1), self.geo_feats.size(1)), requires_grad=True) # [n_feats, n_feats]
                
    def get_initial_prob(self, input_ids, labels, attention_mask, cache_hidden_states=None):
        prob_vectors = []
        if self.share_base:
            if cache_hidden_states is None:
                outputs = self.trans_base.forward(input_ids = input_ids, labels = labels, \
                        attention_mask = attention_mask, output_hidden_states=True)
                hidden_states = outputs.hidden_states[-1]
            else:
                hidden_states = cache_hidden_states

            for k in range(self.num_component):
                if self.output_layer_type == "linear":
                    offset = self.offsets_heads[k](hidden_states)
                elif self.output_layer_type == "gpt2":
                    # , labels = labels
                    offset = self.offsets_heads[k].forward(
                        inputs_embeds = hidden_states, 
                        attention_mask = attention_mask, 
                        output_hidden_states=True).logits

                # x0 = nn.Softmax(dim=-1)(offset)
                # print(offset.size())
                x0 = self.offset_pos_func(offset)
                # print(torch.sum(x0, dim=-1))
                prob_vectors.append(x0)
        else:
            for k in range(self.num_component):
                outputs = self.offsets[k].forward(input_ids = input_ids, labels = labels, \
                        attention_mask = attention_mask, output_hidden_states=True)
                offset = outputs.logits # [B, L, V]
                # x0 = nn.Softmax(dim=-1)(offset)
                x0 = self.offset_pos_func(offset)
                prob_vectors.append(x0)
        
        prob_vectors = torch.stack(prob_vectors, dim=-1) # [B, L, V, K]
        return prob_vectors.view(-1, self.num_component)

    def get_rates(self, k, input_ids, labels, attention_mask):
        if self.share_base:
            pass
        else:
            outputs = self.trans_rates[k].forward(input_ids = input_ids, labels = labels, \
                    attention_mask = attention_mask, output_hidden_states=True)
            return outputs.logits

    def get_trans_matrix(self, input_ids, labels, attention_mask):        
        rates_matrix = [[0 for _ in range(self.num_component)] for _ in range(self.num_component)]
        for i in range(self.num_component):
            for j in range(i, self.num_component):
                k = ((2 * self.num_component - i + 1) * i) // 2 + j - i
                
                outputs = self.trans_rates[k].forward(input_ids = input_ids, labels = labels, \
                    attention_mask = attention_mask, output_hidden_states=True)
                rate = outputs.logits # [B, L, V]
                
                rate = rate + torch.rand(rate.size()).to(rate.device) * self.eps # To avoid the A is ill-defined.
                rate = self.pos_func(rate) # TODO: any better choice?

                rates_matrix[i][j] = rate
                rates_matrix[j][i] = rate
        rates_matrix = [item for row in rates_matrix for item in row]
        rates_matrix = torch.stack(rates_matrix, dim=-1) # [B, L, V, K**2]
        rates_matrix = rates_matrix.view(rates_matrix.size(0), rates_matrix.size(1), rates_matrix.size(2), self.num_component, self.num_component)
        assert torch.all(rates_matrix.transpose(-2, -1) == rates_matrix)
        
        # print(rates_matrix[0,0,0])
        # print(rates_matrix[0,1,4])
        # exit()

        # to avoid overfloat
        rates_matrix = torch.clamp(rates_matrix, min=self.eps, max=self.inf)
        eig_value, eig_vector = torch.linalg.eigh(rates_matrix.view(-1, self.num_component, self.num_component)) 
        # L: value, BxK, V: BxKxK

        return  eig_value, eig_vector, rates_matrix
    
    def get_trans_matrix_from_base(self, input_ids, labels, attention_mask, cache_hidden_states=None, generation=False, **args):

        if cache_hidden_states is None:
            outputs = self.trans_base.forward(input_ids = input_ids, labels = labels, \
                    attention_mask = attention_mask, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
        else:
            hidden_states = cache_hidden_states
        
        if generation:
            L = hidden_states.size(1) # Real len of sequences
            hidden_states = hidden_states[:, -1:, :]
        # print(hidden_states.size())

        rates_matrix = [[0 for _ in range(self.num_component)] for _ in range(self.num_component)]
        for i in range(self.num_component):
            for j in range(i, self.num_component):
                k = ((2 * self.num_component - i + 1) * i) // 2 + j - i
                
                if self.output_layer_type == "linear":
                    rate = self.trans_heads[k](hidden_states)
                    # print(rate.size())
                elif self.output_layer_type == "gpt2":
                    # print(outputs.hidden_states[-1].size()) # [B, L, H]
                    # print(labels.size())
                    # , labels = labels,
                    rate = self.trans_heads[k].forward(inputs_embeds = hidden_states,\
                        attention_mask = attention_mask, output_hidden_states=True).logits
                    # print(rate.size()) # [B x L x V]

                # if torch.any(torch.isnan(rate)): # Check where is the NAN from?
                #     print(torch.any(torch.isnan(hidden_states)))
                rate = rate + torch.rand(rate.size()).to(rate.device) * self.eps # To avoid the A is ill-defined.
                # rate = self.pos_func(rate) # TODO: any better choice? first, pos, then add noise.

                rates_matrix[i][j] = rate
                rates_matrix[j][i] = rate
        rates_matrix = [item for row in rates_matrix for item in row]
        rates_matrix = torch.stack(rates_matrix, dim=-1) # [B, L, V, K**2]
        rates_matrix = rates_matrix.view(rates_matrix.size(0), rates_matrix.size(1), rates_matrix.size(2), self.num_component, self.num_component)
        if getattr(self.config, "add_trans_layer_norm", False):
            rates_matrix = self.trans_layer_norm(rates_matrix)
        rates_matrix = self.pos_func(rates_matrix)

        if self.config.add_geo_info:
            # print(self.geo_feats.size())
            distance = torch.sum((self.geo_feats.unsqueeze(1) - self.geo_feats.unsqueeze(0)) ** 2, dim=-1) # K*K
            # print(distance)
            # geo_feats = self.geo_feats @ self.geo_feat_map @ self.geo_feats.T # K*K
            # print(geo_feats)
            geo_feats = self.dis_ratio * torch.exp(-distance / self.dis_sigma) # positive & symmetry
            # print(geo_feats)
            # geo_feats = self.pos_func(geo_feats + geo_feats.T) # positive & symmetry
            # print(geo_feats, geo_feats.size(), rates_matrix.size())
            rates_matrix += geo_feats.view(1, 1, 1, geo_feats.size(0), geo_feats.size(1))
            # print(rates_matrix)

        assert torch.all(rates_matrix.transpose(-2, -1) == rates_matrix), rates_matrix[rates_matrix.transpose(-2, -1) != rates_matrix]

        # to avoid overfloat
        rates_matrix = torch.clamp(rates_matrix, min=self.eps, max=self.inf)
        # print(rates_matrix.size())
        # print(rates_matrix[0,0,0])
        # print(rates_matrix[0,100,0])
        # print(rates_matrix[0,100,5])
        # # print(torch.mean(rates_matrix[]))
        # print(torch.mean(rates_matrix.view(-1, self.num_component, self.num_component), dim=0))
        # exit()
        # L: value, BxK, V: BxKxK
        # torch.cuda.synchronize()
        # t1=time()
        # print(rates_matrix.dtype)
        eig_value, eig_vector = torch.linalg.eigh(rates_matrix.view(-1, self.num_component, self.num_component)) 

        # if self.config.lobpcg: # Approximate by top-k eigenvalues & eigenvectors
        #     eig_value, eig_vector = torch.lobpcg(rates_matrix.view(-1, self.num_component, self.num_component), 
        #                                          k=self.config.lobpcg_k, method="ortho", niter=10)
        #     print(eig_value.size(), eig_vector.size(), self.num_component)
        # else:
        if getattr(self.config, "topk_eigen", None) is not None:
            # print(eig_value.size(), eig_vector.size())
            eig_value = eig_value[:, -self.config.topk_eigen:]
            eig_vector = eig_vector[:, :, -self.config.topk_eigen:] # V[:,:,-2:]
            # print(eig_value.size(), eig_vector.size())
        

        assert ~torch.any(torch.isnan(eig_value))
        assert ~torch.any(torch.isnan(eig_vector))

        # torch.cuda.synchronize()
        # t2=time()
        # gpu_time = t2-t1
        # print("gpu_time", gpu_time)
        if generation:
            eig_value = eig_value.view(-1, 1, self.config.vocab_size, eig_value.size(-1)).repeat(1, L, 1, 1).view(-1, eig_value.size(-1))
            eig_vector = eig_vector.view(-1, 1, self.config.vocab_size, eig_vector.size(-2), eig_vector.size(-1)).repeat(1, L, 1, 1, 1).view(-1, eig_vector.size(-2), eig_vector.size(-1))
            # print(eig_value.size(), eig_vector.size(), rates_matrix.size())
        return  eig_value, eig_vector, rates_matrix

    @classmethod
    def from_config(cls, config, num_component, **args):
        model = cls(config, num_component, **args)
        return model
    
    # def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
    #     token_type_ids = kwargs.get("token_type_ids", None)
    #     # only last token for inputs_ids if past is defined in kwargs
    #     if past:
    #         input_ids = input_ids[:, -1].unsqueeze(-1)
    #         if token_type_ids is not None:
    #             token_type_ids = token_type_ids[:, -1].unsqueeze(-1)

    #     attention_mask = kwargs.get("attention_mask", None)
    #     position_ids = kwargs.get("position_ids", None)

    #     if attention_mask is not None and position_ids is None:
    #         # create position_ids on the fly for batch generation
    #         position_ids = attention_mask.long().cumsum(-1) - 1
    #         position_ids.masked_fill_(attention_mask == 0, 1)
    #         if past:
    #             position_ids = position_ids[:, -1].unsqueeze(-1)
    #     else:
    #         position_ids = None
    #     return {
    #         "input_ids": input_ids,
    #         "past_key_values": past,
    #         "use_cache": kwargs.get("use_cache"),
    #         "position_ids": position_ids,
    #         "attention_mask": attention_mask,
    #         "token_type_ids": token_type_ids,
    #     }

    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs) -> Dict[str, Any]:
        """
        Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method.
        """
        # print(kwargs)
        # print("num_beams", )
        # print(input_ids.size())
        kwargs.pop("attention_mask")
        # print(kwargs["attention_mask"].size())
        # print(input_ids.size())
        ret_dict = super().prepare_inputs_for_generation(input_ids, past, **kwargs)
        # print(ret_dict["input_ids"].size())
        ret_dict["input_time"] = kwargs["input_time"]
        ret_dict[self.config.data_property] = kwargs[self.config.data_property]
        # print(ret_dict)
        # ret_dict["attention_mask"] = (input_ids != )
        # print(ret_dict["attention_mask"].size())
        # ret_dict.pop("attention_mask")
        # print(ret_dict["input_ids"].size())
        # exit()
        # print(ret_dict["past_key_values"])
        ret_dict["generation"] = True

        if self.config.num_beams > 1:
            # extend the input_time and 
            bsz = ret_dict[self.config.data_property].size(0)
            expanded_return_idx = (
                torch.arange(bsz).view(-1, 1).repeat(1, self.config.num_beams).view(-1).to(input_ids.device)
            )
            ret_dict[self.config.data_property] = ret_dict[self.config.data_property].index_select(0, expanded_return_idx)
            ret_dict["input_time"] = ret_dict["input_time"].index_select(0, expanded_return_idx)

        return ret_dict
        # return {"input_ids": input_ids, "input_time": kwargs["input_time"], self.config.data_property: kwargs[self.config.data_property]}
    
    def forward_cache_hidden_states(self, input_time, **argv):
        return {} # Do nothing. But you could rewrite this if necessary.
        pass

    @property
    def number_of_component(self, ):
        return self.num_component

    def forward(self, input_time, **argv):
        generation = argv.pop("generation", False)

        return_rates_matrix = argv.get("return_rates_matrix", False)
        return_init_prob = argv.get("return_init_prob", False)
        return_eigen_values = argv.get("return_eigen_values", False)
        return_logits = argv.get("return_logits", False)
        return_component_logits = argv.get("return_component_logits", False)

        # torch.cuda.synchronize()
        # t0 = time_time.time()
        
        caches = self.forward_cache_hidden_states(input_time, **argv)
        trans_cache_hidden_states = caches.get("trans_cache_hidden_states")
        offsets_cache_hidden_states = caches.get("offsets_cache_hidden_states")
        # Overwrite
        if "trans_cache_hidden_states" in argv:
            trans_cache_hidden_states = argv.get("trans_cache_hidden_states")
        if "offsets_cache_hidden_states" in argv:
            offsets_cache_hidden_states = argv.get("offsets_cache_hidden_states")

        info_dict = {}

        time = discretize_time(
            input_time, 
            one_step=False, 
            normalize_time_a=self.config.normalize_time_a, 
            normalize_time_b=self.config.normalize_time_b,
            discrete=False)
        beam_size = argv.get("input_ids").size(0) // input_time.size(0)
        time = time.unsqueeze(1).repeat(1, beam_size).view(-1)

        # print(time)

        B, L = argv.get("input_ids").size()
        V = self.config.vocab_size


        # torch.cuda.synchronize()
        # t1 = time_time.time()
        # print("Time for calculate hidden states: ", t1-t0)

        # [B*L*V, K], [B*L*V,K,K]
        if self.share_base:
            eig_value, eig_vecs, rates_matrix = self.get_trans_matrix_from_base(
                argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), 
                cache_hidden_states=trans_cache_hidden_states, generation=generation)
        else:
            eig_value, eig_vecs, rates_matrix = self.get_trans_matrix(
                argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), generation=generation)
        
        # torch.cuda.synchronize()
        # t2 = time_time.time()
        # print("Time for calculate eigen vectors: ", t2-t1, self.num_component)

        # print(eig_value[0], eig_vecs[0])
        # print(torch.matmul(eig_vecs[0].T, eig_vecs[0]))
        # print(torch.bmm(eig_vecs.transpose(-2, -1), eig_vecs))
        # [B*L*V, K]
        init_prob = self.get_initial_prob(argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), cache_hidden_states=offsets_cache_hidden_states)
        
        # For symmetry, C=V
        # eig_vecs: [B*L*V, K, K'], K' might <= K when only considering top-k eigenvalues
        const = torch.bmm(eig_vecs.transpose(-2, -1), init_prob.unsqueeze(-1)).squeeze(-1) # [B*L*V,K']

        time = time.reshape(-1, 1, 1).expand(-1, L, V).reshape(-1, 1) # [B*L*V,1]
        # print((const * torch.exp(time * eig_value)).unsqueeze(1)[0])
        # print(eig_vecs[0])
        
        # 
        p = (const * torch.exp(time * eig_value)).unsqueeze(1) * eig_vecs # [B*L*V, K, K']
        p = torch.sum(p, dim=-1) # [B*L*V, K]
        # 


        # # for debug!!!
        # randperm = torch.randperm(eig_value.size(-1))
        # eig_value_randperm = eig_value.T[randperm].T
        # eig_vecs_randperm = eig_vecs.transpose(-1, 0)[randperm].transpose(-1, 0)
        # print(eig_value_randperm.size(), eig_vecs_randperm.size())
        # const = torch.bmm(eig_vecs_randperm.transpose(-2, -1), init_prob.unsqueeze(-1)).squeeze(-1) # [B*L*V,K]
        # print(const.size())
        # p_randperm = (const * torch.exp(time * eig_value_randperm)).unsqueeze(1) * eig_vecs_randperm # [B*L*V, K, K]
        # p_randperm = torch.sum(p_randperm, dim=-1) # [B*L*V, K]
        # print(p_randperm[0])
        # print(p[0])
        # print(p_randperm[1005])
        # print(p[1005])
        # print(torch.dist(p_randperm, p))
        # exit()

        # print(p.size())
        # print(const.size(), time.size(), eig_value.size(), eig_vecs.size())
        # print((const * torch.exp(time * eig_value)).size())
        # p2 = torch.bmm(eig_vecs, (const * torch.exp(time * eig_value)).unsqueeze(-1)).squeeze(-1)
        # print(p2.size())
        # print(torch.dist(p2[0], p[0]))

        # p3 = torch.bmm(torch.linalg.matrix_exp(rates_matrix.view(-1, self.num_component, self.num_component) * time.unsqueeze(-1)),init_prob.unsqueeze(-1)).squeeze(-1)
        # p1 = torch.mm(torch.linalg.matrix_exp(A*t),p0)
        # torch.dist(Q @ torch.diag_embed(L) @ Q.mH, A)
        # exit()

        # p = p.view(-1, V, self.num_component) # [B*L, V, K]
        p = p.view(B, -1, self.number_of_component) #[B, L*V, K]
        # print(self.number_of_component)
        p = torch.clamp(p, min=self.eps, max=self.inf) # For numerical stable...
        
        if return_component_logits:
            # func = torch.log if not self.apply_log_softmax else torch.log_softmax
            if not self.apply_log_softmax:
                info_dict["component_logits"] = torch.log(p).view(B, L, V, -1)  # 
            else:
                info_dict["component_logits"] = torch.log_softmax(p.view(B, L, V, -1), dim=-2)  # 
        
        if self.config.data_property in argv:
            host_label = argv[self.config.data_property].unsqueeze(1).repeat(1, p.size(1)).unsqueeze(-1).long() # [B, L*V, 1]
            p = torch.gather(p, -1, host_label).squeeze(-1).view(B, L, -1) # [B, L, V]
            # print(p.size())
            # print(argv["host"][:3])
            # print(p[:3, 0, 0])
            # exit()
            # print(torch.sum(p, dim=-1))

            if not self.apply_log_softmax:
                logits = torch.log(p) 
            else:
                logits = torch.log_softmax(p, dim=-1) # [B, L, V] 
                # print("apply_log_softmax")
            # print(logits.size())
            # print(logits)
            # exit()
            # log_probs = - torch.gather(logp, argv.get("labels").unsqueeze(-1)).squeeze(-1) # [B, L]

            if return_rates_matrix:
                # rates_matrix_seq = torch.gather(rates_matrix, 2, argv.get("input_ids").view(B, L, 1, 1, 1).expand(-1, -1, -1, self.num_component, self.num_component))
                # rates_matrix_seq = rates_matrix_seq.squeeze(2) # [B, L, K, K]
                # info_dict["rates_matrix"] = rates_matrix_seq
                info_dict["rates_matrix"] = rates_matrix.view(B, L, V, self.num_component, self.num_component)

            if return_init_prob:
                init_prob_ = init_prob.view(B, L, V, self.num_component) # [B, L, V, K]
                # init_prob_seq = torch.gather(init_prob_, 1, argv.get("input_ids").view(B, L, 1, 1).expand(-1, -1, -1, self.num_component))
                # info_dict["init_prob"] = init_prob_seq.squeeze(-2)
                info_dict["init_prob"] = init_prob_

            # if len(info_dict) > 0:
            # if return_eigen_values:
            #     info_dict["eigen_values"] = eig_value.view(B, L, V, -1)
            #     print(info_dict["eigen_values"][0,0,0])

            # torch.cuda.synchronize()
            # t3 = time_time.time()
            # print("Time for calculate probability: ", t3-t2)
        else:
            logits = None
        
        # print(logits[:, :2])
        return GPTOutputs(logits=logits, info_dict=info_dict)
        return logits, info_dict
        # return logits

class GPT2TimeModelMultiHostsNew(GPT2TimeModelMultiHosts):
    def __init__(self, config, num_component, symmetry=True, base_models=None, **args) -> None:
        super().__init__(config, num_component, symmetry, base_models, **args)
        if config.offset_pos_function == "softmax":
            self.offset_pos_func = nn.Softmax(dim=-2)
        
    
    def build_models(self, config, num_component, symmetry=True, base_models=None, **args):
        if base_models is not None:
            self.trans_base = base_models["trans_base"]
            self.offsets_base = base_models["offsets_base"]
        else:
            self.trans_base = transformers.GPT2LMHeadModel(config)
            self.offsets_base = transformers.GPT2LMHeadModel(config)

        if config.output_layer_type == "linear":
            self.trans_heads = nn.Linear(config.hidden_size, config.vocab_size * num_component * num_component) 
            self.offsets_heads = nn.Linear(config.hidden_size, config.vocab_size * num_component)
            
            # self.trans_layer_norm = nn.BatchNorm2d(1) # 
            # self.trans_layer_norm = nn.LayerNorm([num_component, num_component], elementwise_affine=False)

        if getattr(config, "add_geo_info", False):
            self.geo_feats = nn.Parameter(args["geo_info"], requires_grad=False) # [n_countries, n_feats]
            self.dis_sigma = nn.Parameter(torch.tensor(1000.0), requires_grad=True) # [n_feats, n_feats]
            self.dis_ratio = nn.Parameter(torch.tensor(0.1), requires_grad=True) # [n_feats, n_feats]
            # self.geo_feat_map = nn.Parameter(torch.randn(self.geo_feats.size(1), self.geo_feats.size(1)), requires_grad=True) # [n_feats, n_feats]

    def get_initial_prob(self, input_ids, labels, attention_mask, cache_hidden_states=None):
        if cache_hidden_states is None:
            outputs = self.trans_base.forward(input_ids = input_ids, labels = labels, \
                    attention_mask = attention_mask, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
        else:
            hidden_states = cache_hidden_states
        offset = self.offsets_heads(hidden_states).view(-1, self.config.vocab_size, self.num_component) # [B*L, V, K]
        offset = self.offset_pos_func(offset)
        return offset.view(-1, self.num_component) # [B*L*V, K]

    def get_trans_matrix_from_base(self, input_ids, labels, attention_mask, cache_hidden_states=None, generation=False, **args):

        if cache_hidden_states is None:
            outputs = self.trans_base.forward(input_ids = input_ids, labels = labels, \
                    attention_mask = attention_mask, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
        else:
            hidden_states = cache_hidden_states
        
        if generation:
            L = hidden_states.size(1) # Real len of sequences
            hidden_states = hidden_states[:, -1:, :]
        # print(hidden_states.size())

        rates_matrix = self.trans_heads(hidden_states) # [B, L, V*K*K]
        rates_matrix = rates_matrix.view(-1, self.num_component, self.num_component) # [B*L*V, K, K]

        if getattr(self.config, "positive_definite", False):
            rates_matrix = torch.bmm(rates_matrix, rates_matrix.transpose(-2, -1)) # positive definite
        elif getattr(self.config, "negative_definite", False):
            rates_matrix = -torch.bmm(rates_matrix, rates_matrix.transpose(-2, -1)) # negative definite
        else: # Just symmetry
            rates_matrix = (rates_matrix + rates_matrix.transpose(-2, -1)) / 2

        # if getattr(self.config, "power_iteration", False): 
        #     # negative definite
        #     rates_matrix = torch.bmm(rates_matrix, rates_matrix.transpose(-2, -1)) # positive definite
        # else:
        #     rates_matrix = (rates_matrix + rates_matrix.transpose(-2, -1)) / 2
        
        # print(rates_matrix[0])
        # print(rates_matrix.size(), torch.mean(rates_matrix.view(-1, self.num_component*self.num_component), dim=-1, keepdim=True).size())
        # print(torch.var(rates_matrix.view(-1, self.num_component*self.num_component), dim=-1, keepdim=True).size())
        # rates_matrix = (rates_matrix - torch.mean(rates_matrix.view(-1, self.num_component*self.num_component), dim=-1).view(-1, 1, 1)) / torch.var(rates_matrix.view(-1, self.num_component*self.num_component), dim=-1).view(-1, 1, 1)
        # rates_matrix = self.trans_layer_norm(rates_matrix)
            
        rates_matrix = self.pos_func(rates_matrix)
        rates_matrix = torch.clamp(rates_matrix, min=self.eps, max=self.inf)
        
        if self.config.add_geo_info:
            # print(self.geo_feats.size())
            distance = torch.sum((self.geo_feats.unsqueeze(1) - self.geo_feats.unsqueeze(0)) ** 2, dim=-1) # K*K
            # print(distance)
            # geo_feats = self.geo_feats @ self.geo_feat_map @ self.geo_feats.T # K*K
            # print(geo_feats)
            geo_feats = self.dis_ratio * torch.exp(-distance / self.dis_sigma) # positive & symmetry
            # print(geo_feats)
            # geo_feats = self.pos_func(geo_feats + geo_feats.T) # positive & symmetry
            # print(geo_feats, geo_feats.size(), rates_matrix.size())
            rates_matrix += geo_feats.view(1, 1, 1, geo_feats.size(0), geo_feats.size(1))
            # print(rates_matrix)

        assert torch.all(rates_matrix.transpose(-2, -1) == rates_matrix), rates_matrix[rates_matrix.transpose(-2, -1) != rates_matrix]

        # to avoid overfloat
        # rates_matrix = torch.clamp(rates_matrix, min=self.eps, max=self.inf)
        
        if getattr(self.config, "power_iteration", False): 
            # eig_value, eig_vector = block_power_method(rates_matrix, k=self.config.topk_eigen_values)
            
            X = rates_matrix # / 100 # stable?
            # print(X.max(), X.min())
            eig_vals = []
            eig_vecs = []
            # print(self.config.topk_eigen_values, X.size())
            for _ in range(self.config.topk_eigen_values):
                # both v1 and v2 are doing the same thing
                eig_val, eig_vec = power_iteration(X, num_iterations=getattr(self.config, "power_iteration_num", 100),
                                                   power_gradient=getattr(self.config, "power_gradient", False)) # eig_val: [B*L*V, 1], eig_vec: [B*L*V, K, 1]
                # print(eig_val.size(), eig_vec.size())
                eig_vals.append(eig_val)
                eig_vecs.append(eig_vec)
                X = X - eig_val.unsqueeze(-1) * torch.bmm(eig_vec, eig_vec.transpose(-2,-1))
            eig_value = torch.cat(eig_vals, dim=-1) # [B*L*V, K']
            eig_vector = torch.cat(eig_vecs, dim=-1) # [B*L*V, K, K']
            # print(eig_value.size(), eig_vector.size())
            # print(torch.mean(eig_value))
        else:
            eig_value, eig_vector = torch.linalg.eigh(rates_matrix.view(-1, self.num_component, self.num_component)) 

        # if self.config.lobpcg: # Approximate by top-k eigenvalues & eigenvectors
        #     eig_value, eig_vector = torch.lobpcg(rates_matrix.view(-1, self.num_component, self.num_component), 
        #                                          k=self.config.lobpcg_k, method="ortho", niter=10)
        #     print(eig_value.size(), eig_vector.size(), self.num_component)
        # else:
        

        assert ~torch.any(torch.isnan(eig_value))
        assert ~torch.any(torch.isnan(eig_vector))

        # torch.cuda.synchronize()
        # t2=time()
        # gpu_time = t2-t1
        # print("gpu_time", gpu_time)
        if generation:
            eig_value = eig_value.view(-1, 1, self.config.vocab_size, eig_value.size(-1)).repeat(1, L, 1, 1).view(-1, eig_value.size(-1))
            eig_vector = eig_vector.view(-1, 1, self.config.vocab_size, eig_vector.size(-2), eig_vector.size(-1)).repeat(1, L, 1, 1, 1).view(-1, eig_vector.size(-2), eig_vector.size(-1))
            # print(eig_value.size(), eig_vector.size(), rates_matrix.size())
        return  eig_value, eig_vector, rates_matrix

class GPT2TimeModelMultiHostsNew2(GPT2TimeModelMultiHosts):
    def __init__(self, config, num_component, symmetry=True, base_models=None, **args) -> None:
        super().__init__(config, num_component, symmetry, base_models, **args)
        if config.offset_pos_function == "softmax":
            self.offset_pos_func = nn.Softmax(dim=-2)
        
    def build_models(self, config, num_component, symmetry=True, base_models=None, **args):
        if base_models is not None:
            self.trans_base = base_models["trans_base"]
            self.offsets_base = base_models["offsets_base"]
        else:
            self.trans_base = transformers.GPT2LMHeadModel(config)
            self.offsets_base = transformers.GPT2LMHeadModel(config)

        if config.output_layer_type == "linear":
            self.trans_heads = nn.Linear(config.hidden_size, config.vocab_size * num_component * num_component) 
            self.offsets_heads = nn.Linear(config.hidden_size, config.vocab_size * num_component)
            
        if getattr(config, "add_geo_info", False):
            self.geo_feats = nn.Parameter(args["geo_info"], requires_grad=False) # [n_countries, n_feats]
            self.dis_sigma = nn.Parameter(torch.tensor(1000.0), requires_grad=True) # [n_feats, n_feats]
            self.dis_ratio = nn.Parameter(torch.tensor(0.1), requires_grad=True) # [n_feats, n_feats]
            # self.geo_feat_map = nn.Parameter(torch.randn(self.geo_feats.size(1), self.geo_feats.size(1)), requires_grad=True) # [n_feats, n_feats]

    def get_initial_prob(self, input_ids, labels, attention_mask, cache_hidden_states=None):
        if cache_hidden_states is None:
            outputs = self.trans_base.forward(input_ids = input_ids, labels = labels, \
                    attention_mask = attention_mask, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
        else:
            hidden_states = cache_hidden_states
        offset = self.offsets_heads(hidden_states).view(-1, self.config.vocab_size, self.num_component) # [B*L, V, K]
        offset = self.offset_pos_func(offset)
        return offset.view(-1, self.num_component) # [B*L*V, K]

    def get_trans_matrix_from_base(self, input_ids, labels, attention_mask, cache_hidden_states=None, generation=False, **args):

        if cache_hidden_states is None:
            outputs = self.trans_base.forward(input_ids = input_ids, labels = labels, \
                    attention_mask = attention_mask, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
        else:
            hidden_states = cache_hidden_states
        
        if generation:
            L = hidden_states.size(1) # Real len of sequences
            hidden_states = hidden_states[:, -1:, :]
        # print(hidden_states.size())

        rates_matrix = self.trans_heads(hidden_states) # [B, L, V*K*K]
        rates_matrix = rates_matrix.view(-1, self.num_component, self.num_component) # [B*L*V, K, K]

        if getattr(self.config, "positive_definite", False):
            rates_matrix = torch.bmm(rates_matrix, rates_matrix.transpose(-2, -1)) # positive definite
        elif getattr(self.config, "negative_definite", False):
            rates_matrix = -torch.bmm(rates_matrix, rates_matrix.transpose(-2, -1)) # negative definite
        else: # Just symmetry
            rates_matrix = (rates_matrix + rates_matrix.transpose(-2, -1)) / 2

        # rates_matrix = self.pos_func(rates_matrix)
        rates_matrix = torch.clamp(rates_matrix, min=-self.inf, max=self.inf) #
        
        if self.config.add_geo_info:
            # print(self.geo_feats.size())
            distance = torch.sum((self.geo_feats.unsqueeze(1) - self.geo_feats.unsqueeze(0)) ** 2, dim=-1) # K*K
            # print(distance)
            # geo_feats = self.geo_feats @ self.geo_feat_map @ self.geo_feats.T # K*K
            # print(geo_feats)
            geo_feats = self.dis_ratio * torch.exp(-distance / self.dis_sigma) # positive & symmetry
            # print(geo_feats)
            # geo_feats = self.pos_func(geo_feats + geo_feats.T) # positive & symmetry
            # print(geo_feats, geo_feats.size(), rates_matrix.size())
            rates_matrix += geo_feats.view(1, 1, 1, geo_feats.size(0), geo_feats.size(1))
            # print(rates_matrix)

        assert torch.all(rates_matrix.transpose(-2, -1) == rates_matrix), rates_matrix[rates_matrix.transpose(-2, -1) != rates_matrix]

        # to avoid overfloat
        # rates_matrix = torch.clamp(rates_matrix, min=self.eps, max=self.inf)
        
        if getattr(self.config, "power_iteration", False): 
            # eig_value, eig_vector = block_power_method(rates_matrix, k=self.config.topk_eigen_values)
            
            X = rates_matrix # / 100 # stable?
            # print(X.max(), X.min())
            eig_vals = []
            eig_vecs = []
            # print(self.config.topk_eigen_values, X.size())
            for _ in range(self.config.topk_eigen_values):
                # both v1 and v2 are doing the same thing
                eig_val, eig_vec = power_iteration(X, num_iterations=getattr(self.config, "power_iteration_num", 100),
                                                   power_gradient=getattr(self.config, "power_gradient", False)) # eig_val: [B*L*V, 1], eig_vec: [B*L*V, K, 1]
                # print(eig_val.size(), eig_vec.size())
                eig_vals.append(eig_val)
                eig_vecs.append(eig_vec)
                X = X - eig_val.unsqueeze(-1) * torch.bmm(eig_vec, eig_vec.transpose(-2,-1))
            eig_value = torch.cat(eig_vals, dim=-1) # [B*L*V, K']
            eig_vector = torch.cat(eig_vecs, dim=-1) # [B*L*V, K, K']
            # print(eig_value.size(), eig_vector.size())
            # print(torch.mean(eig_value))
        else:
            # print(rates_matrix.dtype)
            eig_value, eig_vector = torch.linalg.eigh(rates_matrix.view(-1, self.num_component, self.num_component)) 
            # print(torch.sum(eig_value>0))
            # print(torch.max(eig_value))
        # if self.config.lobpcg: # Approximate by top-k eigenvalues & eigenvectors
        #     eig_value, eig_vector = torch.lobpcg(rates_matrix.view(-1, self.num_component, self.num_component), 
        #                                          k=self.config.lobpcg_k, method="ortho", niter=10)
        #     print(eig_value.size(), eig_vector.size(), self.num_component)
        # else:
        

        assert ~torch.any(torch.isnan(eig_value))
        assert ~torch.any(torch.isnan(eig_vector))

        # torch.cuda.synchronize()
        # t2=time()
        # gpu_time = t2-t1
        # print("gpu_time", gpu_time)
        if generation:
            eig_value = eig_value.view(-1, 1, self.config.vocab_size, eig_value.size(-1)).repeat(1, L, 1, 1).view(-1, eig_value.size(-1))
            eig_vector = eig_vector.view(-1, 1, self.config.vocab_size, eig_vector.size(-2), eig_vector.size(-1)).repeat(1, L, 1, 1, 1).view(-1, eig_vector.size(-2), eig_vector.size(-1))
            # print(eig_value.size(), eig_vector.size(), rates_matrix.size())
        return  eig_value, eig_vector, rates_matrix
    
    def forward(self, input_time, **argv):
        generation = argv.pop("generation", False)

        return_rates_matrix = argv.get("return_rates_matrix", False)
        return_init_prob = argv.get("return_init_prob", False)
        return_eigen_values = argv.get("return_eigen_values", False)
        return_logits = argv.get("return_logits", False)
        return_component_logits = argv.get("return_component_logits", False)
        
        caches = self.forward_cache_hidden_states(input_time, **argv)
        trans_cache_hidden_states = caches.get("trans_cache_hidden_states")
        offsets_cache_hidden_states = caches.get("offsets_cache_hidden_states")
        # Overwrite
        if "trans_cache_hidden_states" in argv:
            trans_cache_hidden_states = argv.get("trans_cache_hidden_states")
        if "offsets_cache_hidden_states" in argv:
            offsets_cache_hidden_states = argv.get("offsets_cache_hidden_states")

        info_dict = {}

        time = discretize_time(
            input_time, 
            one_step=False, 
            normalize_time_a=self.config.normalize_time_a, 
            normalize_time_b=self.config.normalize_time_b,
            discrete=False)
        beam_size = argv.get("input_ids").size(0) // input_time.size(0)
        time = time.unsqueeze(1).repeat(1, beam_size).view(-1)

        B, L = argv.get("input_ids").size()
        V = self.config.vocab_size

        # [B*L*V, K], [B*L*V,K,K]
        if self.share_base:
            eig_value, eig_vecs, rates_matrix = self.get_trans_matrix_from_base(
                argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), 
                cache_hidden_states=trans_cache_hidden_states, generation=generation)
        else:
            eig_value, eig_vecs, rates_matrix = self.get_trans_matrix(
                argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), generation=generation)
        
        # [B*L*V, K]
        init_prob = self.get_initial_prob(argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), cache_hidden_states=offsets_cache_hidden_states)
        
        # For symmetry, C=V
        # eig_vecs: [B*L*V, K, K'], K' might <= K when only considering top-k eigenvalues
        const = torch.bmm(eig_vecs.transpose(-2, -1), init_prob.unsqueeze(-1)).squeeze(-1) # [B*L*V,K']
        time = time.reshape(-1, 1, 1).expand(-1, L, V).reshape(-1, 1) # [B*L*V,1]
        
        p = (const * torch.exp(time * eig_value)).unsqueeze(1) * eig_vecs # [B*L*V, K, K]
        p = torch.sum(p, dim=-1) # [B*L*V, K]
        p = p.view(B, -1, self.number_of_component) #[B, L*V, K]
        p = self.pos_func(p) # make sure p is postive
        p = torch.clamp(p, min=self.eps, max=self.inf) # For numerical stable...
        
        if return_component_logits:
            # func = torch.log if not self.apply_log_softmax else torch.log_softmax
            if not self.apply_log_softmax:
                info_dict["component_logits"] = torch.log(p).view(B, L, V, -1)  # 
            else:
                info_dict["component_logits"] = torch.log_softmax(p.view(B, L, V, -1), dim=-2)  # 
        
        if self.config.data_property in argv:
            host_label = argv[self.config.data_property].unsqueeze(1).repeat(1, p.size(1)).unsqueeze(-1).long() # [B, L*V, 1]
            p = torch.gather(p, -1, host_label).squeeze(-1).view(B, L, -1) # [B, L, V]

            if not self.apply_log_softmax:
                logits = torch.log(p) 
            else:
                logits = torch.log_softmax(p, dim=-1) # [B, L, V] 
            
            if return_rates_matrix:
                # rates_matrix_seq = torch.gather(rates_matrix, 2, argv.get("input_ids").view(B, L, 1, 1, 1).expand(-1, -1, -1, self.num_component, self.num_component))
                # rates_matrix_seq = rates_matrix_seq.squeeze(2) # [B, L, K, K]
                # info_dict["rates_matrix"] = rates_matrix_seq
                info_dict["rates_matrix"] = rates_matrix.view(B, L, V, self.num_component, self.num_component)

            if return_init_prob:
                init_prob_ = init_prob.view(B, L, V, self.num_component) # [B, L, V, K]
                # init_prob_seq = torch.gather(init_prob_, 1, argv.get("input_ids").view(B, L, 1, 1).expand(-1, -1, -1, self.num_component))
                # info_dict["init_prob"] = init_prob_seq.squeeze(-2)
                info_dict["init_prob"] = init_prob_

        else:
            logits = None
        
        return GPTOutputs(logits=logits, info_dict=info_dict)


class GPT2TimeModelMultiHostsMatrixExpSample(GPT2TimeModelMultiHosts):
    def __init__(self, config, num_component, symmetry=True, base_models=None, **args) -> None:
        super().__init__(config, num_component, symmetry, base_models, **args)
        if config.offset_pos_function == "softmax":
            self.offset_pos_func = nn.Softmax(dim=-2)
        
    def build_models(self, config, num_component, symmetry=True, base_models=None, **args):
        if base_models is not None:
            self.trans_base = base_models["trans_base"]
            self.offsets_base = base_models["offsets_base"]
        else:
            self.trans_base = transformers.GPT2LMHeadModel(config)
            self.offsets_base = transformers.GPT2LMHeadModel(config)

        if config.output_layer_type == "linear":
            self.trans_heads = nn.Linear(config.hidden_size, config.vocab_size * num_component * num_component) 
            self.offsets_heads = nn.Linear(config.hidden_size, config.vocab_size * num_component)
            
        if getattr(config, "add_geo_info", False):
            self.geo_feats = nn.Parameter(args["geo_info"], requires_grad=False) # [n_countries, n_feats]
            self.dis_sigma = nn.Parameter(torch.tensor(1000.0), requires_grad=True) # [n_feats, n_feats]
            self.dis_ratio = nn.Parameter(torch.tensor(0.1), requires_grad=True) # [n_feats, n_feats]
            # self.geo_feat_map = nn.Parameter(torch.randn(self.geo_feats.size(1), self.geo_feats.size(1)), requires_grad=True) # [n_feats, n_feats]

    def get_initial_prob(self, input_ids, labels, attention_mask, cache_hidden_states=None):
        if cache_hidden_states is None:
            outputs = self.trans_base.forward(input_ids = input_ids, labels = labels, \
                    attention_mask = attention_mask, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
        else:
            hidden_states = cache_hidden_states
        offset = self.offsets_heads(hidden_states).view(-1, self.config.vocab_size, self.num_component) # [B*L, V, K]
        offset = self.offset_pos_func(offset)
        return offset.view(-1, self.num_component) # [B*L*V, K]

    def get_trans_matrix_from_base(self, input_ids, labels, attention_mask, cache_hidden_states=None, generation=False, **args):

        if cache_hidden_states is None:
            outputs = self.trans_base.forward(input_ids = input_ids, labels = labels, \
                    attention_mask = attention_mask, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
        else:
            hidden_states = cache_hidden_states
        
        if generation:
            L = hidden_states.size(1) # Real len of sequences
            hidden_states = hidden_states[:, -1:, :]
        # print(hidden_states.size())

        rates_matrix = self.trans_heads(hidden_states) # [B, L, V*K*K]
        rates_matrix = rates_matrix.view(-1, self.num_component, self.num_component) # [B*L*V, K, K]

        rates_matrix = self.pos_func(rates_matrix)
        rates_matrix = torch.clamp(rates_matrix, min=-self.inf, max=self.inf) #
        
        if self.config.add_geo_info:
            # print(self.geo_feats.size())
            distance = torch.sum((self.geo_feats.unsqueeze(1) - self.geo_feats.unsqueeze(0)) ** 2, dim=-1) # K*K
            # print(distance)
            # geo_feats = self.geo_feats @ self.geo_feat_map @ self.geo_feats.T # K*K
            # print(geo_feats)
            geo_feats = self.dis_ratio * torch.exp(-distance / self.dis_sigma) # positive & symmetry
            # print(geo_feats)
            # geo_feats = self.pos_func(geo_feats + geo_feats.T) # positive & symmetry
            # print(geo_feats, geo_feats.size(), rates_matrix.size())
            rates_matrix += geo_feats.view(1, 1, 1, geo_feats.size(0), geo_feats.size(1))
            # print(rates_matrix)

        return  rates_matrix
    
    def forward(self, input_time, **argv):
        generation = argv.pop("generation", False)

        return_rates_matrix = argv.get("return_rates_matrix", False)
        return_init_prob = argv.get("return_init_prob", False)
        return_eigen_values = argv.get("return_eigen_values", False)
        return_logits = argv.get("return_logits", False)
        return_component_logits = argv.get("return_component_logits", False)
        
        caches = self.forward_cache_hidden_states(input_time, **argv)
        trans_cache_hidden_states = caches.get("trans_cache_hidden_states")
        offsets_cache_hidden_states = caches.get("offsets_cache_hidden_states")
        # Overwrite
        if "trans_cache_hidden_states" in argv:
            trans_cache_hidden_states = argv.get("trans_cache_hidden_states")
        if "offsets_cache_hidden_states" in argv:
            offsets_cache_hidden_states = argv.get("offsets_cache_hidden_states")

        info_dict = {}

        time = discretize_time(
            input_time, 
            one_step=False, 
            normalize_time_a=self.config.normalize_time_a, 
            normalize_time_b=self.config.normalize_time_b,
            discrete=False)
        beam_size = argv.get("input_ids").size(0) // input_time.size(0)
        time = time.unsqueeze(1).repeat(1, beam_size).view(-1)

        B, L = argv.get("input_ids").size()
        V = self.config.vocab_size

        # [B*L*V, K], [B*L*V,K,K]
        if self.share_base:
            rates_matrix = self.get_trans_matrix_from_base(
                argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), 
                cache_hidden_states=trans_cache_hidden_states, generation=generation)
        else:
            rates_matrix = self.get_trans_matrix(
                argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), generation=generation)
        
        # [B*L*V, K]
        init_prob = self.get_initial_prob(argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), cache_hidden_states=offsets_cache_hidden_states)
        

        # calculate the power 
        time = time.reshape(-1, 1, 1).expand(-1, L, V).reshape(-1, 1, 1) # [B*L*V,1]

        lbd = self.config.poisson_lambda
        m = torch.distributions.poisson.Poisson(torch.tensor([lbd]))
        sample_number = int(self.config.poisson_sample_num)
        # print(sample_number, lbd)
        sampled_k = m.sample(sample_shape=(sample_number,)).squeeze(-1).tolist()
        # print(sampled_k)
        sampled_k_counter = Counter(sampled_k).most_common()
        # print(sampled_k_counter)

        rates_matrix_exp_all = []
        for n, count in sampled_k_counter:
            # print(n.item(), rates_matrix.size())
            rates_matrix_power = torch.linalg.matrix_power(rates_matrix, int(n)) 
            weight = math.exp(lbd) / (lbd ** n)
            # print(rates_matrix_power.size(), time.size(), weight)
            rates_matrix_exp = weight * rates_matrix_power * (time ** n) * count # [B*L*V, K, K]
            rates_matrix_exp_all.append(rates_matrix_exp)
        
        rates_matrix_exp_all = sum(rates_matrix_exp_all)

        # print(rates_matrix_exp_all)

        p = torch.bmm(rates_matrix_exp_all, init_prob.unsqueeze(-1)).squeeze(-1)
        # print(torch.sum(p<0))
        p = torch.clamp(p, min=self.eps, max=self.inf) # For numerical stable...
        # print(torch.sum(p<0))
        # print(p.size())

        if return_component_logits:
            # func = torch.log if not self.apply_log_softmax else torch.log_softmax
            if not self.apply_log_softmax:
                info_dict["component_logits"] = torch.log(p).view(B, L, V, -1)  # 
            else:
                info_dict["component_logits"] = torch.log_softmax(p.view(B, L, V, -1), dim=-2)  # 
        
        if self.config.data_property in argv:
            p = p.view(B, -1, p.size(-1))
            host_label = argv[self.config.data_property].unsqueeze(1).repeat(1, p.size(1)).unsqueeze(-1).long() # [B, L*V, 1]
            p = torch.gather(p, -1, host_label).squeeze(-1).view(B, L, -1) # [B, L, V]

            if not self.apply_log_softmax:
                logits = torch.log(p) 
            else:
                logits = torch.log_softmax(p, dim=-1) # [B, L, V] 
            
            if return_rates_matrix:
                # rates_matrix_seq = torch.gather(rates_matrix, 2, argv.get("input_ids").view(B, L, 1, 1, 1).expand(-1, -1, -1, self.num_component, self.num_component))
                # rates_matrix_seq = rates_matrix_seq.squeeze(2) # [B, L, K, K]
                # info_dict["rates_matrix"] = rates_matrix_seq
                info_dict["rates_matrix"] = rates_matrix.view(B, L, V, self.num_component, self.num_component)

            if return_init_prob:
                init_prob_ = init_prob.view(B, L, V, self.num_component) # [B, L, V, K]
                # init_prob_seq = torch.gather(init_prob_, 1, argv.get("input_ids").view(B, L, 1, 1).expand(-1, -1, -1, self.num_component))
                # info_dict["init_prob"] = init_prob_seq.squeeze(-2)
                info_dict["init_prob"] = init_prob_

        else:
            logits = None
        
        return GPTOutputs(logits=logits, info_dict=info_dict)


class GPT2TimeModelMultiHostsGlobal(GPT2TimeModelMultiHosts):
    def __init__(self, config, num_component, symmetry=True, base_models=None, **args) -> None:
        super().__init__(config, 1, symmetry, base_models, **args)
    
    def forward(self, input_time, **argv):
        if self.config.data_property in argv:
            new_host_label = argv[self.config.data_property].new_zeros(argv[self.config.data_property].size()).long()
            argv[self.config.data_property] = new_host_label
        return super().forward(input_time, **argv)

class GPT2TimeModelMultiHostsDiag(GPT2TimeModelMultiHosts):
    def __init__(self, config, num_component, symmetry=True, base_models=None, **args) -> None:
        super().__init__(config, num_component, symmetry, base_models, **args)
        
    def build_models(self, config, num_component, symmetry=True, base_models=None, **args):
        if config.share_base:
            # Share the GPT-2 Transformer layers
            if base_models is not None:
                self.trans_base = base_models["trans_base"]
                self.offsets_base = base_models["offsets_base"]
            else:
                self.trans_base = transformers.GPT2LMHeadModel(config)
                self.offsets_base = transformers.GPT2LMHeadModel(config)

            if config.output_layer_type == "linear":
                self.trans_heads = nn.ModuleList([nn.Linear(config.hidden_size, config.vocab_size) for _ in range(num_component)]) 
                self.offsets_heads = nn.ModuleList([nn.Linear(config.hidden_size, config.vocab_size) for _ in range(num_component)]) 
            elif config.output_layer_type == "gpt2":
                output_layer_config = args["output_layer_config"]
                # print(output_layer_config)
                self.trans_heads = nn.ModuleList([transformers.GPT2LMHeadModel(output_layer_config) for _ in range(num_component)]) 
                self.offsets_heads = nn.ModuleList([transformers.GPT2LMHeadModel(output_layer_config) for _ in range(num_component)]) 
        else:
            self.trans_rates = nn.ModuleList([transformers.GPT2LMHeadModel(config) for _ in range(num_component)])        
            self.offsets = nn.ModuleList([transformers.GPT2LMHeadModel(config) for _ in range(num_component)]) # K
        
    def get_trans_matrix_from_base(self, input_ids, labels, attention_mask, cache_hidden_states=None, generation=False):

        if cache_hidden_states is None:
            outputs = self.trans_base.forward(input_ids = input_ids, labels = labels, \
                    attention_mask = attention_mask, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
        else:
            hidden_states = cache_hidden_states
        
        if generation:
            L = hidden_states.size(1) # Real len of sequences
            hidden_states = hidden_states[:, -1:, :]
        # print(hidden_states.size())

        rates_matrix = []
        for i in range(self.num_component):
            if self.output_layer_type == "linear":
                rate = self.trans_heads[i](hidden_states)
            elif self.output_layer_type == "gpt2":
                rate = self.trans_heads[i].forward(inputs_embeds = hidden_states,\
                    attention_mask = attention_mask, output_hidden_states=True).logits

            rate = rate + torch.rand(rate.size()).to(rate.device) * self.eps # To avoid the A is ill-defined.
            rate = self.pos_func(rate) # TODO: any better choice?
            rates_matrix.append(rate)
        
        # rates_matrix = [item for row in rates_matrix for item in row]
        rates_matrix = torch.stack(rates_matrix, dim=-1) # [B, L, V, K]
        
        # d = torch.abs((rates_matrix.unsqueeze(-1) - rates_matrix.unsqueeze(-2))) # .min()
        # d.fill_diagonal_(5.0)
        # d  += torch.eye(self.num_component, self.num_component, device=d.device).view(1,1,1,self.num_component, self.num_component)
        # print(d[0,0,0])
        # print((d).min(), self.eps)
        # rates_matrix = rates_matrix.view(rates_matrix.size(0), rates_matrix.size(1), rates_matrix.size(2), self.num_component, self.num_component)
        
        # ## old -> using eigh (problematic, it is not trained)
        # # print(self.eps)
        # rates_matrix = torch.diag_embed(rates_matrix) # [B, L, V, K**2]
        # assert torch.all(rates_matrix.transpose(-2, -1) == rates_matrix), rates_matrix[rates_matrix.transpose(-2, -1) != rates_matrix]
        # rates_matrix = torch.clamp(rates_matrix, min=0, max=self.inf) # To be zero?
        # print(torch.max(rates_matrix), torch.min(rates_matrix), torch.mean(rates_matrix))
        # eig_value, eig_vector = torch.linalg.eigh(rates_matrix.view(-1, self.num_component, self.num_component)) 
        # assert ~torch.any(torch.isnan(eig_value))
        # assert ~torch.any(torch.isnan(eig_vector))
        # ## old

        eig_value = rates_matrix.view(-1, self.num_component)
        eig_vector = torch.eye(self.num_component, self.num_component, device=eig_value.device).view(1, self.num_component, self.num_component).repeat(eig_value.size(0), 1, 1)

        # random_perm = torch.randperm(self.num_component).to(eig_value.device)
        # eig_value = eig_value.T[random_perm].T
        # eig_vector = eig_vector.transpose(0, -1)[random_perm].transpose(0, -1)

        if generation:
            eig_value = eig_value.view(-1, 1, self.config.vocab_size, self.num_component).repeat(1, L, 1, 1).view(-1, self.num_component)
            eig_vector = eig_vector.view(-1, 1, self.config.vocab_size, self.num_component, self.num_component).repeat(1, L, 1, 1, 1).view(-1, self.num_component, self.num_component)
            # print(eig_value.size(), eig_vector.size(), rates_matrix.size())
        
        # print(eig_value, eig_vector, rates_matrix)
        # print(eig_value[0])
        # print(eig_vector[0])
        # print(eig_vector.requires_grad)
        # print(eig_value.requires_grad)
        # print(rates_matrix[0, 0, 0])
        return  eig_value, eig_vector, rates_matrix

class GPT2TimeModelMultiHostsV2(transformers.GPT2LMHeadModel):
    # Different from V1, directly modeling the eigen-vectors and eigen-values & c
    # Just to make sure that P >= 0
    def __init__(self, config, num_component, symmetry=True, base_models=None, **args) -> None:
        super().__init__(config)
        self.num_component = num_component
        assert symmetry, "It is more complicated for non-symmetry matrix. Leave it as a future work"
        
        self.share_base = config.share_base

        # K * K matrix, but only pass one GPT2 for each component.
        if config.share_base:
            if base_models is not None:
                self.trans_base = base_models["trans_base"]
                self.offsets_base = base_models["offsets_base"]
            else:
                self.trans_base = transformers.GPT2LMHeadModel(config)
                self.offsets_base = transformers.GPT2LMHeadModel(config)

            self.eig_vecs_out = nn.ModuleList([nn.Linear(config.hidden_size, config.vocab_size) for _ in range(num_component ** 2)])
            self.eig_value_out = nn.ModuleList([nn.Linear(config.hidden_size, config.vocab_size) for _ in range(num_component)])
            
            self.offsets_out = nn.ModuleList([nn.Linear(config.hidden_size, config.vocab_size) for _ in range(num_component)]) # K
        else:
            self.eig_vecs_base = nn.ModuleList([transformers.GPT2LMHeadModel(config) for _ in range(num_component)])
            self.eig_vecs_out = nn.ModuleList([nn.Linear(config.hidden_size, config.vocab_size) for _ in range(num_component ** 2)])
            self.eig_value = nn.ModuleList([transformers.GPT2LMHeadModel(config) for _ in range(num_component)])
            self.offsets = nn.ModuleList([transformers.GPT2LMHeadModel(config) for _ in range(num_component)]) # K
        self.config = config
        self.eps = 1e-12
        self.inf = 1e5
        self.pos_func = torch.nn.Softplus()

    def get_initial_prob(self,  input_ids, labels, attention_mask, cache_hidden_states=None):
        prob_vectors = []
        if self.share_base:
            if cache_hidden_states is None:
                outputs = self.offsets_base.forward(input_ids = input_ids, labels = labels, \
                            attention_mask = attention_mask, output_hidden_states=True)
                hidden_states = outputs.hidden_states[-1]
            else:
                hidden_states = cache_hidden_states
            
            for k in range(self.num_component):
                prob_vectors.append(self.offsets_out[k].forward(hidden_states))
        else:
            for k in range(self.num_component):
                outputs = self.offsets[k].forward(input_ids = input_ids, labels = labels, \
                        attention_mask = attention_mask, output_hidden_states=True)
                offset = outputs.logits # [B, L, V]
                prob_vectors.append(offset)
        prob_vectors = torch.stack(prob_vectors, dim=-1) # [B, L, V, K]
        return prob_vectors.view(-1, self.num_component)

    def get_trans_matrix(self, input_ids, labels, attention_mask, cache_hidden_states=None):
        
        if self.share_base:
            if cache_hidden_states is None:
                outputs = self.trans_base.forward(input_ids = input_ids, labels = labels, \
                        attention_mask = attention_mask, output_hidden_states=True)
                hidden_states = outputs.hidden_states[-1]
            else:
                hidden_states = cache_hidden_states

            eig_vector = [[0 for _ in range(self.num_component)] for _ in range(self.num_component)]
            for i in range(self.num_component):
                for j in range(self.num_component):
                    rate = self.eig_vecs_out[i * self.num_component + j].forward(hidden_states)
                    eig_vector[j][i] = rate
        else:
            # each column is a eigenvector!        
            eig_vector = [[0 for _ in range(self.num_component)] for _ in range(self.num_component)]
            for i in range(self.num_component):
                outputs = self.eig_vecs_base[i].forward(input_ids = input_ids, labels = labels, \
                        attention_mask = attention_mask, output_hidden_states=True)
                hidden_states = outputs.hidden_states[-1]
                for j in range(self.num_component):                
                    rate = self.eig_vecs_out[i * self.num_component + j].forward(hidden_states)
                    eig_vector[j][i] = rate
        eig_vector = [item for row in eig_vector for item in row]
        eig_vector = torch.stack(eig_vector, dim=-1) # [B, L, V, K**2]
        eig_vector = eig_vector.view(eig_vector.size(0), eig_vector.size(1), eig_vector.size(2), self.num_component, self.num_component)
        # [B, L, V, K, K]
        eig_vector = eig_vector / torch.sum(eig_vector ** 2, dim=-2, keepdims=True).sqrt()
        # print(eig_vector[0,0,0])
        # print(torch.norm(eig_vector, dim=3))

        eig_values = []
        if self.share_base:
            for i in range(self.num_component):
                eig_values.append(self.eig_value_out[i].forward(hidden_states)) # [B, L, V]
        else:
            for i in range(self.num_component):
                outputs = self.eig_value[i].forward(input_ids = input_ids, labels = labels, \
                        attention_mask = attention_mask, output_hidden_states=True)
                eig_values.append(outputs.logits) # [B, L, V]
        eig_values = torch.stack(eig_values, dim=-1) # [B, L, V, K]
        
        return  eig_values.view(-1, self.num_component), eig_vector.view(-1, self.num_component, self.num_component)

    @classmethod
    def from_config(cls, config, num_component):
        model = cls(config, num_component)
        return model
    
    def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
        return {"input_ids": input_ids, "input_time": kwargs["input_time"], self.config.data_property: kwargs[self.config.data_property]}
  
    def forward(self, input_time, **argv):
        info_dict={}
        return_component_logits = argv.get("return_component_logits", False)
        trans_cache_hidden_states = argv.get("trans_cache_hidden_states")
        offsets_cache_hidden_states = argv.get("offsets_cache_hidden_states")

        time = discretize_time(
            input_time, 
            one_step=False, 
            normalize_time_a=self.config.normalize_time_a, 
            normalize_time_b=self.config.normalize_time_b,
            discrete=False)
        beam_size = argv.get("input_ids").size(0) // input_time.size(0)
        time = time.unsqueeze(1).repeat(1, beam_size).view(-1)

        B, L = argv.get("input_ids").size()
        V = self.config.vocab_size

        # [B*L*V, K], [B*L*V,K,K]
        eig_value, eig_vecs = self.get_trans_matrix(argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), cache_hidden_states=trans_cache_hidden_states)
        # [B*L*V, K]
        init_prob = self.get_initial_prob(argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), cache_hidden_states=offsets_cache_hidden_states)
        const = init_prob

        time = time.reshape(-1, 1, 1).expand(-1, L, V).reshape(-1, 1) # [B*L*V,1]
        p = (const * torch.exp(time * eig_value)).unsqueeze(1) * eig_vecs # [B*L*V, K, K]
        # print(p)
        # exit()
        p = p.view(-1, V, self.num_component, self.num_component) # [B*L, V, K, K]
        p = torch.sum(p, dim=-1) # [B*L, V, K]
        p = p.view(B, -1, p.size(-1)) #[B, L*V, K]
        if return_component_logits:
            info_dict["component_logits"] = self.pos_func(p).view(B, L, V, -1)

        host_label = argv[self.config.data_property].unsqueeze(1).repeat(1, p.size(1)).unsqueeze(-1).long() # [B, L*V, 1]
        p = torch.gather(p, -1, host_label).squeeze(-1).view(B, L, -1) # [B, L, V]
        # Here, p could be negative. To make sure that the (un-normalized) probability is meaningful
        logits = torch.log(self.pos_func(p) + self.eps) 
        # log_probs = - torch.gather(logp, argv.get("labels").unsqueeze(-1)).squeeze(-1) # [B, L]
        return GPTOutputs(logits=logits, info_dict=info_dict)

class GPT2TimeModelMultiHostsV2_new(transformers.GPT2LMHeadModel):
    def __init__(self, config, num_component, symmetry=True, base_models=None, **args) -> None:
        super().__init__(config)
        self.num_component = num_component
        self.symmetry = symmetry        
        self.share_base = config.share_base

        # K * K matrix, but only pass one GPT2 for each component.
        if config.share_base:
            if base_models is not None:
                self.trans_base = base_models["trans_base"]
                self.offsets_base = base_models["offsets_base"]
            else:
                self.trans_base = transformers.GPT2LMHeadModel(config)
                self.offsets_base = transformers.GPT2LMHeadModel(config)

            self.eig_vecs_out = nn.Linear(config.hidden_size, config.vocab_size * (num_component ** 2))
            self.eig_value_out = nn.Linear(config.hidden_size, config.vocab_size * num_component)
            self.offsets_out = nn.Linear(config.hidden_size, config.vocab_size * num_component)
        else:
            raise NotImplemented
            # self.eig_vecs_base = nn.ModuleList([transformers.GPT2LMHeadModel(config) for _ in range(num_component)])
            # self.eig_vecs_out = nn.ModuleList([nn.Linear(config.hidden_size, config.vocab_size) for _ in range(num_component ** 2)])
            # self.eig_value = nn.ModuleList([transformers.GPT2LMHeadModel(config) for _ in range(num_component)])
            # self.offsets = nn.ModuleList([transformers.GPT2LMHeadModel(config) for _ in range(num_component)]) # K
        self.config = config
        self.eps = 1e-12
        self.inf = 1e5
        # self.pos_func = torch.nn.Softplus()
        if config.pos_function == "softplus":
            self.pos_func = torch.nn.Softplus()
        elif config.pos_function == "sigmoid":
            self.pos_func = torch.nn.Sigmoid()
        elif config.pos_function == "relu":
            self.pos_func = torch.nn.ReLU()
        elif config.pos_function == "exp":
            self.pos_func = torch.exp
        else:
            self.pos_func = None
        # logging.info("pos_func", config.pos_function)

        if config.offset_pos_function == "softmax":
            self.offset_pos_func = nn.Softmax(dim=-2) # [B, L, V, K]
        elif config.offset_pos_function == "softplus":
            self.offset_pos_func = torch.nn.Softplus()
        elif config.offset_pos_function == "relu":
            self.offset_pos_func = torch.nn.ReLU()
        else:
            self.offset_pos_func = None
        
        if config.eig_vecs_layer_norm:
            self.eig_vecs_ln = nn.LayerNorm([num_component, num_component])
        if config.eig_val_layer_norm:
            self.eig_val_ln = nn.LayerNorm(num_component)


    def get_initial_prob(self,  input_ids, labels, attention_mask, cache_hidden_states=None):
        prob_vectors = []
        if self.share_base:
            if cache_hidden_states is None:
                outputs = self.offsets_base.forward(input_ids = input_ids, labels = labels, \
                            attention_mask = attention_mask, output_hidden_states=True)
                hidden_states = outputs.hidden_states[-1]
            else:
                hidden_states = cache_hidden_states
            
            prob_vectors = self.offsets_out.forward(hidden_states).view(hidden_states.size(0), hidden_states.size(1), -1, self.num_component)
            # [B, L, V, K]
            prob_vectors = self.offset_pos_func(prob_vectors)
        else:
            for k in range(self.num_component):
                outputs = self.offsets[k].forward(input_ids = input_ids, labels = labels, \
                        attention_mask = attention_mask, output_hidden_states=True)
                offset = outputs.logits # [B, L, V]
                prob_vectors.append(self.offset_pos_func(offset))
            prob_vectors = torch.stack(prob_vectors, dim=-1) # [B, L, V, K]
        
        return prob_vectors.view(-1, self.num_component)

    def get_trans_matrix(self, input_ids, labels, attention_mask, cache_hidden_states=None):
        if self.share_base:
            if cache_hidden_states is None:
                outputs = self.trans_base.forward(input_ids = input_ids, labels = labels, \
                        attention_mask = attention_mask, output_hidden_states=True)
                hidden_states = outputs.hidden_states[-1]
            else:
                hidden_states = cache_hidden_states

            eig_vector = self.eig_vecs_out.forward(hidden_states).view(hidden_states.size(0), hidden_states.size(1), 
                                                                       self.config.vocab_size, -1) # [B, L, V, K*K]
            # eig_vector = [[0 for _ in range(self.num_component)] for _ in range(self.num_component)]
            # for i in range(self.num_component):
                # for j in range(self.num_component):
                    # rate = self.eig_vecs_out[i * self.num_component + j].forward(hidden_states)
                    # eig_vector[j][i] = rate
        else:
            # each column is a eigenvector!        
            eig_vector = [[0 for _ in range(self.num_component)] for _ in range(self.num_component)]
            for i in range(self.num_component):
                outputs = self.eig_vecs_base[i].forward(input_ids = input_ids, labels = labels, \
                        attention_mask = attention_mask, output_hidden_states=True)
                hidden_states = outputs.hidden_states[-1]
                for j in range(self.num_component):                
                    rate = self.eig_vecs_out[i * self.num_component + j].forward(hidden_states)
                    eig_vector[j][i] = rate
            eig_vector = [item for row in eig_vector for item in row]
            eig_vector = torch.stack(eig_vector, dim=-1) # [B, L, V, K**2]
        
        eig_vector = eig_vector.view(eig_vector.size(0), eig_vector.size(1), eig_vector.size(2), self.num_component, self.num_component)
        # [B, L, V, K, K]
        
        if self.config.eig_vecs_layer_norm: # add the layer norm
            # print("before", eig_vector[0, 0, 0])
            eig_vector = self.eig_vecs_ln(eig_vector)
            # print("after", eig_vector[0, 0, 0])
        
        if self.symmetry: 
            # Try QR?
            # print(eig_vector.size(), eig_vector.dtype)
            # slow
            eig_vector, _ = torch.linalg.qr(eig_vector.view(-1, self.num_component, self.num_component))
            # unstable
            # eig_vector = gram_schmidt(eig_vector.view(-1, self.num_component, self.num_component))
            
            # simple
            # eig_vector = eig_vector.view(-1, self.num_component, self.num_component)
            # eig_vector = eig_vector / torch.sum(eig_vector ** 2, dim=-2, keepdims=True).sqrt()
        else:
            eig_vector = eig_vector.view(-1, self.num_component, self.num_component)
            eig_vector = eig_vector / torch.sum(eig_vector ** 2, dim=-2, keepdims=True).sqrt()
        # print(torch.norm(eig_vector, dim=3))
        
        
        if self.share_base:
            eig_values = self.eig_value_out.forward(hidden_states).view(hidden_states.size(0), hidden_states.size(1), -1, self.num_component)
            # for i in range(self.num_component):
                # eig_values.append(self.eig_value_out[i].forward(hidden_states)) # [B, L, V]
        else:
            eig_values = []
            for i in range(self.num_component):
                outputs = self.eig_value[i].forward(input_ids = input_ids, labels = labels, \
                        attention_mask = attention_mask, output_hidden_states=True)
                eig_values.append(outputs.logits) # [B, L, V]
            eig_values = torch.stack(eig_values, dim=-1) # [B, L, V, K]
        
        if self.config.eig_val_layer_norm:
            # print()
            eig_values = self.eig_val_ln(eig_values)
        
        return  eig_values.view(-1, self.num_component), eig_vector.view(-1, self.num_component, self.num_component)

    @classmethod
    def from_config(cls, config, num_component, symmetry):
        model = cls(config, num_component, symmetry=symmetry)
        return model
    
    def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
        return {"input_ids": input_ids, "input_time": kwargs["input_time"], self.config.data_property: kwargs[self.config.data_property]}
  
    def forward(self, input_time, **argv):
        info_dict={}
        return_component_logits = argv.get("return_component_logits", False)
        trans_cache_hidden_states = argv.get("trans_cache_hidden_states")
        offsets_cache_hidden_states = argv.get("offsets_cache_hidden_states")

        time = discretize_time(
            input_time, 
            one_step=False, 
            normalize_time_a=self.config.normalize_time_a, 
            normalize_time_b=self.config.normalize_time_b,
            discrete=False)
        beam_size = argv.get("input_ids").size(0) // input_time.size(0)
        time = time.unsqueeze(1).repeat(1, beam_size).view(-1)

        B, L = argv.get("input_ids").size()
        V = self.config.vocab_size

        # [B*L*V, K], [B*L*V,K,K]
        eig_value, eig_vecs = self.get_trans_matrix(argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), cache_hidden_states=trans_cache_hidden_states)
        # [B*L*V, K]
        init_prob = self.get_initial_prob(argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), cache_hidden_states=offsets_cache_hidden_states)
        # const = torch.bmm(eig_vecs.transpose(-2, -1), init_prob.unsqueeze(-1)).squeeze(-1) # [B*L*V,K]
        # const = torch.linalg.solve(eig_vecs, init_prob.unsqueeze(-1)).squeeze(-1) # [B*L*V,K]
        
        if self.symmetry:
            const = torch.bmm(eig_vecs.transpose(-2, -1), init_prob.unsqueeze(-1)).squeeze(-1) # [B*L*V,K]
        else:
            # print("?") # U^{-1} * init_probs
            # print(eig_vecs.size(), torch.norm(eig_vecs, dim=-2))
            # U^{-1} P0 = const
            const = torch.linalg.solve(eig_vecs, init_prob.unsqueeze(-1)).squeeze(-1) # [B*L*V,K]


        time = time.reshape(-1, 1, 1).expand(-1, L, V).reshape(-1, 1) # [B*L*V,1]
        p = (const * torch.exp(time * eig_value)).unsqueeze(1) * eig_vecs # [B*L*V, K, K]
        # p = p.view(-1, V, self.num_component, self.num_component) # [B*L, V, K, K]
        p = torch.sum(p, dim=-1) # [B*L, V, K]
        p = p.view(B, -1, p.size(-1)) #[B, L*V, K]
        if return_component_logits:
            info_dict["component_logits"] = torch.log(self.pos_func(p).view(B, L, V, -1) + self.eps)

        host_label = argv[self.config.data_property].unsqueeze(1).repeat(1, p.size(1)).unsqueeze(-1).long() # [B, L*V, 1]
        p = torch.gather(p, -1, host_label).squeeze(-1).view(B, L, -1) # [B, L, V]
        
        # Here, p could be negative. To make sure that the (un-normalized) probability is meaningful
        if self.pos_func is not None:
            logits = torch.log(self.pos_func(p) + self.eps) 
        else:
            logits = p

        return GPTOutputs(logits=logits, info_dict=info_dict)
        
class GPT2TimeModelMultiHostsParamShareOLD(transformers.GPT2LMHeadModel):
    def __init__(self, config, num_component, base_models=None, **args) -> None:
        super().__init__(config)
        self.num_component = num_component
        if base_models:
            self.base_rate = base_models["trans_base"]
            self.base_offset = base_models["offsets_base"]
        else:
            self.base_rate = transformers.GPT2LMHeadModel(config)
            self.base_offset = transformers.GPT2LMHeadModel(config)
        self.rate_output_heads = nn.ModuleList([nn.Linear(config.hidden_size, config.vocab_size) for _ in range(num_component ** 2)])
        self.offset_output_heads = nn.ModuleList([nn.Linear(config.hidden_size, config.vocab_size) for _ in range(num_component ** 2)])

    @classmethod
    def from_config(cls, config, num_component):
        model = cls(config, num_component)
        return model
    
    def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
        """
        Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method.
        """
        return {"input_ids": input_ids, "input_time": kwargs["input_time"], self.config.data_property: kwargs[self.config.data_property]}
    
    def get_rates(self, **argv):
        trans_cache_hidden_states = argv.get("trans_cache_hidden_states")
        if trans_cache_hidden_states is None:
            outputs = self.base_rate.forward(input_ids = argv.get("input_ids"), labels = argv.get("labels"), \
                attention_mask = argv.get("attention_mask"), output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
        else:
            hidden_states = trans_cache_hidden_states
        logits = []
        for i in range(self.num_component):
            logits.append(self.rate_output_heads[i](hidden_states)) 
        logits = torch.stack(logits, dim=-1) # [B, L, V, K]
        host_label = argv[self.config.data_property].view(-1, 1, 1, 1).repeat(1, logits.size(1), logits.size(2), 1).long() # [B, L, V, 1]
        logits_reduce = torch.gather(logits, -1, host_label).squeeze(-1) # [B, L, V]
        return logits_reduce, logits
    
    def get_offsets(self, **argv):
        offsets_cache_hidden_states = argv.get("offsets_cache_hidden_states")
        if offsets_cache_hidden_states is None:
            outputs = self.base_offset.forward(input_ids = argv.get("input_ids"), labels = argv.get("labels"), \
                attention_mask = argv.get("attention_mask"), output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
        else:
            hidden_states = offsets_cache_hidden_states

        logits = []
        for i in range(self.num_component):
            logits.append(self.offset_output_heads[i](hidden_states)) 
        logits = torch.stack(logits, dim=-1) # [B, L, V, K]
        host_label = argv[self.config.data_property].view(-1, 1, 1, 1).repeat(1, logits.size(1), logits.size(2), 1).long() # [B, L, V, 1]
        logits_reduce = torch.gather(logits, -1, host_label).squeeze(-1) # [B, L, V]
        return logits_reduce, logits

    def forward(self, input_time, **argv):
        time = discretize_time(
            input_time, 
            one_step=False, 
            normalize_time_a=self.config.normalize_time_a, 
            normalize_time_b=self.config.normalize_time_b,
            discrete=False)
        beam_size = argv.get("input_ids").size(0) // input_time.size(0)
        time = time.unsqueeze(1).repeat(1, beam_size).view(-1)

        reduce_rate, rate = self.get_rates(**argv)
        reduce_offset, offset = self.get_rates(**argv)
        # print(reduce_rate.size(), )
        # logits = self.get_rates(**argv) * time.unsqueeze(-1).unsqueeze(-1) + self.get_offsets(**argv)
        logits_reduce = reduce_rate * time.unsqueeze(-1).unsqueeze(-1) + reduce_offset
        logits_full = rate * time.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + offset
        return GPTOutputs(logits=logits_reduce, info_dict={"component_logits": logits_full} if argv.get("return_component_logits", False) else {})
        # return logits

class GPT2TimeModelMultiHostsParamShare(transformers.GPT2LMHeadModel):
    def __init__(self, config, num_component, base_models=None, **args) -> None:
        super().__init__(config)
        self.num_component = num_component
        if base_models:
            self.base_rate = base_models["trans_base"]
            self.base_offset = base_models["offsets_base"]
        else:
            # if args.transformer_offset:
            self.base_rate = transformers.GPT2LMHeadModel(config)
            # self.base_offset = transformers.GPT2LMHeadModel(config)

            if config.transformer_offset:
                self.base_offset = transformers.GPT2LMHeadModel(config)
            else:
                self.base_offset = self.base_rate

        self.rate_output_heads = nn.Linear(config.hidden_size, config.vocab_size * num_component)
        self.offset_output_heads = nn.Linear(config.hidden_size, config.vocab_size * num_component)
        # self.rate_output_heads = nn.ModuleList([nn.Linear(config.hidden_size, config.vocab_size) for _ in range(num_component ** 2)])
        # self.offset_output_heads = nn.ModuleList([nn.Linear(config.hidden_size, config.vocab_size) for _ in range(num_component ** 2)])

    @classmethod
    def from_config(cls, config, num_component):
        model = cls(config, num_component)
        return model
    
    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs) -> Dict[str, Any]:
        """
        Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method.
        """
        kwargs.pop("attention_mask")
        # print(kwargs["attention_mask"].size())
        # print(input_ids.size())
        ret_dict = super().prepare_inputs_for_generation(input_ids, past, **kwargs)
        # print(ret_dict["input_ids"].size())
        ret_dict["input_time"] = kwargs["input_time"]
        ret_dict[self.config.data_property] = kwargs[self.config.data_property]
        ret_dict["generation"] = True

        if self.config.num_beams > 1:
            # extend the input_time and 
            bsz = ret_dict[self.config.data_property].size(0)
            expanded_return_idx = (
                torch.arange(bsz).view(-1, 1).repeat(1, self.config.num_beams).view(-1).to(input_ids.device)
            )
            ret_dict[self.config.data_property] = ret_dict[self.config.data_property].index_select(0, expanded_return_idx)
            ret_dict["input_time"] = ret_dict["input_time"].index_select(0, expanded_return_idx)

        return ret_dict
        
        return {"input_ids": input_ids, "input_time": kwargs["input_time"], self.config.data_property: kwargs[self.config.data_property]}
    
    def get_rates(self, **argv):
        trans_cache_hidden_states = argv.get("trans_cache_hidden_states")
        if trans_cache_hidden_states is None:
            outputs = self.base_rate.forward(input_ids = argv.get("input_ids"), labels = argv.get("labels"), \
                attention_mask = argv.get("attention_mask"), output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
            # print(hidden_states.size())
        else:
            hidden_states = trans_cache_hidden_states
        
        B, L = hidden_states.size(0), hidden_states.size(1)
        
        logits = self.rate_output_heads(hidden_states).view(B, L, -1, self.num_component) # [B, L, V*K]
        # print(argv[self.config.data_property].view(-1, 1, 1, 1).size(), logits.size())
        host_label = argv[self.config.data_property].view(-1, 1, 1, 1).repeat(1, logits.size(1), logits.size(2), 1).long() # [B, L, V, 1]
        # print(host_label.size())
        logits_reduce = torch.gather(logits, -1, host_label).squeeze(-1) # [B, L, V]
        # print(logits_reduce.size())
        return logits_reduce, logits
    
    def get_offsets(self, **argv):
        offsets_cache_hidden_states = argv.get("offsets_cache_hidden_states")
        if offsets_cache_hidden_states is None:
            outputs = self.base_offset.forward(input_ids = argv.get("input_ids"), labels = argv.get("labels"), \
                attention_mask = argv.get("attention_mask"), output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
        else:
            hidden_states = offsets_cache_hidden_states
        B, L = hidden_states.size(0), hidden_states.size(1)
        logits = self.offset_output_heads(hidden_states).view(B, L, -1, self.num_component)# [B, L, V*K ]
        host_label = argv[self.config.data_property].view(-1, 1, 1, 1).repeat(1, logits.size(1), logits.size(2), 1).long() # [B, L, V, 1]
        logits_reduce = torch.gather(logits, -1, host_label).squeeze(-1) # [B, L, V]
        return logits_reduce, logits

    def forward(self, input_time, **argv):
        time = discretize_time(
            input_time, 
            one_step=False, 
            normalize_time_a=self.config.normalize_time_a, 
            normalize_time_b=self.config.normalize_time_b,
            discrete=False)
        # print(time.size(), argv.get("input_ids").size())
        beam_size = argv.get("input_ids").size(0) // input_time.size(0)
        time = time.unsqueeze(1).repeat(1, beam_size).view(-1)
        # print(beam_size.size(), argv.get("input_ids").size())
        reduce_rate, rate = self.get_rates(**argv)
        reduce_offset, offset = self.get_rates(**argv)
        # print(reduce_rate.size(), reduce_offset.size(), time.size())
        # for debug
        # _reduce_rate = torch.gather(reduce_rate, -1, argv.get("input_ids").unsqueeze(-1)).squeeze(-1)
        # print(_reduce_rate.max(dim=-1), _reduce_rate.min(dim=-1), _reduce_rate.mean())
        # print(reduce_rate.size(), reduce_offset.size(), time.size())
        # exit()
        # logits = self.get_rates(**argv) * time.unsqueeze(-1).unsqueeze(-1) + self.get_offsets(**argv)
        logits_reduce = reduce_rate * time.unsqueeze(-1).unsqueeze(-1) + reduce_offset
        logits_full = rate * time.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + offset
        return GPTOutputs(logits=logits_reduce, info_dict={"component_logits": logits_full} if argv.get("return_component_logits", False) else {})
        # return logits

class GPT2TimeModelMultiHostsIndependent(transformers.GPT2LMHeadModel):
    def __init__(self, config, num_component, symmetry=True) -> None:
        super().__init__(config)
        self.num_component = num_component
        # assert symmetry, "It is more complicated for non-symmetry matrix. Leave it as a future work"

        self.trans_rates = nn.ModuleList([transformers.GPT2LMHeadModel(config) for _ in range(num_component)])        
        self.offsets = nn.ModuleList([transformers.GPT2LMHeadModel(config) for _ in range(num_component)]) # K
        self.config = config
        self.eps = 1e-12
        self.inf = 1e5
        self.pos_func = torch.nn.Softplus()

    def get_initial_prob(self,  input_ids, labels, attention_mask):
        prob_vectors = []
        for k in range(self.num_component):
            outputs = self.offsets[k].forward(input_ids = input_ids, labels = labels, \
                    attention_mask = attention_mask, output_hidden_states=True)
            offset = outputs.logits # [B, L, V]
            x0 = nn.Softmax(dim=-1)(offset)
            prob_vectors.append(x0)
        
        prob_vectors = torch.stack(prob_vectors, dim=-1) # [B, L, V, K]
        return prob_vectors.view(-1, self.num_component)

    # def get_trans_matrix(self, input_ids, labels, attention_mask):        
    #     rates_matrix = [[0 for _ in range(self.num_component)] for _ in range(self.num_component)]
    #     for i in range(self.num_component):
    #         outputs = self.trans_rates[i].forward(input_ids = input_ids, labels = labels, \
    #             attention_mask = attention_mask, output_hidden_states=True)
    #         rate = outputs.logits # [B, L, V]
    #         rate = rate + torch.rand(rate.size()).to(rate.device) * self.eps # To avoid the A is ill-defined.
    #         rate = self.pos_func(rate) # TODO: any better choice?
    #         rates_matrix[i][i] = rate
    #         for j in range(self.num_component): # Diagonal
    #             if j == i:
    #                 continue
    #             rates_matrix[i][j] = rate.new_zeros(rate.size()) + 0.001
    #             # if the non-diagonal parts are zero (or a very small number), the loss will stuck (idk why)

    #     rates_matrix = [item for row in rates_matrix for item in row]
    #     rates_matrix = torch.stack(rates_matrix, dim=-1) # [B, L, V, K**2]
    #     rates_matrix = rates_matrix.view(rates_matrix.size(0), rates_matrix.size(1), rates_matrix.size(2), self.num_component, self.num_component)
    #     assert torch.all(rates_matrix.transpose(-2, -1) == rates_matrix)
    #     # print(rates_matrix.size())
    #     # print(rates_matrix[0,0,0])
    #     # print(rates_matrix[0,1,4])

    #     # TODO: to avoid overfloat?
    #     # rates_matrix = torch.clamp(rates_matrix, min=self.eps, max=self.inf)
    #     eig_value, eig_vector = torch.linalg.eigh(rates_matrix.view(-1, self.num_component, self.num_component)) 
    #     # L: value, BxK, V: BxKxK
    #     # print(eig_value[0], eig_vector[0])
    #     # print(eig_value[2], eig_vector[2])
    #     # exit()
    #     return  eig_value, eig_vector
    
    def get_trans_matrix(self, input_ids, labels, attention_mask):
        # Simplified version: assume b == c == 0, learn the a & d
        rates_matrix = []
        for i in range(self.num_component):
            outputs = self.trans_rates[i].forward(input_ids = input_ids, labels = labels, \
                attention_mask = attention_mask, output_hidden_states=True)
            rate = outputs.logits # [B, L, V]
            rate = rate + torch.rand(rate.size()).to(rate.device) * self.eps # To avoid the A is ill-defined.
            rate = self.pos_func(rate) # TODO: any better choice?
            rates_matrix.append(rate)
        rates_matrix = torch.stack(rates_matrix, dim=-1) # [B, L, V, K]
        eig_value = rates_matrix.view(-1, self.num_component)
        eig_vector = torch.eye(self.num_component).to(eig_value.device).unsqueeze(0).repeat(eig_value.size(0), 1, 1)
        return  eig_value, eig_vector

    @classmethod
    def from_config(cls, config, num_component):
        model = cls(config, num_component)
        return model
    
    def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
        """
        Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method.
        """
        return {"input_ids": input_ids, "input_time": kwargs["input_time"]}
    
    def forward(self, input_time, **argv):

        time = discretize_time(
            input_time, 
            one_step=False, 
            normalize_time_a=self.config.normalize_time_a, 
            normalize_time_b=self.config.normalize_time_b,
            discrete=False)
        # print(time)
        # exit()
        beam_size = argv.get("input_ids").size(0) // input_time.size(0)
        time = time.unsqueeze(1).repeat(1, beam_size).view(-1)

        B, L = argv.get("input_ids").size()
        V = self.config.vocab_size

        # [B*L*V, K], [B*L*V,K,K]
        eig_value, eig_vecs = self.get_trans_matrix(argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"))
        # [B*L*V, K]
        # print(torch.bmm(eig_vecs.transpose(-2, -1), eig_vecs))
        init_prob = self.get_initial_prob(argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"))
        # print(init_prob[0])
        
        # For symmetry, C=V
        const = torch.bmm(eig_vecs.transpose(-2, -1), init_prob.unsqueeze(-1)).squeeze(-1) # [B*L*V,K]
        # print(const[0])

        time = time.reshape(-1, 1, 1).expand(-1, L, V).reshape(-1, 1) # [B*L*V,1]
        p = (const * torch.exp(time * eig_value)).unsqueeze(1) * eig_vecs # [B*L*V, K, K]
        # print(time[0], (const *torch.exp(time * eig_value))[0])
        # print(p[0])
        p = p.view(-1, V, self.num_component, self.num_component) # [B*L, V, K, K]
        p = torch.sum(p, dim=-1) # [B*L, V, K]
        p = p.view(B, -1, p.size(-1)) #[B, L*V, K]
        # print(p[0, 0])
        host_label = argv[self.config.data_property].unsqueeze(1).repeat(1, p.size(1)).unsqueeze(-1).long() # [B, L*V, 1]
        p = torch.gather(p, -1, host_label).squeeze(-1).view(B, L, -1) # [B, L, V]
        # print(torch.sum(p, dim=-1))
        # print(p[0, 0])

        # TODO: for numerical stable...
        p = torch.clamp(p, min=self.eps, max=self.inf)

        logits = torch.log(p) 
        # print(logits.size())
        # print(logits[0,0,0])
        # exit()
        # log_probs = - torch.gather(logp, argv.get("labels").unsqueeze(-1)).squeeze(-1) # [B, L]
        return logits

class GPT2TimeModelMultiHostsAndGlobal(GPT2TimeModelMultiHosts):
    def __init__(self, config, num_component, symmetry=True, base_models=None, add_global_model=False, \
        aggregate_global_loss=False,
        **args) -> None:
        
        if add_global_model: # Treat global as a pseudo-continent
            super().__init__(config, num_component + 1, symmetry, base_models, **args)
        else:
            super().__init__(config, num_component, symmetry, base_models, **args)
        
        self.add_global_model = add_global_model
        # self.disable_transmission_from_local_to_global = disable_transmission_from_local_to_global
        self.aggregate_global_loss = aggregate_global_loss

        assert add_global_model or aggregate_global_loss
        if aggregate_global_loss:
            self.aggregate_weights = nn.Parameter(torch.randn(num_component), requires_grad=True)

    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs) -> Dict[str, Any]:
        """
        Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method.
        """
        ret_dict = super().prepare_inputs_for_generation(input_ids, past, **kwargs)
        ret_dict["test_global_loss_w"] = kwargs["test_global_loss_w"]
        return ret_dict

    def forward(self, input_time, **argv):
        B = input_time.size(0)
        L = argv.get("input_ids").size(1)

        argv["return_component_logits"] = True
        outputs = super().forward(input_time, **argv)

        # print(outputs.info_dict["component_logits"].size())
        all_logits = outputs.info_dict["component_logits"].view(B, -1, self.num_component) # [B, L*V, K]
        
        if self.add_global_model:
            global_logits = all_logits[:, :, -1].view(B, L, -1) # [B, L, V]
        else:
            agg_weight = self.pos_func(self.aggregate_weights)
            global_logits = torch.logsumexp(all_logits.view(B, L, -1, self.num_component) + agg_weight.view(1, 1, 1, -1), dim=-1) - torch.logsumexp(agg_weight.view(1, 1, 1, -1), dim=-1)
        
        # info_dict["global_logits"] = global_logits
        logits=outputs.logits

        if not self.training: # generation or inference
            alpha = argv["test_global_loss_w"]
            # print(alpha)
            # print(torch.mean(alpha))
            # print(torch.mean(logits), torch.mean(global_logits))
            logits = ( 1 - alpha) * logits + (alpha) * global_logits
            # print(torch.mean(logits))
            # print(mix_logits.size())
        
        # argv["add_global_logits_to_local_logits"] = True
        # argv["add_global_logits_to_local_logits_w"] = 0.5
        # if argv.get("add_global_logits_to_local_logits", False): # might happen during the inference (or traning)
        #     # mix the log_p of global & local models?
        #     # could be different choice of global model?
        #     alpha = argv["add_global_logits_to_local_logits_ratio"]
        #     global_logp = F.log_softmax(global_logits, dim=-1) + math.log(alpha)
        #     local_logp = F.log_softmax(logits, dim=-1) + math.log(1 - alpha)
        #     logits = torch.logsumexp(torch.stack([global_logp, local_logp], dim=-1), dim=-1)

        return GPTOutputs(logits, info_dict={"global_logits": global_logits})
    
        # B = input_time.size(0)
        # L = argv.get("input_ids").size(1)

        # argv["return_component_logits"] = True
        # info_dict = super().forward(input_time, **argv).info_dict
        # # print(info_dict["component_logits"].size())
        # # print(info_dict["component_logits"])
        # p = torch.exp(info_dict["component_logits"].view(B, -1, self.num_component)) # # [B, L*V, K]        
        # host_label = argv[self.config.data_property].unsqueeze(1).repeat(1, p.size(1)).unsqueeze(-1).long() # [B, L*V, 1]
        
        # if self.add_global_model:
        #     global_p = p[:, :, -1].view(B, L, -1) # [B, L, V]
        # else:
        #     agg_weight = self.pos_func(self.aggregate_weights)
        #     global_p = torch.sum(p.view(B, L, -1, self.num_component) * agg_weight.view(1, 1, 1, -1), dim=-1) / torch.sum(agg_weight) # [B, L, V, K]

        # p = torch.gather(p, -1, host_label).squeeze(-1).view(B, L, -1) # [B, L, V] -> Local P
        
        # logits = torch.log(p) 

        # global_logits = torch.log(global_p)
        # info_dict["global_logits"] = global_logits

        # print(global_logits.size())
        # print(logits.size())
        # exit()

        # return GPTOutputs(logits=logits, info_dict=info_dict)

class GPT2TimeModelMultiHostsMatrixExp(GPT2TimeModelMultiHosts):
    def __init__(self, config, num_component, symmetry=True, base_models=None, **args) -> None:
        super().__init__(config, num_component, symmetry, base_models, **args)
        
    def build_models(self, config, num_component, symmetry=True, base_models=None, **args):
        if config.share_base:
            if base_models is not None:
                self.trans_base = base_models["trans_base"]
                self.offsets_base = base_models["offsets_base"]
            else:
                self.trans_base = transformers.GPT2LMHeadModel(config)
                self.offsets_base = transformers.GPT2LMHeadModel(config)

            if config.output_layer_type == "linear":
                self.trans_heads = nn.ModuleList([nn.Linear(config.hidden_size, config.vocab_size) for _ in range(num_component * num_component)]) 
                self.offsets_heads = nn.ModuleList([nn.Linear(config.hidden_size, config.vocab_size) for _ in range(num_component)]) 
            elif config.output_layer_type == "gpt2":
                output_layer_config = args["output_layer_config"]
                # print(output_layer_config)
                self.trans_heads = nn.ModuleList([transformers.GPT2LMHeadModel(output_layer_config) for _ in range(num_component * num_component)]) 
                self.offsets_heads = nn.ModuleList([transformers.GPT2LMHeadModel(output_layer_config) for _ in range(num_component)]) 
        else:
            self.trans_rates = nn.ModuleList([transformers.GPT2LMHeadModel(config) for _ in range(num_component * num_component)])        
            self.offsets = nn.ModuleList([transformers.GPT2LMHeadModel(config) for _ in range(num_component)]) # K
    
    def get_trans_matrix(self, input_ids, labels, attention_mask):        
        rates_matrix = [[0 for _ in range(self.num_component)] for _ in range(self.num_component)]
        for i in range(self.num_component):
            for j in range(self.num_component):                
                outputs = self.trans_rates[i * self.num_component + j].forward(input_ids = input_ids, labels = labels, \
                    attention_mask = attention_mask, output_hidden_states=True)
                rate = outputs.logits # [B, L, V]
                rates_matrix[i][j] = rate
        rates_matrix = [item for row in rates_matrix for item in row]
        rates_matrix = torch.stack(rates_matrix, dim=-1) # [B, L, V, K**2]
        rates_matrix = rates_matrix.view(rates_matrix.size(0), rates_matrix.size(1), rates_matrix.size(2), self.num_component, self.num_component)
        rates_matrix = self.pos_func(rates_matrix) # ensure it is positive
        rates_matrix = torch.clamp(rates_matrix, min=self.eps, max=self.inf)
        return rates_matrix
    
    def get_trans_matrix_from_base(self, input_ids, labels, attention_mask, cache_hidden_states=None):
        if cache_hidden_states is None:
            outputs = self.trans_base.forward(input_ids = input_ids, labels = labels, \
                    attention_mask = attention_mask, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
        else:
            hidden_states = cache_hidden_states

        rates_matrix = [[0 for _ in range(self.num_component)] for _ in range(self.num_component)]
        for i in range(self.num_component):
            for j in range(self.num_component):
                k = i * self.num_component + j                
                if self.output_layer_type == "linear":
                    rate = self.trans_heads[k](hidden_states)
                elif self.output_layer_type == "gpt2":
                    rate = self.trans_heads[k].forward(inputs_embeds = hidden_states,\
                        attention_mask = attention_mask, output_hidden_states=True).logits
                rates_matrix[i][j] = rate

        rates_matrix = [item for row in rates_matrix for item in row]
        rates_matrix = torch.stack(rates_matrix, dim=-1) # [B, L, V, K**2]
        rates_matrix = rates_matrix.view(rates_matrix.size(0), rates_matrix.size(1), rates_matrix.size(2), self.num_component, self.num_component)
        # before!
        # print("before pos_func", torch.max(rates_matrix), torch.min(rates_matrix), torch.mean(rates_matrix))
        
        rates_matrix = rates_matrix + torch.rand(rates_matrix.size()).to(rates_matrix.device) * self.eps

        rates_matrix = self.pos_func(rates_matrix) # ensure it is positive
        # print("after pos_func", torch.max(rates_matrix), torch.min(rates_matrix), torch.mean(rates_matrix))
        rates_matrix = torch.clamp(rates_matrix, min=self.eps, max=self.inf)
        # print("after clamp", torch.max(rates_matrix), torch.min(rates_matrix), torch.mean(rates_matrix))
        return rates_matrix

    def forward(self, input_time, **argv):
        return_rates_matrix = argv.get("return_rates_matrix", False)
        return_init_prob = argv.get("return_init_prob", False)
        return_eigen_values = argv.get("return_eigen_values", False)
        return_logits = argv.get("return_logits", False)
        return_component_logits = argv.get("return_component_logits", False)
        
        caches = self.forward_cache_hidden_states(input_time, **argv)
        trans_cache_hidden_states = caches.get("trans_cache_hidden_states")
        offsets_cache_hidden_states = caches.get("offsets_cache_hidden_states")
        # Overwrite
        if "trans_cache_hidden_states" in argv:
            trans_cache_hidden_states = argv.get("trans_cache_hidden_states")
        if "offsets_cache_hidden_states" in argv:
            offsets_cache_hidden_states = argv.get("offsets_cache_hidden_states")

        info_dict = {}

        time = discretize_time(
            input_time, 
            one_step=False, 
            normalize_time_a=self.config.normalize_time_a, 
            normalize_time_b=self.config.normalize_time_b,
            discrete=False)
        beam_size = argv.get("input_ids").size(0) // input_time.size(0)
        time = time.unsqueeze(1).repeat(1, beam_size).view(-1)
        # print(time)

        B, L = argv.get("input_ids").size()
        V = self.config.vocab_size

        # [B*L*V, K], [B*L*V,K,K]
        # print("before get_trans_matrix_from_base")
        if self.share_base:
            rates_matrix = self.get_trans_matrix_from_base(argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), cache_hidden_states=trans_cache_hidden_states)
        else:
            rates_matrix = self.get_trans_matrix(argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"))
        
        # print("before get_initial_prob")
        # rates_matrix = self.pos_func(torch.randn(rates_matrix.size()).to(rates_matrix.device))
        # TODO: FOR debug...  reset the value?

        # print(torch.max(rates_matrix), torch.min(rates_matrix), torch.mean(rates_matrix))
        init_prob = self.get_initial_prob(argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), cache_hidden_states=offsets_cache_hidden_states)
        # print(torch.max(init_prob), torch.min(init_prob), torch.mean(init_prob))
        
        # TODO: debug? Reset the initial probability??
        # init_prob = self.pos_func(torch.randn(init_prob.size()).to(init_prob.device))

        time = time.reshape(-1, 1, 1).expand(-1, L, V).reshape(-1, 1) # [B*L*V,1]
        
        # print(torch.mean(rates_matrix), torch.max(rates_matrix), torch.min(rates_matrix), torch.sum(torch.isnan(rates_matrix)), torch.sum(torch.isinf(rates_matrix)))
        

        p = torch.bmm(torch.linalg.matrix_exp(rates_matrix.view(-1, self.num_component, self.num_component) * time.unsqueeze(-1)),init_prob.unsqueeze(-1)).squeeze(-1)
        # p = torch.bmm(torch.linalg.matrix_exp(rates_matrix.view(-1, self.num_component, self.num_component).cpu() * time.unsqueeze(-1).cpu()).to(init_prob.device),init_prob.unsqueeze(-1)).squeeze(-1)


        # _tmp = torch.linalg.matrix_exp(rates_matrix.view(-1, self.num_component, self.num_component) * time.unsqueeze(-1))
        # print(torch.mean(_tmp), torch.max(_tmp), torch.min(_tmp), torch.sum(torch.isnan(_tmp)), torch.sum(torch.isinf(_tmp)))
        # print(torch.max(init_prob), torch.min(init_prob), torch.mean(init_prob))
        # print(torch.sum(torch.isnan(p)), torch.sum(torch.isinf(p)))
        
        p = p.view(B, -1, self.number_of_component) #[B, L*V, K]
        p = torch.clamp(p, min=self.eps, max=self.inf) # For numerical stable...

        if return_component_logits:
            info_dict["component_logits"] = p.view(B, L, V, -1)
        
        host_label = argv[self.config.data_property].unsqueeze(1).repeat(1, p.size(1)).unsqueeze(-1).long() # [B, L*V, 1]
        p = torch.gather(p, -1, host_label).squeeze(-1).view(B, L, -1) # [B, L, V]
        
        logits = torch.log(p) 

        if return_rates_matrix:
            rates_matrix_seq = torch.gather(rates_matrix, 2, argv.get("input_ids").view(B, L, 1, 1, 1).expand(-1, -1, -1, self.num_component, self.num_component))
            rates_matrix_seq = rates_matrix_seq.squeeze(2) # [B, L, K, K]
            info_dict["rates_matrix"] = rates_matrix_seq

        if return_init_prob:
            init_prob_ = init_prob.view(B, L, V, self.num_component) # [B, L, V, K]
            init_prob_seq = torch.gather(init_prob_, 1, argv.get("input_ids").view(B, L, 1, 1).expand(-1, -1, -1, self.num_component))
            info_dict["init_prob"] = init_prob_seq.squeeze(-2)

        # print(torch.isnan(logits))
        # if torch.sum(torch.isnan(logits)) > 0:
        #     print(logits)
        # print("calc logits done", logits.size())

        return GPTOutputs(logits=logits, info_dict=info_dict)

class GPT2TimeModelMultiHostsGeneral(GPT2TimeModelMultiHosts):

    def __init__(self, config, num_component, symmetry=True, base_models=None, **args) -> None:
        super().__init__(config, num_component, True, base_models, **args)
        self.symmetry = symmetry
        if config.offset_pos_function == "softmax":
            self.offset_pos_func = nn.Softmax(dim=-2) # [*, V, K]

    def build_models(self, config, num_component, symmetry=True, base_models=None, **args):
        # Share the GPT-2 Transformer layers
        if base_models is not None:
            self.trans_base = base_models["trans_base"]
            self.offsets_base = base_models["offsets_base"]
        else:
            self.trans_base = transformers.GPT2LMHeadModel(config)
            self.offsets_base = transformers.GPT2LMHeadModel(config)

        self.trans_heads = nn.Linear(config.hidden_size, config.vocab_size * num_component * num_component)
        self.offsets_heads = nn.Linear(config.hidden_size, config.vocab_size * num_component)
    
    def get_trans_matrix_from_base(self, input_ids, labels, attention_mask, cache_hidden_states=None):

        if cache_hidden_states is None:
            outputs = self.trans_base.forward(input_ids = input_ids, labels = labels, \
                    attention_mask = attention_mask, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
        else:
            hidden_states = cache_hidden_states
        rates_matrix = self.trans_heads(hidden_states).view(-1, self.num_component, self.num_component)
        if self.symmetry: # for debug...
            rates_matrix = rates_matrix + rates_matrix.transpose(-2, -1)
        
        rates_matrix += torch.rand(rates_matrix.size()).to(rates_matrix.device) * self.eps

        rates_matrix = self.pos_func(rates_matrix)
        # print(rates_matrix[0])
        # print(rates_matrix.size())
        rates_matrix = torch.clamp(rates_matrix, min=self.eps, max=self.inf)
        # print(rates_matrix.size())
        # print(rates_matrix.size())
        # print(rates_matrix[0])
        
        # if self.symmetry:
        #     eig_value, eig_vector = torch.linalg.eig(rates_matrix.float()) 
        # else:
        #     eig_value, eig_vector = torch.linalg.eigh(rates_matrix.float()) 
        
        eig_value, eig_vector = torch.linalg.eig(rates_matrix.float()) 
        # print(eig_value)
        # print(eig_vector)
        
        # print(eig_value.size(), eig_vector.size())
        # print(eig_value[0])
        # print(eig_vector[0])

        assert ~torch.any(torch.isnan(eig_value))
        assert ~torch.any(torch.isnan(eig_vector))

        return  eig_value, eig_vector, rates_matrix

    def get_initial_prob(self, input_ids, labels, attention_mask, cache_hidden_states=None):
        if cache_hidden_states is None:
            outputs = self.trans_base.forward(input_ids = input_ids, labels = labels, \
                    attention_mask = attention_mask, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
        else:
            hidden_states = cache_hidden_states
        # print(hidden_states.size())
        offset = self.offsets_heads(hidden_states)
        offset = offset.view(-1, self.config.vocab_size, self.num_component)
        # print(offset.size(), self.offset_pos_func)
        x0 = self.offset_pos_func(offset)
        # print(torch.sum(x0, dim=-2))
        # print(x0.view(-1, self.num_component).size())
        return x0.view(-1, self.num_component)

    def forward(self, input_time, **argv):
        return_rates_matrix = argv.get("return_rates_matrix", False)
        return_init_prob = argv.get("return_init_prob", False)
        return_eigen_values = argv.get("return_eigen_values", False)
        return_logits = argv.get("return_logits", False)
        return_component_logits = argv.get("return_component_logits", False)
        
        caches = self.forward_cache_hidden_states(input_time, **argv)
        trans_cache_hidden_states = caches.get("trans_cache_hidden_states")
        offsets_cache_hidden_states = caches.get("offsets_cache_hidden_states")
        # Overwrite
        if "trans_cache_hidden_states" in argv:
            trans_cache_hidden_states = argv.get("trans_cache_hidden_states")
        if "offsets_cache_hidden_states" in argv:
            offsets_cache_hidden_states = argv.get("offsets_cache_hidden_states")

        info_dict = {}

        time = discretize_time(
            input_time, 
            one_step=False, 
            normalize_time_a=self.config.normalize_time_a, 
            normalize_time_b=self.config.normalize_time_b,
            discrete=False)
        beam_size = argv.get("input_ids").size(0) // input_time.size(0)
        time = time.unsqueeze(1).repeat(1, beam_size).view(-1)

        B, L = argv.get("input_ids").size()
        V = self.config.vocab_size

        eig_value, eig_vecs, rates_matrix = self.get_trans_matrix_from_base(argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), cache_hidden_states=trans_cache_hidden_states)
        
        init_prob = self.get_initial_prob(argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), cache_hidden_states=offsets_cache_hidden_states)
        
        # print(eig_vecs.dtype)
        if eig_vecs.dtype == torch.complex64:
            init_prob = torch.complex(real=init_prob, imag=init_prob.new_zeros(init_prob.size()))

        # if self.symmetry:
        #     const = torch.bmm(torch.inverse(eig_vecs), init_prob.unsqueeze(-1)).squeeze(-1) # [B*L*V,K]
        #     # const = torch.linalg.solve(eig_vecs, init_prob)
        # else:

        const = torch.linalg.solve(eig_vecs, init_prob)
        # print(const2.size())
        # print(torch.allclose(const2, const))
        # print(const2[0], const[0])
        # print(torch.dist(const2, const))

        time = time.reshape(-1, 1, 1).expand(-1, L, V).reshape(-1, 1) # [B*L*V,1]
        
        p = (const * torch.exp(time * eig_value)).unsqueeze(1) * eig_vecs # [B*L*V, K, K]
        
        p = torch.sum(p, dim=-1) # .real # [B*L*V, K]  

        # print(p.dtype)
        if p.dtype == torch.complex64:
            p = torch.abs(p)

        p = p.view(B, -1, self.number_of_component) #[B, L*V, K]
        p = torch.clamp(p, min=self.eps, max=self.inf) # For numerical stable...

        if return_component_logits:
            info_dict["component_logits"] = p.view(B, L, V, -1)
        
        host_label = argv[self.config.data_property].unsqueeze(1).repeat(1, p.size(1)).unsqueeze(-1).long() # [B, L*V, 1]
        p = torch.gather(p, -1, host_label).squeeze(-1).view(B, L, -1) # [B, L, V]
        
        logits = torch.log(p) 

        if return_rates_matrix:
            rates_matrix_seq = torch.gather(rates_matrix, 2, argv.get("input_ids").view(B, L, 1, 1, 1).expand(-1, -1, -1, self.num_component, self.num_component))
            rates_matrix_seq = rates_matrix_seq.squeeze(2) # [B, L, K, K]
            info_dict["rates_matrix"] = rates_matrix_seq

        if return_init_prob:
            init_prob_ = init_prob.view(B, L, V, self.num_component) # [B, L, V, K]
            init_prob_seq = torch.gather(init_prob_, 1, argv.get("input_ids").view(B, L, 1, 1).expand(-1, -1, -1, self.num_component))
            info_dict["init_prob"] = init_prob_seq.squeeze(-2)

        return GPTOutputs(logits=logits, info_dict=info_dict)

class GPT2TimeModelMultiHostsGeneralV2(GPT2TimeModelMultiHostsGeneral):

    def build_models(self, config, num_component, symmetry=True, base_models=None, **args):
        # Share the GPT-2 Transformer layers
        if base_models is not None:
            self.trans_base = base_models["trans_base"]
            self.offsets_base = base_models["offsets_base"]
        else:
            self.trans_base = transformers.GPT2LMHeadModel(config)
            self.offsets_base = transformers.GPT2LMHeadModel(config)

        self.eig_vecs_heads = nn.Linear(config.hidden_size, config.vocab_size * int(num_component * (num_component - 1) // 2))
        self.eig_vals_heads = nn.Linear(config.hidden_size, config.vocab_size * num_component)
        self.offsets_heads = nn.Linear(config.hidden_size, config.vocab_size * num_component)
    
    def skew(self, vector, K):
        # vector: size == [B, K]
        triu_indices = torch.triu_indices(K, K, 1).to(vector.device)
        # print(triu_indices)
        triu_indices = triu_indices[0] * K + triu_indices[1]
        # print(triu_indices)
        triu_indices = triu_indices.unsqueeze(0).repeat(vector.size(0), 1) # [B, K]
        triu_indices = triu_indices + torch.arange(vector.size(0)).to(vector.device).unsqueeze(1) * K * K # [B, K]
        m = vector.new_zeros(vector.size(0) * K * K) # [B, K, K]
        m[triu_indices.view(-1)] = vector.view(-1)
        m = m.view(vector.size(0), K, K)
        m = m - m.transpose(-2, -1)
        return m.view(vector.size(0), K, K)

    def cayley_map(self, matrix, K):
        Id = torch.eye(K).unsqueeze(0).expand(matrix.size(0), -1, -1).to(matrix.device)
        return torch.linalg.solve(Id - matrix, Id + matrix)
    
    def get_initial_prob(self, input_ids, labels, attention_mask, cache_hidden_states=None):
        if cache_hidden_states is None:
            outputs = self.trans_base.forward(input_ids = input_ids, labels = labels, \
                    attention_mask = attention_mask, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
        else:
            hidden_states = cache_hidden_states
        # print(hidden_states.size())

        offset = self.offsets_heads(hidden_states)
        # print(offset.size())
        if self.offset_pos_func is not None:
            # print(offset.size())
            # print(self.offset_pos_func)
            # print(offset.view(-1, self.config.vocab_size).size())
            # print(self.offset_pos_func)
            offset = offset.view(-1, self.config.vocab_size, self.num_component)
            # print(offset.size())
            offset = self.offset_pos_func(offset)
            # print(offset.size())
        # print(x0.size())
        # print(x0.view(-1, self.num_component).size())
        return offset.view(-1, self.num_component)

    def get_trans_matrix_from_base(self, input_ids, labels, attention_mask, cache_hidden_states=None):

        if cache_hidden_states is None:
            outputs = self.trans_base.forward(input_ids = input_ids, labels = labels, \
                    attention_mask = attention_mask, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
        else:
            hidden_states = cache_hidden_states

        # print(hidden_states.size())
        eig_vecs = self.eig_vecs_heads(hidden_states).view(-1, (self.num_component * (self.num_component - 1) // 2))
        # print(eig_vecs.size())
        # print(eig_vecs[0], eig_vecs[4])
        eig_vecs = self.skew(eig_vecs, self.num_component)
        # print(eig_vecs.size())
        # print(eig_vecs[0], eig_vecs[4])
        eig_vecs = self.cayley_map(eig_vecs, self.num_component)
        # print(eig_vecs.size())
        # print(torch.bmm(eig_vecs.transpose(-2, -1), eig_vecs))

        eig_values = self.eig_vals_heads(hidden_states).view(-1, self.num_component)

        # eig_values = self.pos_func(eig_values) # TODO: keep all the eigenvalues < 0, stability?
        # print(eig_values)
        # print(eig_values.size(), eig_values[0])
        # exit()
        # rates_matrix = self.pos_func(rates_matrix)
        # print(rates_matrix.size())
        # rates_matrix = torch.clamp(rates_matrix, min=self.eps, max=self.inf)
        # print(rates_matrix.size())
        # eig_value, eig_vector = torch.linalg.eig(rates_matrix) 
        # print(eig_value.size(), eig_vector.size())
        # exit()

        # assert ~torch.any(torch.isnan(eig_value))
        # assert ~torch.any(torch.isnan(eig_vector))

        return  eig_values, eig_vecs, None

    def forward(self, input_time, **argv):
        return_rates_matrix = argv.get("return_rates_matrix", False)
        return_init_prob = argv.get("return_init_prob", False)
        return_eigen_values = argv.get("return_eigen_values", False)
        return_logits = argv.get("return_logits", False)
        return_component_logits = argv.get("return_component_logits", False)
        
        caches = self.forward_cache_hidden_states(input_time, **argv)
        trans_cache_hidden_states = caches.get("trans_cache_hidden_states")
        offsets_cache_hidden_states = caches.get("offsets_cache_hidden_states")
        # Overwrite
        if "trans_cache_hidden_states" in argv:
            trans_cache_hidden_states = argv.get("trans_cache_hidden_states")
        if "offsets_cache_hidden_states" in argv:
            offsets_cache_hidden_states = argv.get("offsets_cache_hidden_states")

        info_dict = {}

        time = discretize_time(
            input_time, 
            one_step=False, 
            normalize_time_a=self.config.normalize_time_a, 
            normalize_time_b=self.config.normalize_time_b,
            discrete=False)
        beam_size = argv.get("input_ids").size(0) // input_time.size(0)
        time = time.unsqueeze(1).repeat(1, beam_size).view(-1)

        B, L = argv.get("input_ids").size()
        V = self.config.vocab_size

        eig_value, eig_vecs, rates_matrix = self.get_trans_matrix_from_base(argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), cache_hidden_states=trans_cache_hidden_states)
        
        init_prob = self.get_initial_prob(argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), cache_hidden_states=offsets_cache_hidden_states)
        const = torch.linalg.solve(eig_vecs, init_prob.float())

        time = time.reshape(-1, 1, 1).expand(-1, L, V).reshape(-1, 1) # [B*L*V,1]
        
        p = (const * torch.exp(time * eig_value)).unsqueeze(1) * eig_vecs # [B*L*V, K, K]
        p = torch.sum(p, dim=-1) # .real # [B*L*V, K]   
        p = p.view(B, -1, self.number_of_component) #[B, L*V, K]

        if return_component_logits:
            info_dict["component_logits"] = p.view(B, L, V, -1)
        
        host_label = argv[self.config.data_property].unsqueeze(1).repeat(1, p.size(1)).unsqueeze(-1).long() # [B, L*V, 1]
        p = torch.gather(p, -1, host_label).squeeze(-1).view(B, L, -1) # [B, L, V]
        
        if self.pos_func is not None:
            p = self.pos_func(p)
            p = torch.clamp(p, min=self.eps, max=self.inf)
            logits = torch.log(p) 
        else:
            logits = p

        if return_rates_matrix:
            rates_matrix_seq = torch.gather(rates_matrix, 2, argv.get("input_ids").view(B, L, 1, 1, 1).expand(-1, -1, -1, self.num_component, self.num_component))
            rates_matrix_seq = rates_matrix_seq.squeeze(2) # [B, L, K, K]
            info_dict["rates_matrix"] = rates_matrix_seq

        if return_init_prob:
            init_prob_ = init_prob.view(B, L, V, self.num_component) # [B, L, V, K]
            init_prob_seq = torch.gather(init_prob_, 1, argv.get("input_ids").view(B, L, 1, 1).expand(-1, -1, -1, self.num_component))
            info_dict["init_prob"] = init_prob_seq.squeeze(-2)

        return GPTOutputs(logits=logits, info_dict=info_dict)

class GPT2TimeModelMultiHostsSimple(GPT2TimeModelMultiHosts):

    def __init__(self, config, num_component, symmetry=True, base_models=None, **args) -> None:
        super().__init__(config, num_component, True, base_models, **args)
        self.symmetry = symmetry
        if config.offset_pos_function == "softmax":
            self.offset_pos_func = nn.Softmax(dim=-2) # [*, V, K]
        # self.apply_log_softmax = getattr(config, "apply_log_softmax", False)

    def build_models(self, config, num_component, symmetry=True, base_models=None, **args):
        if base_models is not None:
            
            self.trans_base = base_models["trans_base"]
            if self.config.transformer_offset:
                self.offsets_base = base_models["offsets_base"]
            else:
                self.offsets_base = self.trans_base

            # self.offsets_base = base_models["offsets_base"]
        else:
            self.trans_base = transformers.GPT2LMHeadModel(config)
            if self.config.transformer_offset:
                self.offsets_base = transformers.GPT2LMHeadModel(config)
            else:
                self.offsets_base = self.trans_base

        self.eigvecs_heads = nn.Linear(config.hidden_size, config.vocab_size * num_component * num_component)
        self.eigvals_heads = nn.Linear(config.hidden_size, config.vocab_size * num_component)
        # self.offsets_heads = nn.Linear(config.hidden_size, config.vocab_size * num_component)
    
    def get_trans_matrix_from_base(self, input_ids, labels, attention_mask, cache_hidden_states=None):
        if cache_hidden_states is None:
            outputs = self.trans_base.forward(input_ids = input_ids, labels = labels, \
                    attention_mask = attention_mask, output_hidden_states=True)
            hidden_states_eigval = outputs.hidden_states[-1]

            if self.config.transformer_offset:
                outputs = self.offsets_base.forward(input_ids = input_ids, labels = labels, \
                    attention_mask = attention_mask, output_hidden_states=True)
                hidden_states_eigvec = outputs.hidden_states[-1]
            else:
                hidden_states_eigvec = hidden_states_eigval

        else:
            hidden_states_eigval = cache_hidden_states
            hidden_states_eigvec = cache_hidden_states

        eig_value = self.eigvals_heads(hidden_states_eigval).view(-1, self.num_component)
        eig_vector = self.eigvecs_heads(hidden_states_eigvec).view(-1, self.num_component, self.num_component)
        
        # assert ~torch.any(torch.isnan(eig_value))
        # assert ~torch.any(torch.isnan(eig_vector))
        return  eig_value, eig_vector, None

    def forward(self, input_time, **argv):
        return_rates_matrix = argv.get("return_rates_matrix", False)
        return_init_prob = argv.get("return_init_prob", False)
        return_eigen_values = argv.get("return_eigen_values", False)
        return_logits = argv.get("return_logits", False)
        return_component_logits = argv.get("return_component_logits", False)
        
        caches = self.forward_cache_hidden_states(input_time, **argv)
        trans_cache_hidden_states = caches.get("trans_cache_hidden_states")
        offsets_cache_hidden_states = caches.get("offsets_cache_hidden_states")
        # Overwrite
        if "trans_cache_hidden_states" in argv:
            trans_cache_hidden_states = argv.get("trans_cache_hidden_states")
        if "offsets_cache_hidden_states" in argv:
            offsets_cache_hidden_states = argv.get("offsets_cache_hidden_states")

        info_dict = {}

        time = discretize_time(
            input_time, 
            one_step=False, 
            normalize_time_a=self.config.normalize_time_a, 
            normalize_time_b=self.config.normalize_time_b,
            discrete=False)
        beam_size = argv.get("input_ids").size(0) // input_time.size(0)
        time = time.unsqueeze(1).repeat(1, beam_size).view(-1)

        B, L = argv.get("input_ids").size()
        V = self.config.vocab_size

        eig_value, eig_vecs, _ = self.get_trans_matrix_from_base(argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), cache_hidden_states=trans_cache_hidden_states)

        time = time.reshape(-1, 1, 1).expand(-1, L, V).reshape(-1, 1) # [B*L*V,1]
        
        # new
        logits = (time * eig_value).unsqueeze(1) + eig_vecs # [B*L*V, K, K]
        logits = torch.logsumexp(logits, dim=-1)
        # print(logits.size())
        if return_component_logits:
            if self.apply_log_softmax:
                info_dict["component_logits"] = torch.log_softmax(torch.exp(logits.view(B, L, V, -1)), dim=-2) 
            else:
                info_dict["component_logits"] = logits.view(B, L, V, -1)
        logits = logits.view(B, -1, self.number_of_component) # [B, L*V, K]
        host_label = argv[self.config.data_property].unsqueeze(1).repeat(1, logits.size(1)).unsqueeze(-1).long() # [B, L*V, 1]
        logits = torch.gather(logits, -1, host_label).squeeze(-1).view(B, L, -1) # [B, L, V]

        if self.apply_log_softmax:
            logits = torch.log_softmax(torch.exp(logits), dim=-1) 

        # p = torch.exp(logits)
        # new

        # old
        # p = (torch.exp(time * eig_value)).unsqueeze(1) * eig_vecs # [B*L*V, K, K]
        # p = torch.sum(p, dim=-1) # .real # [B*L*V, K]  
        # if self.pos_func is not None:
            # p = self.pos_func(p)
        # p = p.view(B, -1, self.number_of_component) #[B, L*V, K]
        # p = torch.clamp(p, min=self.eps, max=self.inf) # For numerical stable...

        # if return_component_logits:
            # info_dict["component_logits"] = p.view(B, L, V, -1)
        
        # old
        # host_label = argv[self.config.data_property].unsqueeze(1).repeat(1, p.size(1)).unsqueeze(-1).long() # [B, L*V, 1]
        # p = torch.gather(p, -1, host_label).squeeze(-1).view(B, L, -1) # [B, L, V]
        # logits = torch.log(p) 

        return GPTOutputs(logits=logits, info_dict=info_dict)

class GPT2TimeModelMultiHostsPrepend(GPT2TimeModelMultiHosts):
    def __init__(self, config, num_component, symmetry=True, base_models=None, **args) -> None:
        super().__init__(config, num_component, symmetry, base_models, **args)

    def build_models(self, config, num_component, symmetry=True, base_models=None, **args):
        _config = deepcopy(config)
        # print(config.data_properties)
        # print(config.location_dict)
        # print(sum([len(getattr(config, "%s_dict" % prop)) for prop in config.data_properties]))
        _config.vocab_size = config.vocab_size + sum([len(getattr(config, "%s_dict" % prop)) for prop in config.data_properties])
        # print(_config.vocab_size)
        # self.base_model = GPT2TimeModel.from_config(_config)
        # print(self.config.normalize_time_a)
        # print(self.config.normalize_time_b)
        # exit()
        # config = AutoConfig.from_pretrained(
            # pretrained_model_name_or_path="gpt2", **self.model_data_kwargs
        # )
        # # print(config.vocab_size)
        # setattr(config, "normalize_time_a", self.config.normalize_time_a)
        # setattr(config, "normalize_time_b", self.config.normalize_time_b)
        # setattr(config, "transformer_offset", self.config.transformer_offset)
        self.model = GPT2TimeModel.from_config(_config)
        # print(self.model)
        # len(alphabet) + sum([len(getattr(self.config, "%s_dict" % prop)) for prop in config.data_properties])
        self.trans_rate_output_layer = nn.Linear(config.hidden_size, config.vocab_size * self.num_component) # H -> V*K
        self.offsets_output_layer = nn.Linear(config.hidden_size, config.vocab_size) # H -> V
    
    def _prepend_tokens(self, **argv):
        # <LOCATION=i> ATCG.... -> p0_i(ATCT)\in R A_i(ATCG)\in R^K
        # <LOCATION=j> ATCG.... -> p0_j(ATCG)\in R A_j(ATCG)\in R^K
        
        input_ids = argv["input_ids"]
        # print(input_ids.size())
        bsz = input_ids.size(0)

        offset = self.config.vocab_size
        prop = self.config.data_property # TODO: we only consider one property. 
        prop_tok = torch.arange(len(getattr(self.config, "%s_dict" % prop))).to(input_ids.device) + offset
        # print(prop_tok)

        prop_tok = prop_tok.unsqueeze(0).repeat(bsz, 1).unsqueeze(-1) # [B, K, 1]
        # print(prop_tok.size())

        input_ids_expand = input_ids.unsqueeze(1).repeat(1, prop_tok.size(1), 1) # [B, K, L]
        # print(input_ids_expand.size())

        
        input_ids_prepend = torch.cat([prop_tok, input_ids_expand], dim=-1) # [B, K, L + 1]
        # print(input_ids_prepend.size())

        # print(self.config.padding_idx)

        new_argv = {}

        new_argv["input_ids"] = input_ids_prepend.view(-1, input_ids_prepend.size(-1)) # [B*K, L+1]
        new_argv["labels"] = new_argv["input_ids"]
        new_argv["attention_mask"] = (new_argv["input_ids"] != self.config.padding_idx)

        return new_argv

    def get_initial_prob(self, hidden_states):
        L = hidden_states.size(1)
        prob_vectors = self.offsets_output_layer(hidden_states) # B*K, L, V
        prob_vectors = prob_vectors.view(-1, self.num_component, L, self.config.vocab_size) # [B, K, L, V]
        prob_vectors = self.offset_pos_func(prob_vectors) # Make sure V is the -1 dimension
        prob_vectors = torch.permute(prob_vectors, (0, 2, 3, 1)) # [B, L, V, K]
        return prob_vectors.reshape(-1, self.num_component)

    def get_trans_matrix_from_base(self, hidden_states, generation=False):
        K, V, L = self.num_component, self.config.vocab_size, hidden_states.size(1)
        rates = self.trans_rate_output_layer(hidden_states).view(-1, L, self.config.vocab_size, self.num_component)
        # print(rates.size())
        rates = rates.view(-1, K, L, V, K)
        rates = torch.permute(rates, (0, 2, 3, 1, 4))
        # print(rates.size())
        rates = (rates + rates.transpose(-1, -2)) / 2
        # print(rates[0,0,0])
        
        rates = rates + torch.rand(rates.size()).to(rates.device) * self.eps # To avoid the A is ill-defined.
        rates = self.pos_func(rates)
        # print(rates[0,0,0])
        assert torch.all(rates.transpose(-2, -1) == rates), rates[rates.transpose(-2, -1) != rates]
        rates = torch.clamp(rates, min=self.eps, max=self.inf)
        
        eig_value, eig_vector = torch.linalg.eigh(rates.reshape(-1, self.num_component, self.num_component)) 

        assert ~torch.any(torch.isnan(eig_value))
        assert ~torch.any(torch.isnan(eig_vector))

        if generation:
            eig_value = eig_value.view(-1, 1, self.config.vocab_size, self.num_component).repeat(1, L, 1, 1).view(-1, self.num_component)
            eig_vector = eig_vector.view(-1, 1, self.config.vocab_size, self.num_component, self.num_component).repeat(1, L, 1, 1, 1).view(-1, self.num_component, self.num_component)

        # print(eig_value.size(), eig_vector.size(), rates.size())
        # exit()
        return eig_value, eig_vector, rates

    def forward(self, input_time, **argv): 
        generation = argv.get("generation", False)

        return_rates_matrix = argv.get("return_rates_matrix", False)
        return_init_prob = argv.get("return_init_prob", False)
        return_eigen_values = argv.get("return_eigen_values", False)
        return_logits = argv.get("return_logits", False)
        return_component_logits = argv.get("return_component_logits", False)
        
        caches = self.forward_cache_hidden_states(input_time, **argv)
        trans_cache_hidden_states = caches.get("trans_cache_hidden_states")
        offsets_cache_hidden_states = caches.get("offsets_cache_hidden_states")
        # Overwrite
        if "trans_cache_hidden_states" in argv:
            trans_cache_hidden_states = argv.get("trans_cache_hidden_states")
        if "offsets_cache_hidden_states" in argv:
            offsets_cache_hidden_states = argv.get("offsets_cache_hidden_states")

        info_dict = {}

        time = discretize_time(
            input_time, 
            one_step=False, 
            normalize_time_a=self.config.normalize_time_a, 
            normalize_time_b=self.config.normalize_time_b,
            discrete=False)
        beam_size = argv.get("input_ids").size(0) // input_time.size(0)
        time = time.unsqueeze(1).repeat(1, beam_size).view(-1)
        
        B = argv.get("input_ids").size(0)
        new_argv = self._prepend_tokens(**argv)
        V = self.config.vocab_size
        L = new_argv.get("input_ids").size(1)

        outputs = self.model.forward(input_time, **new_argv, return_hidden_states=True) # .logits / self.config.temperature
        # print(outputs.hidden_states[-1].size())
        
        eig_value, eig_vecs, rates_matrix = self.get_trans_matrix_from_base(outputs.hidden_states[-1], generation=generation)
        init_prob = self.get_initial_prob(outputs.hidden_states[-1])
        # print(eig_value.size(), eig_vecs.size(), rates_matrix.size(), init_prob.size())
        const = torch.bmm(eig_vecs.transpose(-2, -1), init_prob.unsqueeze(-1)).squeeze(-1) # [B*L*V,K]
        time = time.reshape(-1, 1, 1).expand(-1, L, V).reshape(-1, 1) # [B*L*V,1]
        # print(time.size())
        p = (const * torch.exp(time * eig_value)).unsqueeze(1) * eig_vecs # [B*L*V, K, K]
        p = torch.sum(p, dim=-1) # [B*L*V, K]
        p = p.view(B, -1, self.number_of_component) #[B, L*V, K]
        p = torch.clamp(p, min=self.eps, max=self.inf) # For numerical stable...  
        # remove the first token
        p = p.view(B, L, V, self.num_component)
        p = p[:, 1:, :, :].reshape(B, -1, self.num_component)
        L = L - 1
        # logits = logits[:, prepend_ids_size:, :]
              
        if return_component_logits:
            info_dict["component_logits"] = torch.log(p).view(B, L, V, -1)  # 
        if self.config.data_property in argv:
            host_label = argv[self.config.data_property].unsqueeze(1).repeat(1, p.size(1)).unsqueeze(-1).long() # [B, L*V, 1]
            p = torch.gather(p, -1, host_label).squeeze(-1).view(B, L, -1) # [B, L, V]
            logits = torch.log(p) 
            if return_rates_matrix:
                rates_matrix_seq = torch.gather(rates_matrix, 2, argv.get("input_ids").view(B, L, 1, 1, 1).expand(-1, -1, -1, self.num_component, self.num_component))
                rates_matrix_seq = rates_matrix_seq.squeeze(2) # [B, L, K, K]
                info_dict["rates_matrix"] = rates_matrix_seq

            if return_init_prob:
                init_prob_ = init_prob.view(B, L, V, self.num_component) # [B, L, V, K]
                init_prob_seq = torch.gather(init_prob_, 1, argv.get("input_ids").view(B, L, 1, 1).expand(-1, -1, -1, self.num_component))
                info_dict["init_prob"] = init_prob_seq.squeeze(-2)
        else:
            logits = None

        # print(logits.size())
        # exit()
        return GPTOutputs(logits=logits, info_dict=info_dict)
        return logits, info_dict
        # return logits

class GPT2TimeModelMultiHostsPrependV2(GPT2TimeModelMultiHostsPrepend):
    def __init__(self, config, num_component, symmetry=True, base_models=None, **args) -> None:
        super().__init__(config, num_component, symmetry, base_models, **args)

    def build_models(self, config, num_component, symmetry=True, base_models=None, **args):
        _config = deepcopy(config)
        _config.vocab_size = config.vocab_size + sum([len(getattr(config, "%s_dict" % prop)) for prop in config.data_properties])
        self.model = GPT2TimeModel.from_config(_config)
        # print(self.model)
        # len(alphabet) + sum([len(getattr(self.config, "%s_dict" % prop)) for prop in config.data_properties])
        self.eig_value_output_layer = nn.Linear(config.hidden_size, config.vocab_size) # H -> V*K
        self.eig_vectors_output_layer = nn.Linear(config.hidden_size, config.vocab_size * self.num_component) # H -> V*K
        self.offsets_output_layer = nn.Linear(config.hidden_size, config.vocab_size)
    
    def get_trans_matrix_from_base(self, hidden_states, generation=False):
        K, V, L = self.num_component, self.config.vocab_size, hidden_states.size(1)

        eig_vector = self.eig_vectors_output_layer(hidden_states).view(-1, K, L, self.config.vocab_size, K) # [B, K, L, V, K]
        eig_value = self.eig_value_output_layer(hidden_states).view(-1, K, L, self.config.vocab_size) # [B, K, L, V]
        
        eig_vector = torch.permute(eig_vector, (0, 2, 3, 1, 4)).reshape(-1, K, K) # [B, L, V, K, K] -> [B*L*V, K, K]
        eig_value = torch.permute(eig_value, (0, 2, 3, 1)).reshape(-1, K) # [B, L, V, K] -> [B*L*V, K]
        
        # print(eig_vector.size())
        # print(eig_value.size())
        # exit()
        # if generation: TODO:
            # eig_value = eig_value.view(-1, 1, self.config.vocab_size, self.num_component).repeat(1, L, 1, 1).view(-1, self.num_component)
            # eig_vector = eig_vector.view(-1, 1, self.config.vocab_size, self.num_component, self.num_component).repeat(1, L, 1, 1, 1).view(-1, self.num_component, self.num_component)

        return eig_value, eig_vector, None

    def forward(self, input_time, **argv): 
        generation = argv.get("generation", False)

        return_rates_matrix = argv.get("return_rates_matrix", False)
        return_init_prob = argv.get("return_init_prob", False)
        return_eigen_values = argv.get("return_eigen_values", False)
        return_logits = argv.get("return_logits", False)
        return_component_logits = argv.get("return_component_logits", False)
        
        caches = self.forward_cache_hidden_states(input_time, **argv)
        trans_cache_hidden_states = caches.get("trans_cache_hidden_states")
        offsets_cache_hidden_states = caches.get("offsets_cache_hidden_states")
        # Overwrite
        if "trans_cache_hidden_states" in argv:
            trans_cache_hidden_states = argv.get("trans_cache_hidden_states")
        if "offsets_cache_hidden_states" in argv:
            offsets_cache_hidden_states = argv.get("offsets_cache_hidden_states")

        info_dict = {}

        time = discretize_time(
            input_time, 
            one_step=False, 
            normalize_time_a=self.config.normalize_time_a, 
            normalize_time_b=self.config.normalize_time_b,
            discrete=False)
        beam_size = argv.get("input_ids").size(0) // input_time.size(0)
        time = time.unsqueeze(1).repeat(1, beam_size).view(-1)
        
        B = argv.get("input_ids").size(0)
        new_argv = self._prepend_tokens(**argv)
        V = self.config.vocab_size
        L = new_argv.get("input_ids").size(1)

        outputs = self.model.forward(input_time, **new_argv, return_hidden_states=True) # .logits / self.config.temperature
        # print(outputs.hidden_states[-1].size())
        
        eig_value, eig_vecs, rates_matrix = self.get_trans_matrix_from_base(outputs.hidden_states[-1], generation=generation)
        init_prob = self.get_initial_prob(outputs.hidden_states[-1])
        # print(eig_value.size(), eig_vecs.size(), rates_matrix.size(), init_prob.size())
        const = torch.bmm(eig_vecs.transpose(-2, -1), init_prob.unsqueeze(-1)).squeeze(-1) # [B*L*V,K]
        time = time.reshape(-1, 1, 1).expand(-1, L, V).reshape(-1, 1) # [B*L*V,1]
        # print(time.size())
        p = (const * torch.exp(time * eig_value)).unsqueeze(1) * eig_vecs # [B*L*V, K, K]
        p = torch.sum(p, dim=-1) # [B*L*V, K]
        p = self.pos_func(p)
        p = p.view(B, -1, self.number_of_component) #[B, L*V, K]
        p = torch.clamp(p, min=self.eps, max=self.inf) # For numerical stable...  
        # remove the first token
        p = p.view(B, L, V, self.num_component)
        p = p[:, 1:, :, :].reshape(B, -1, self.num_component)
        L = L - 1
        # logits = logits[:, prepend_ids_size:, :]
              
        if return_component_logits:
            info_dict["component_logits"] = torch.log(p).view(B, L, V, -1)  # 

        if self.config.data_property in argv:
            host_label = argv[self.config.data_property].unsqueeze(1).repeat(1, p.size(1)).unsqueeze(-1).long() # [B, L*V, 1]
            p = torch.gather(p, -1, host_label).squeeze(-1).view(B, L, -1) # [B, L, V]
            logits = torch.log(p) 
            if return_rates_matrix:
                rates_matrix_seq = torch.gather(rates_matrix, 2, argv.get("input_ids").view(B, L, 1, 1, 1).expand(-1, -1, -1, self.num_component, self.num_component))
                rates_matrix_seq = rates_matrix_seq.squeeze(2) # [B, L, K, K]
                info_dict["rates_matrix"] = rates_matrix_seq

            if return_init_prob:
                init_prob_ = init_prob.view(B, L, V, self.num_component) # [B, L, V, K]
                init_prob_seq = torch.gather(init_prob_, 1, argv.get("input_ids").view(B, L, 1, 1).expand(-1, -1, -1, self.num_component))
                info_dict["init_prob"] = init_prob_seq.squeeze(-2)
        else:
            logits = None

        # print(logits.size())
        # exit()
        return GPTOutputs(logits=logits, info_dict=info_dict)
        return logits, info_dict
        # return logits

class GPT2TimeModelMultiHostsEmbedding(GPT2TimeModelMultiHosts):
    def __init__(self, config, num_component, symmetry=True, base_models=None, **args) -> None:
        super().__init__(config, num_component, symmetry, base_models, **args)
    
    def build_models(self, config, num_component, symmetry=True, base_models=None, **args):
        
        if base_models is not None:
            self.trans_base = base_models["trans_base"]
            # self.offsets_base = base_models["offsets_base"]
            self.local_embedding_ffn = base_models["local_embedding_ffn"]
        else:
            self.trans_base = transformers.GPT2LMHeadModel(config)
            # self.offsets_base = transformers.GPT2LMHeadModel(config)
            self.local_embedding_ffn = nn.Linear(config.hidden_size, config.vocab_size * num_component * config.local_embedding_size)

        self.trans_ffn_heads = nn.Linear(config.local_embedding_size, config.local_embedding_size) # R=(HW)^T (HW)
        self.offsets_ffn_heads = nn.Linear(config.local_embedding_size, 1) # P0=HD
    
    def get_trans_matrix_from_base(self, input_ids, labels, attention_mask, cache_local_embeddings=None, cache_hidden_states=None, generation=False):
        B, L = input_ids.size(0), input_ids.size(1)

        if cache_local_embeddings is None:
            if cache_hidden_states is None:
                outputs = self.trans_base.forward(input_ids = input_ids, labels = labels, \
                        attention_mask = attention_mask, output_hidden_states=True)
                hidden_states = outputs.hidden_states[-1]
            else:
                hidden_states = cache_hidden_states
            local_embeddings = self.local_embedding_ffn(hidden_states) # [B, L, V*K*H']
        else:
            local_embeddings = cache_local_embeddings # [B, L, V, K, H']
    

        if generation:
            L = local_embeddings.size(1) # Real len of sequences
            local_embeddings = local_embeddings[:, -1:, :]
        
        # hidden_states: [B, L, H]
        local_embeddings = local_embeddings.view(-1, self.num_component, self.config.local_embedding_size) # [B*L*V, K, H']
        # print(local_embeddings.size())
        rates_matrix = torch.bmm(local_embeddings, local_embeddings.transpose(-2, -1)) # [B*L*V, K, K]
        rates_matrix = self.pos_func(rates_matrix)
        # print(rates_matrix.size())

        assert torch.all(rates_matrix == rates_matrix.transpose(-2, -1))

        # rates_matrix = torch.clamp(rates_matrix, min=self.eps, max=self.inf) # TODO: clip?
        eig_value, eig_vector = torch.linalg.eigh(rates_matrix.view(-1, self.num_component, self.num_component)) 

        assert ~torch.any(torch.isnan(eig_value))
        assert ~torch.any(torch.isnan(eig_vector))

        if generation:
            eig_value = eig_value.view(-1, 1, self.config.vocab_size, self.num_component).repeat(1, L, 1, 1).view(-1, self.num_component)
            eig_vector = eig_vector.view(-1, 1, self.config.vocab_size, self.num_component, self.num_component).repeat(1, L, 1, 1, 1).view(-1, self.num_component, self.num_component)

        return  eig_value, eig_vector, rates_matrix

    def get_initial_prob(self, input_ids, labels, attention_mask, cache_local_embeddings=None, cache_hidden_states=None):
        if cache_local_embeddings is None:
            if cache_hidden_states is None:
                outputs = self.trans_base.forward(input_ids = input_ids, labels = labels, \
                        attention_mask = attention_mask, output_hidden_states=True)
                cache_hidden_states = outputs.hidden_states[-1]
            local_embeddings = self.local_embedding_ffn(cache_hidden_states)
        else:
            local_embeddings = cache_local_embeddings # # [B, L, V, K, H']
        
        # local_embeddings: [B, L, V*K*H']
        # print(local_embeddings.size())
        # print(self.num_component)
        # local_embeddings = local_embeddings.view(-1, self.num_component, self.config.local_embedding_size) # [B*L*V, K, H']
        p0 = self.offsets_ffn_heads(local_embeddings).squeeze(-1) # [B*L*V, K]
        # print(p0.size())
        p0 = self.offset_pos_func(p0.view(-1, self.config.vocab_size, self.num_component))
        # print(p0.view(-1, self.num_component).size())
        return p0.view(-1, self.num_component)

    def forward(self, input_time, **argv):
        generation = argv.pop("generation", False)

        return_rates_matrix = argv.get("return_rates_matrix", False)
        return_init_prob = argv.get("return_init_prob", False)
        return_eigen_values = argv.get("return_eigen_values", False)
        return_logits = argv.get("return_logits", False)
        return_component_logits = argv.get("return_component_logits", False)

        caches = self.forward_cache_hidden_states(input_time, **argv)
        trans_cache_hidden_states = caches.get("trans_cache_hidden_states")
        offsets_cache_hidden_states = caches.get("offsets_cache_hidden_states")
        # Overwrite
        if "trans_cache_hidden_states" in argv:
            trans_cache_hidden_states = argv.get("trans_cache_hidden_states")
        if "offsets_cache_hidden_states" in argv:
            offsets_cache_hidden_states = argv.get("offsets_cache_hidden_states")

        info_dict = {}

        time = discretize_time(
            input_time, 
            one_step=False, 
            normalize_time_a=self.config.normalize_time_a, 
            normalize_time_b=self.config.normalize_time_b,
            discrete=False)
        beam_size = argv.get("input_ids").size(0) // input_time.size(0)
        time = time.unsqueeze(1).repeat(1, beam_size).view(-1)

        B, L = argv.get("input_ids").size()
        V = self.config.vocab_size

        # [B*L*V, K], [B*L*V,K,K]
        eig_value, eig_vecs, rates_matrix = self.get_trans_matrix_from_base(
            argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), 
            cache_hidden_states=trans_cache_hidden_states, generation=generation, cache_local_embeddings=argv.get("cache_local_embeddings", None))

        init_prob = self.get_initial_prob(
            argv.get("input_ids"), argv.get("labels"), argv.get("attention_mask"), cache_hidden_states=offsets_cache_hidden_states,
            cache_local_embeddings=argv.get("cache_local_embeddings", None))
        const = torch.bmm(eig_vecs.transpose(-2, -1), init_prob.unsqueeze(-1)).squeeze(-1) # [B*L*V,K]
        time = time.reshape(-1, 1, 1).expand(-1, L, V).reshape(-1, 1) # [B*L*V,1]
        p = (const * torch.exp(time * eig_value)).unsqueeze(1) * eig_vecs # [B*L*V, K, K]
        p = torch.sum(p, dim=-1) # [B*L*V, K]
        p = p.view(B, -1, self.number_of_component) #[B, L*V, K]
        p = torch.clamp(p, min=self.eps, max=self.inf) # For numerical stable...
        
        if return_component_logits:
            # func = torch.log if not self.apply_log_softmax else torch.log_softmax
            if not self.apply_log_softmax:
                info_dict["component_logits"] = torch.log(p).view(B, L, V, -1)  # 
            else:
                info_dict["component_logits"] = torch.log_softmax(p.view(B, L, V, -1), dim=-2)  # 
        
        if self.config.data_property in argv:
            host_label = argv[self.config.data_property].unsqueeze(1).repeat(1, p.size(1)).unsqueeze(-1).long() # [B, L*V, 1]
            # print(host_label)
            p = torch.gather(p, -1, host_label).squeeze(-1).view(B, L, -1) # [B, L, V]
            
            if not self.apply_log_softmax:
                logits = torch.log(p) 
            else:
                logits = torch.log_softmax(p, dim=-1) # [B, L, V] 

            if return_rates_matrix:
                rates_matrix_seq = torch.gather(rates_matrix, 2, argv.get("input_ids").view(B, L, 1, 1, 1).expand(-1, -1, -1, self.num_component, self.num_component))
                rates_matrix_seq = rates_matrix_seq.squeeze(2) # [B, L, K, K]
                info_dict["rates_matrix"] = rates_matrix_seq

            if return_init_prob:
                init_prob_ = init_prob.view(B, L, V, self.num_component) # [B, L, V, K]
                init_prob_seq = torch.gather(init_prob_, 1, argv.get("input_ids").view(B, L, 1, 1).expand(-1, -1, -1, self.num_component))
                info_dict["init_prob"] = init_prob_seq.squeeze(-2)

        else:
            logits = None
        return GPTOutputs(logits=logits, info_dict=info_dict)

@register_model("gpt2_time_multi_hosts")
class GPT2TimeMultiHosts(LanguageModelingTransformer):
    def __init__(self, config, alphabet) -> None:
        self.config = config
        self.alphabet = alphabet
        self.pad_idx = alphabet.pad()
        super().__init__(
            pretrained_model_name_or_path=config.model_name_or_path, # GPT-2
            load_weights=config.load_weights,  # False
            vocab_size=len(alphabet),  # TODO: build the alphabet first!!!!!!!!!!
            max_position_embeddings=config.max_position_embeddings, # 1024 by default, but please set larger.
            num_hidden_layers=config.num_hidden_layers, # 12
            hidden_size=config.hidden_size # 768
            )
        
    def set_transformer_config(self, config):
        # basic settings?
        setattr(config, "num_hidden_layers", self.config.num_hidden_layers)
        setattr(config, "hidden_size", self.config.hidden_size)


        setattr(config, "normalize_time_a", self.config.normalize_time_a)
        setattr(config, "normalize_time_b", self.config.normalize_time_b)
        setattr(config, "transformer_offset", self.config.transformer_offset)
        setattr(config, "data_property", self.config.data_properties[0])
        setattr(config, "share_base", getattr(self.config, "share_base", False))
        setattr(config, "output_layer_type", getattr(self.config, "output_layer_type", "linear"))

        setattr(config, "offset_pos_function", getattr(self.config, "offset_pos_function", "softmax"))
        setattr(config, "pos_function", getattr(self.config, "pos_function", "softplus"))
        setattr(config, "max_rate_value", getattr(self.config, "max_rate_value", 1e5))
        setattr(config, "min_rate_value", getattr(self.config, "min_rate_value", 1e-12))

        setattr(config, "eig_val_layer_norm", getattr(self.config, "eig_val_layer_norm", False))
        setattr(config, "eig_vecs_layer_norm", getattr(self.config, "eig_vecs_layer_norm", False))

        
        setattr(config, "add_trans_layer_norm", getattr(self.config, "add_trans_layer_norm", False))


        setattr(config, "padding_idx", self.alphabet.pad())

        setattr(config, "data_properties", self.config.data_properties)
        for data_property in self.config.data_properties:
            setattr(config, "%s_dict" % data_property, getattr(self.config, "%s_dict" % data_property))
        
        setattr(config, "add_geo_info", getattr(self.config, "add_geo_info", False))

        setattr(config, "lobpcg", getattr(self.config, "lobpcg", False))
        setattr(config, "lobpcg_k", getattr(self.config, "lobpcg_k", 1))

        setattr(config, "power_iteration", getattr(self.config, "power_iteration", False))
        setattr(config, "topk_eigen_values", getattr(self.config, "topk_eigen_values", 1))
        setattr(config, "power_iteration_num", getattr(self.config, "power_iteration_num", 100))
        setattr(config, "power_gradient", getattr(self.config, "power_gradient", False))

        setattr(config, "positive_definite", getattr(self.config, "positive_definite", False))
        setattr(config, "negative_definite", getattr(self.config, "negative_definite", False))

        setattr(config, "topk_eigen", getattr(self.config, "topk_eigen", None))
        setattr(config, "poisson_lambda", getattr(self.config, "poisson_lambda", 4.0))
        setattr(config, "poisson_sample_num", getattr(self.config, "poisson_sample_num", 32))
        # 


    def initialize_model(self, pretrained_model_name_or_path: str):
        """create and initialize the model to use with this task,
        Feel free to overwrite this method if you are initializing the model in a different way
        """
        config = AutoConfig.from_pretrained(
            pretrained_model_name_or_path="gpt2", **self.model_data_kwargs
        )
        self.set_transformer_config(config)

        # setattr(config, "normalize_time_a", self.config.normalize_time_a)
        # setattr(config, "normalize_time_b", self.config.normalize_time_b)
        # setattr(config, "transformer_offset", self.config.transformer_offset)
        # setattr(config, "data_property", self.config.data_properties[0])
        # setattr(config, "share_base", getattr(self.config, "share_base", False))
        # setattr(config, "output_layer_type", getattr(self.config, "output_layer_type", "linear"))

        # setattr(config, "offset_pos_function", getattr(self.config, "offset_pos_function", "softmax"))
        # setattr(config, "pos_function", getattr(self.config, "pos_function", "softplus"))
        # setattr(config, "max_rate_value", getattr(self.config, "max_rate_value", 1e5))
        # setattr(config, "min_rate_value", getattr(self.config, "min_rate_value", 1e-12))



        # setattr(config, "return_rates_matrix", getattr(self.config, "return_rates_matrix", False))
        # setattr(config, "return_init_prob", getattr(self.config, "return_init_prob", False))

        if self.config.num_host:
            num_host = self.config.num_host
        else:
            num_host = len(getattr(self.config, "%s_dict" % self.config.data_properties[0]))
        
        logging.info("num_host: %d" % num_host)

        output_layer_config = AutoConfig.from_pretrained(
            pretrained_model_name_or_path="gpt2", 
            vocab_size=config.vocab_size,
            max_position_embeddings=config.max_position_embeddings,
            hidden_size=config.hidden_size,
            num_hidden_layers=getattr(self.config, "output_layer_num_hidden_layers", self.config.num_hidden_layers)
        )
        
        if self.config.implement_version == 1:
            if getattr(self.config, "power_iteration", False):
                self.model = GPT2TimeModelMultiHostsNew.from_config(
                    config, num_host, output_layer_config=output_layer_config,
                    add_global_model=self.config.add_global_model, aggregate_global_loss=self.config.aggregate_global_loss
                    )
            else:
                if getattr(self.config, "add_global_model", False) or getattr(self.config, "aggregate_global_loss", False):
                # self.config.add_global_model or self.config.aggregate_global_loss:
                    self.model = GPT2TimeModelMultiHostsAndGlobal.from_config(
                        config, num_host, output_layer_config=output_layer_config,
                        add_global_model=self.config.add_global_model, aggregate_global_loss=self.config.aggregate_global_loss
                        )
                else:
                    self.model = GPT2TimeModelMultiHosts.from_config(config, num_host, output_layer_config=output_layer_config)
        elif self.config.implement_version == 2:
            self.model = GPT2TimeModelMultiHostsV2.from_config(config, num_host)
        elif self.config.implement_version == 3:
            # self.model = GPT2TimeModelMultiHostsParamShareOLD.from_config(config, num_host)
            assert len(self.config.data_properties) == 1, "Could only set one property!"
            self.model = GPT2TimeModelMultiHostsParamShare.from_config(config, num_host)
        elif self.config.implement_version == 4:
            self.model = GPT2TimeModelMultiHostsGeneralV2.from_config(config, num_host)
        elif self.config.implement_version == 5: # non-symmatry
            self.model = GPT2TimeModelMultiHostsGeneral.from_config(config, num_host, symmetry=getattr(self.config, "symmetry_rate_matrix", False))
        elif self.config.implement_version == 6: # simple model
            self.model = GPT2TimeModelMultiHostsSimple.from_config(config, num_host)
        elif self.config.implement_version == 7: # global model: 
            self.model = GPT2TimeModelMultiHostsGlobal.from_config(config, 1)
        elif self.config.implement_version == 8: # diagonal
            self.model = GPT2TimeModelMultiHostsDiag.from_config(config, num_host)
        elif self.config.implement_version == 9:
            self.model = GPT2TimeModelMultiHostsPrepend.from_config(config, num_host)
        elif self.config.implement_version == 10:
            self.model = GPT2TimeModelMultiHostsPrependV2.from_config(config, num_host)
        elif self.config.implement_version == 11: # 
            self.model = GPT2TimeModelMultiHostsV2_new.from_config(config, num_host, symmetry=getattr(self.config, "symmetry_rate_matrix", True))
        elif self.config.implement_version == 12:
            self.model = GPT2TimeModelMultiHostsEmbedding.from_config(config, num_host)
        elif self.config.implement_version == 13:
            self.model = GPT2TimeModelMultiHostsNew.from_config(
                    config, num_host, output_layer_config=output_layer_config,
                    add_global_model=self.config.add_global_model, aggregate_global_loss=self.config.aggregate_global_loss
                    )
        elif self.config.implement_version == 14: # + negative/positive definite
            self.model = GPT2TimeModelMultiHostsNew2.from_config(config, num_host)
        elif self.config.implement_version == 15: # matrix exponential
            self.model = GPT2TimeModelMultiHostsMatrixExpSample.from_config(config, num_host)
        
             
    # def load_pretrained_model(self, path):
    #     pretrained_model_state_dict = torch.load(path)["state_dict"]
    #     for state in pretrained_model_state_dict:
    #         if state in self.state_dict():
    #             if self.state_dict()[state].size() != pretrained_model_state_dict[state].size():
    #                 logging.warning("The parameter %s of pretrained model (%s) doesn't fit the current model %s." % (state, str(pretrained_model_state_dict[state].size()), str(self.state_dict()[state].size())))
    #             else:
    #                 self.state_dict()[state].copy_(pretrained_model_state_dict[state])

    def configure_optimizers(self) -> Dict:
        # rank_zero_warn(
        #     "You haven't specified an optimizer or lr scheduler. "
        #     "Defaulting to AdamW with an lr of 1e-5 and linear warmup for 10% of steps. "
        #     "To change this, override ``configure_optimizers`` in  TransformerModule."
        # )
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.learning_rate)
        num_training_steps, num_warmup_steps = self.compute_warmup(
            num_training_steps=-1,
            num_warmup_steps=0.1,
        )
        scheduler = transformers.get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {"scheduler": scheduler, "interval": "step", "frequency": 1},
        }

    def on_after_backward(self) -> None:
        # print("on_after_backward..")
        if getattr(self.config, "zero_gradient_for_nan", True):
            valid_gradients = True
            for name, param in self.named_parameters():
                if param.grad is not None:
                    valid_gradients = not (torch.isnan(param.grad).any() or torch.isinf(param.grad).any())
                    if not valid_gradients:
                        break

            if not valid_gradients:
                # log.warning(f'detected inf or nan values in gradients. not updating model parameters')
                self.zero_grad()
            else:
                return super().on_after_backward()
        else:
            return super().on_after_backward()

    @classmethod
    def load_from_checkpoint(
        cls,
        checkpoint_path: Union[str, IO],
        map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
        hparams_file: Optional[str] = None,
        strict: bool = True,
        hf_pipeline_kwargs: Optional[Dict] = None,
        # config = None,
        args = None,
        **kwargs
    ):
        # if args.ensemble and len(checkpoint_path.split(",")) > 1:
        #     checkpoint_paths = checkpoint_path.split(",")
        #     # print(checkpoint_paths)
        #     model_list = []
        #     for path in checkpoint_paths:
        #         _model = super().load_from_checkpoint(path, map_location, hparams_file, strict)
        #         model_list.append(_model)
        #     # models = [super().load_from_checkpoint(path, map_location, hparams_file, strict) for path in checkpoint_paths]
        #     model = nn.ModuleList(model_list) 
        # else:
        model = super().load_from_checkpoint(checkpoint_path, map_location, hparams_file, strict)

        # model.resume_from_checkpoint = checkpoint_path
        model.config.resume_from_checkpoint = checkpoint_path
        model.config.pred_data_paths = getattr(args, "pred_data_paths", "")
        if args is not None:
            model.config.test_data_paths = args.test_data_paths
        for key in kwargs:
            logging.info("Overwrite model hyperparameter %s:" % key + ", from " + str(getattr(model, key, None)) + " to " + str(kwargs[key]))
            setattr(model, key, kwargs[key])
        return model

    @classmethod
    def add_argparse_args(cls, parent_parser):
        # parent_parser = super(myGPT2, cls).add_argparse_args()
        # For testing
        parent_parser.add_argument('--load_weights', action='store_true')
        parent_parser.add_argument('--num_hidden_layers', type=int, default=12)
        parent_parser.add_argument('--tau', type=float, default=1.0, help="Devide t by tau.")
        parent_parser.add_argument('--hidden_size', type=int, default=768)
        parent_parser.add_argument('--model_name_or_path', type=str, default="gpt2")
        parent_parser.add_argument('--load_from_pretrain_checkpoint', type=str, default=None)
        # parent_parser.add_argument('--max_position_embeddings', type=int, default=1280)
        # For time embeddings
        parent_parser.add_argument('--normalize_time_a', type=int, default=1,  help="t = (t-b)/a")
        parent_parser.add_argument('--normalize_time_b', type=int, default=0, help="t = (t-b)/a")
        # parent_parser.add_argument('--time_agnostic', action='store_true')
        parent_parser.add_argument('--add_location', action='store_true', help="Add the location information.")
        parent_parser.add_argument('--add_lineage', action='store_true', help="Add the lineage information.")
        parent_parser.add_argument('--count_mse_loss', action='store_true', help="Use the count mse loss instead of ce loss.")
        # Settings for the off-set layer:
        parent_parser.add_argument('--weight_loss_by_count', type=str2bool, default="true", help="Weight loss of each sample by their counting not frequency")
        parent_parser.add_argument('--no_normalization_in_batch', action='store_true', help="Don't normalize the loss weight within the batch!!")
        parent_parser.add_argument('--zero_offset', action='store_true', help="Set the sequences distribution at offset as 0")
        parent_parser.add_argument('--offset_share_layer', type=int, default=-1, help="Use the hidden state at layer i to output the offset.")
        parent_parser.add_argument('--transformer_offset', action='store_true', help="Use another transformer NN to predict the offset.")
        # parent_parser.add_argument('--regression_loss', action='store_true', help="Use the regression loss instead of the MLE loss.")
        # parent_parser.add_argument('--normalize_time_a', type=int, default=1,  help="t = (t-b)/a")
        parent_parser.add_argument('--second_order_rate', action='store_true', help="Add the second order rate in modeling.")
        parent_parser.add_argument('--transformer_second_order_rate', action='store_true', help="Add the second order rate in modeling.")
        parent_parser.add_argument('--output_token_losses', type=str2bool, default="false")

        parent_parser.add_argument('--do_sample', type=str2bool, default="false")
        parent_parser.add_argument('--temperature', type=float, default=1.0)
        parent_parser.add_argument('--num_beams', type=int, default=1)
        parent_parser.add_argument('--num_return_sequences', type=int, default=1)

        parent_parser.add_argument('--zero_time', action='store_true', help="Set the time as zero.")
        parent_parser.add_argument('--set_time', type=float, default=None)
        
        parent_parser.add_argument('--ensemble', type=str2bool, default="false")
        parent_parser.add_argument('--average_over_time', type=str2bool, default="false")

        
        parent_parser.add_argument('--num_host', type=int, default=None)
        parent_parser.add_argument('--implement_version', type=int, default=1) # , choices=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
        parent_parser.add_argument('--independent', type=str2bool, default="false")

        parent_parser.add_argument('--share_base', type=str2bool, default="false")
        parent_parser.add_argument('--output_layer_type', type=str, default="linear", choices=["linear", "gpt2"])
        parent_parser.add_argument('--output_layer_num_hidden_layers', type=int, default=2)

        parent_parser.add_argument('--return_rates_matrix', type=str2bool, default="false")
        parent_parser.add_argument('--return_init_prob', type=str2bool, default="false")

        # parent_parser.add_argument('--output_layer_hidden_size', type=int, default=768)
        parent_parser.add_argument('--fail_retry_number', type=int, default=3)

        # If we introduce the global model (to make the training more stable...)
        parent_parser.add_argument('--add_global_model', type=str2bool, default="false")
        parent_parser.add_argument('--aggregate_global_loss', type=str2bool, default="false")
        parent_parser.add_argument('--global_loss_w', type=float, default=1.0)
        parent_parser.add_argument('--test_global_loss_w', type=float, default=0.6)
        # parent_parser.add_argument('--add_global_logits_to_local_logits', type=str2bool, default="false")
        # parent_parser.add_argument('--add_global_logits_to_local_logits_w', type=float, default=0.5)

        
        parent_parser.add_argument('--zero_gradient_for_nan', type=str2bool, default="false") # Don't use this! Problematic

        
        parent_parser.add_argument('--offset_pos_function', type=str, default="softmax", choices=["softmax", "softplus", "relu", "none", "exp", "abs"])
        parent_parser.add_argument('--pos_function', type=str, default="softplus", choices=["sigmoid", "softplus", "relu", "none", "exp", "abs"])
        parent_parser.add_argument('--max_rate_value', type=float, default=1e5)
        parent_parser.add_argument('--min_rate_value', type=float, default=1e-12)


        parent_parser.add_argument('--weight_loss_by_time', type=str2bool, default="false")
        parent_parser.add_argument('--weight_loss_by_time_logistic_x0', type=float, default=None, help="f(x0)=0.5")
        parent_parser.add_argument('--weight_loss_by_time_logistic_k', type=float, default=0.5, help="small k, smoother.")
        
        parent_parser.add_argument('--debias_sample_weight', type=str2bool, default="false")

        parent_parser.add_argument('--symmetry_rate_matrix', type=str2bool, default="true")

        parent_parser.add_argument('--add_random_shuffle_loss', type=str, default=None, choices=["shuffle", "uniform"])
        parent_parser.add_argument('--random_shuffle_loss_weight', type=float, default=0.5)
        # parent_parser.add_argument('--eigen_values_distinction_reg_loss_w', type=float, default=0.0)
        # global_loss_w
        
        parent_parser.add_argument('--eig_vecs_layer_norm', type=str2bool, default="false")
        parent_parser.add_argument('--eig_val_layer_norm', type=str2bool, default="false")
        parent_parser.add_argument('--add_trans_layer_norm', type=str2bool, default="false")
        
        # Approximation for eigens
        parent_parser.add_argument('--lobpcg', type=str2bool, default="false")
        parent_parser.add_argument('--topk_eigen_values', type=int, default=1)
        parent_parser.add_argument('--power_iteration', type=str2bool, default="false")
        parent_parser.add_argument('--power_gradient', type=str2bool, default="false")
        parent_parser.add_argument('--power_iteration_num', type=int, default=100)

        # constraints in trans rate matrix
        parent_parser.add_argument('--positive_definite', type=str2bool, default="false")
        parent_parser.add_argument('--negative_definite', type=str2bool, default="false")
        # keep topk eigenvalues/eigenvectors
        parent_parser.add_argument('--topk_eigen', type=int, default=None)

        parent_parser.add_argument('--poisson_lambda', type=float, default=4.0)
        parent_parser.add_argument('--poisson_sample_num', type=int, default=32)

        

        
        
        return parent_parser
        
    def nll_loss(self, lm_logits, labels, loss_weight=None, reduce=True, ignore_bos=False):

        labels = labels.masked_fill(torch.eq(labels, self.alphabet.pad()), -100)
        # print(labels.size())
        # print(torch.sum(labels != -100, dim=-1)) # to test! 

        if ignore_bos:
            labels = labels.masked_fill(torch.eq(labels, self.alphabet.bos()), -100)
            # print(torch.sum(labels != -100, dim=-1)) # to test!

        # Shift so that tokens < n predict n
        shift_logits = lm_logits[..., :-1, :].contiguous() / self.config.temperature

        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = CrossEntropyLoss(reduce=False)
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        loss = loss.view(shift_labels.size())
        # print(loss.size())
        # print(loss)
        # print(torch.sum(loss, dim=-1))
        # exit()

        if reduce:
            # print(loss)
            loss = loss.sum(dim=-1) / (shift_labels != -100).sum(dim=-1) # [B]

            if loss_weight is not None:
                if not self.config.no_normalization_in_batch:
                    loss_weight = loss_weight / loss_weight.sum()
                # print(loss_weight.sum(), "loss_weight", loss)
                loss = torch.sum(loss * loss_weight)
                # loss = loss.sum() / (shift_labels != -100).sum()
            else:
                loss = loss.mean()
        # else:
            # print((shift_labels != -100).sum(-1))
            # if not getattr(self.config, "output_token_losses", False):
                # loss = loss.sum(-1) # TODO: / (shift_labels != -100).sum(-1) # calculate the loss for each sample
        return loss
    
    def count_mse_loss(self, lm_logits, labels, total_count, target_count, reduce=True):
        # First, predict the count, and then
        # Shift so that tokens < n predict n
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        log_prob = torch.log_softmax(shift_logits, dim=-1) # [B, L, V]
        target_log_prob = torch.gather(log_prob, -1, shift_labels.unsqueeze(-1)).squeeze(-1) # [B, L]
        target_log_prob = (target_log_prob * (shift_labels != self.alphabet.pad())).sum(-1) # [B]
        loss = (target_log_prob.exp() * total_count - target_count) ** 2 / (target_count.max() ** 2) # .mean()
        # print(((target_log_prob.exp() * total_count - target_count) ** 2).mean(), loss.mean())
        if reduce:
            loss = loss.mean()
        return loss

    def logistic_time_loss_weight(self, time):
        # k = 0.1
        k = getattr(self.config, "weight_loss_by_time_logistic_x0", 0.1)
        x0 = getattr(self.config, "weight_loss_by_time_logistic_k", 50) # 50
        # print(x0, k)
        return 1 / (1 + torch.exp(-k * (time - x0)))

    def forward(self, batch, batch_idx, reduce=True, mode="train"):
        if getattr(self.config, "zero_time", False):
            batch["input_time"].fill_(0.)

        if getattr(self.config, "set_time", None) is not None:
            batch["input_time"].fill_(self.config.set_time) # set time bin as a constant.

        # logits, info_dict = self.model(**batch, return_rates_matrix=self.config.return_rates_matrix, return_init_prob = self.config.return_init_prob)
        if self.config.weight_loss_by_count and batch.get('freq', None) is not None and batch.get('bin_size', None) is not None:
            loss_weight = batch.get('freq', None) * batch.get('bin_size', None)
            if getattr(self.config, "debias_sample_weight", False):
                if mode == "train":
                    total_weights = self.trainer.datamodule.total_sample_count_train
                    total_num_of_samples = len(self.trainer.datamodule.train_dataset)
                elif mode == "val":
                    total_weights = self.trainer.datamodule.total_sample_count_valid
                    total_num_of_samples = len(self.trainer.datamodule.val_dataset)

                loss_weight = loss_weight / total_weights * total_num_of_samples / batch["input_ids"].size(0)
                # print(torch.sum(loss_weight))
                # loss_weight = loss_weight / sum(sample_num) * len(random_values) / batch_size

        elif not self.config.weight_loss_by_count and batch.get('freq', None) is not None: # otherwise, using the frequency
            loss_weight = batch.get('freq', None)
        else:
            loss_weight = 1.0
        
        if getattr(self.config, "weight_loss_by_time", False):  #  self.config.
            # print(loss_weight, batch["input_time"])
            loss_weight = loss_weight * self.logistic_time_loss_weight(batch["input_time"])
            # print(loss_weight)
            # exit()

        labels = batch["labels"]

        # print(self.config.return_rates_matrix, self.config.return_init_prob)

        model_outputs = self.model(
            **batch, return_rates_matrix=self.config.return_rates_matrix, 
            return_init_prob = self.config.return_init_prob, return_component_logits=True,
            test_global_loss_w=getattr(self.config, "test_global_loss_w", None) # when add global model
            )
        
        if getattr(self.config, "count_mse_loss", False): # self.config.count_mse_loss:
            loss = self.count_mse_loss(model_outputs.logits, labels, batch["bin_size"], batch["freq"] * batch["bin_size"], reduce=reduce)
        else:
            loss = self.nll_loss(model_outputs.logits, labels, loss_weight=loss_weight, reduce=reduce, 
                             ignore_bos=True if mode == "test" else False)
                

        loss_dict = {}

        if getattr(self.config, "add_random_shuffle_loss", None) is not None: # two ways to shuffle            
            logp = model_outputs.info_dict["component_logits"] # [B, L, V, K]
            # print(model_outputs.logits.size())
            # print(logp.size())
            
            properties = batch[self.config.data_properties[0]]
            # host_label = properties.view(-1, 1, 1, 1).repeat(1, logp.size(1), logp.size(2), 1).long() # [B, L, V, 1]
            # print(host_label.size())
            # p = torch.gather(p, -1, host_label).squeeze(-1).view(B, L, -1) # [B, L, V]

            if self.config.add_random_shuffle_loss == "shuffle":
                shuffle_index = torch.randperm(len(properties)).to(properties.device)
                shuffle_properties = properties[shuffle_index]
            elif self.config.add_random_shuffle_loss == "uniform":
                shuffle_properties = torch.randint(self.config.num_host, (properties.size(0), )).to(properties.device)
                # print(shuffle_properties.size(), shuffle_properties)
            
            shuffle_host_label = shuffle_properties.view(-1, 1, 1, 1).repeat(1, logp.size(1), logp.size(2), 1).long() # [B, L, V, 1]
            shuffle_logp = torch.gather(logp, -1, shuffle_host_label).squeeze(-1) # [B, L, V]
            # print(shuffle_logp.size())
            shuffle_loss = self.nll_loss(shuffle_logp, labels, loss_weight=loss_weight, reduce=reduce)
            # print(shuffle_loss.size())
            # print(shuffle_loss)
            loss_dict["ori_loss"] = loss
            loss_dict["shuffle_loss"] = shuffle_loss
            
            if mode == "train":
                loss = (1 - self.config.random_shuffle_loss_weight) * loss + self.config.random_shuffle_loss_weight * shuffle_loss
            # exit()
        
        # print(loss)
        # print(info_dict)
        
        # if self.config.add_global_model or self.config.aggregate_global_loss:
        if getattr(self.config, "add_global_model", False) or getattr(self.config, "aggregate_global_loss", False):
            global_logits = model_outputs.info_dict["global_logits"]
            global_loss = self.nll_loss(global_logits, labels, loss_weight=loss_weight, reduce=reduce)
            loss_dict["global_loss"] = global_loss
            loss_dict["local_loss"] = loss
            if mode != "test":
                loss = loss + self.config.global_loss_w * global_loss
            # else:
                # loss = - (torch.logsumexp(torch.stack([-loss + math.log(1 - self.config.test_global_loss_w), -global_loss + math.log(self.config.test_global_loss_w)], dim=-1), axis=-1))

        # if torch.isnan(loss):
        #     print(torch.any(torch.isnan(model_outputs.logits)))
        #     print(loss)
        #     exit()
        # print(loss_dict)
        # print(model_outputs)
        if self.config.return_rates_matrix:
            loss_dict["rates_matrix"] = model_outputs.info_dict["rates_matrix"]
            loss_dict["input_ids"] = batch["input_ids"] # output this for debugging
        if self.config.return_init_prob:
            loss_dict["init_prob"] = model_outputs.info_dict["init_prob"]
        # exit()
        return loss, loss_dict

    def training_step(self, batch, batch_idx):
        # self.generate("A")
        loss, loss_dict = self.forward(batch, batch_idx)
        self.log("train_loss", loss, prog_bar=True)
        for key in loss_dict:
            self.log("train_%s" % key, loss_dict[key], prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, loss_dict = self.forward(batch, batch_idx)
        self.log("val_loss", loss, prog_bar=True)
        for key in loss_dict:
            self.log("val_%s" % key, loss_dict[key], prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx, dataloader_idx=0):
        loss, loss_dict = self.forward(batch, batch_idx, reduce=False, mode="test")
        token_num = torch.sum(
            (batch["labels"][..., 1:].contiguous() != self.alphabet.pad()) * (batch["labels"][..., 1:].contiguous() != self.alphabet.eos()), dim=-1)

        if "freq" in batch and "bin_size" in batch:
            weight = batch["freq"] * batch["bin_size"]
        else:
            weight = token_num.new_zeros(token_num.size(0)) + 1.0
        self.log("test_loss", loss.mean(), prog_bar=True)
        # for key in loss_dict:
            # self.log("test_%s" % key, loss_dict[key].mean(), prog_bar=True)
        return loss, token_num, weight, loss_dict

    def overwrite_generate_kwargs(self, new_config):
        setattr(self.config, "do_sample", new_config.do_sample)
        setattr(self.config, "num_beams", new_config.num_beams)
        setattr(self.config, "temperature", new_config.temperature)
        setattr(self.config, "num_return_sequences", new_config.num_return_sequences)
        setattr(self.config, "output_token_losses", new_config.output_token_losses)

        setattr(self.config, "return_rates_matrix", new_config.return_rates_matrix)
        setattr(self.config, "return_init_prob", new_config.return_init_prob)

        setattr(self.config, "fail_retry_number", new_config.fail_retry_number)
        # print("test_global_loss_w", new_config.test_global_loss_w)
        setattr(self.config, "test_global_loss_w", new_config.test_global_loss_w)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        # print(batch)
        generate_kwargs = {}
        generate_kwargs["temperature"] = getattr(self.config, "temperature", 1.0) # TODO: how to add this in testing?
        generate_kwargs["do_sample"] = getattr(self.config, "do_sample", True)
        generate_kwargs["num_beams"] = getattr(self.config, "num_beams", 1.0)
        setattr(self.model.config, "num_beams", generate_kwargs["num_beams"])
        
        generate_kwargs["num_return_sequences"] = max(getattr(self.config, "num_return_sequences", 1.0), generate_kwargs["num_beams"])
        # generate_kwargs["num_return_sequences"] = getattr(self.config, "num_return_sequences", 1.0)
            

        if hasattr(self.config, "test_global_loss_w"):
             generate_kwargs["test_global_loss_w"] = self.config.test_global_loss_w

        if getattr(self.config, "generate_max_length", None) is None:
            generate_kwargs["max_length"] = self.config.max_position_embeddings
        else:
            generate_kwargs["max_length"] = getattr(self.config, "generate_max_length", None)
        generate_kwargs["pad_token_id"] = self.alphabet.pad()
        generate_kwargs["eos_token_id"] = self.alphabet.eos()
        generate_kwargs["bos_token_id"] = self.alphabet.bos()
        

        # self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"])
        # print(batch["input_time"])
        # input_ids = torch.zeros((batch["input_ids"].size(0), 1)).to(batch["input_ids"].device)
        # input_ids.fill_(self.alphabet.bos())
        if batch["input_ids"][0, -1].item() == self.alphabet.eos():
            batch["input_ids"] = batch["input_ids"][:, :-1]

        # batch["input_ids"][0, 0] = 1 # just fot test
        # print(batch["input_ids"])
        # print(generate_kwargs)
        # exit()
        # print(batch)
        # model_inputs = {"input_ids": batch["input_ids"], "input_time": batch["input_time"]}
        # print(self.model.tokenizer)
        # generate_kwargs["do_sample"] = False

        output_ids = self.model.generate(**batch, **generate_kwargs)
        
        # for i in range(self.config.fail_retry_number):
        #     try:
        #         output_ids = self.model.generate(**batch, **generate_kwargs)
        #         break
        #     except Exception as e:
        #         logging.info(e)

        input_time = batch["input_time"].unsqueeze(1).repeat(1, generate_kwargs["num_return_sequences"]).view(-1)
        # print(input_time.size())
        outputs = [{"prediction": self.alphabet.string(x), "src_time": input_time[i].item()} for i, x in enumerate(output_ids)]
        # print(outputs)
        # exit()
        # print(self.alphabet.bos(), self.alphabet.eos())
        return outputs
        
    def test_epoch_end(self, outputs):
        losses, token_nums, weights = [], [], []
        other_info = defaultdict(list)
        # print(len(outputs))
        if len(self.config.test_data_paths) == 1:
            outputs = [outputs]

        for dataloader_outputs in outputs:
            for output in dataloader_outputs:
                # outpu[0]: [B, L]
                losses.append(output[0].sum(-1)) # [B]
                token_nums.append(output[1])
                weights.append(output[2])
                
                for key in output[3]:
                    if isinstance(output[3][key], list) and isinstance(output[3][key][0], torch.Tensor):
                        other_info[key].extend([x.mean(dim=0).cpu() for x in output[3][key]])
                    
                    elif isinstance(output[3][key], torch.Tensor):
                        other_info[key].append(output[3][key].cpu())
                        # other_info[key].append(output[3][key].mean(1).cpu())
        losses = torch.cat(losses)
        token_nums = torch.cat(token_nums)
        weights = torch.cat(weights)
        # exit()
        # print("Sum of frequency", torch.exp(-losses).sum())
        # print(torch.sum(weights), torch.sum(token_nums * weights))
        # print(losses.size(), token_nums.size(), weights.size())
        # print(outputs[0]) # loss, token_num, weight
        # ppl1 = torch.exp(torch.sum(losses) / torch.sum(token_nums))
        # print(ppl1)
        # ppl2 = torch.exp(torch.sum(weights * losses) / torch.sum(weights * token_nums))
        # print(ppl2)
        ppl = torch.exp(torch.sum(losses * weights) / torch.sum(token_nums * weights))
        nll = torch.sum(weights * losses) / torch.sum(weights)
        # nll = torch.exp(torch.sum(losses * weights))
        # exit()
        # collate data:
        # outputs is a list of dict, or a list of list of dict (for multiple dataloaders)
        # loss = torch.cat(outputs)
        self.log_dict({"perplexity": ppl, "nll": nll, "coverage": torch.exp(-losses).sum()})

        if self.config.output_token_losses:
            self.all_outputs = []
            for dataloader_outputs in outputs:
                for output in dataloader_outputs:
                    # print(output[0].size())
                    # loss = 
                    self.all_outputs.extend([x for x in output[0]])
        else:
            self.all_outputs = []
            for loss, tok_num in zip(losses, token_nums):
                # info_dict["prediction"] = loss.item()
                self.all_outputs.append({"prediction": loss.item(), "token_num": tok_num.item()})

        # Output the other information
        if not os.path.exists(self.trainer.logger.log_dir):
            os.makedirs(self.trainer.logger.log_dir)
        for key in other_info:
            print(key, len(other_info[key]), other_info[key][0].size())
            # other_info[key] = torch.cat(other_info[key]).cpu() # in case the shapes are not the same
            # print(key, other_info[key].size())
            output_path = os.path.join(self.trainer.logger.log_dir, "%s.pkl" % key)
            # print(output_path)
            logging.info("Saving %s to %s" % (key, output_path))
            torch.save(other_info[key], output_path)

        return ppl
    
    def output_testing_results(self, outputs, predict_dataset):
        
        predict_dataset = [item for sublist in predict_dataset for item in sublist]
        # print(len(outputs))
        # print(len(predict_dataset))
        assert len(outputs) == len(predict_dataset)
        results = []
        for index, output_loss in enumerate(outputs):
            # src_id,freq,src_time,prediction,rate,offset
            if self.config.output_token_losses:
                output_dict = {"prediction": " ".join([str(x.item()) for x in output_loss])}
            else:
                output_dict = output_loss # {"prediction": output_loss}
            # print(output_dict)
            # exit()
            output_dict["src_id"] = predict_dataset[index]["src_id"]
            output_dict["src_time"] = predict_dataset[index]["src_time"]
            output_dict["freq"] = predict_dataset[index]["freq"]
            results.append(output_dict)
        return results

    def output_predicting_results(self, outputs, predict_dataset, *args, **kwargs):
        # assert len(outputs) == len(predict_dataset)
        # print(len(outputs), len(predict_dataset))
        results = []
        for i, output_dict in enumerate(outputs):
            # src_id,freq,src_time,prediction,rate,offset
            output_dict["prediction"] = output_dict["prediction"]
            output_dict["src_time"] = output_dict["src_time"]
            results.append(output_dict)

        output_path = args[0]
        if output_path is not None and output_path.endswith(".csv"):
            fasta_path = output_path.split(".csv")[0] + ".fasta"
            # print(fasta_path)
            logging.info("Writing generations to %s" % fasta_path)
            with open(fasta_path, "w") as fout:
                for i, data in enumerate(results):
                    fout.write(">%d\n%s\n\n" % (i, data["prediction"]))

        return results

        results = []
        for output_dict in outputs:
            index = output_dict["index"]
            # src_id,freq,src_time,prediction,rate,offset
            output_dict["src_id"] = predict_dataset[index]["src_id"]
            output_dict["freq"] = predict_dataset[index]["freq"]
            results.append(output_dict)
        return results


# For debugging
@register_model("gpt2_time_multi_hosts_debug")
class GPT2TimeMultiHostsDebug(GPT2TimeMultiHosts):

    @classmethod
    def add_argparse_args(cls, parent_parser):
        parent_parser = super(GPT2TimeMultiHostsDebug, cls).add_argparse_args(parent_parser)
        parent_parser.add_argument('--debug_mode', type=str, choices=["shuffle_property", "fix_property", "enumerate_property", "random_property"])
        return parent_parser

    def forward(self, batch, batch_idx, reduce=True, mode="train"):
        if getattr(self.config, "zero_time", False):
            batch["input_time"].fill_(0.)

        if getattr(self.config, "set_time", None) is not None:
            batch["input_time"].fill_(self.config.set_time) # set time bin as a constant.

        # print(batch[self.config.data_properties[0]])
        if self.config.debug_mode == "shuffle_property": # Shuffle the property
            properties = batch[self.config.data_properties[0]]
            shuffle_index = torch.randperm(len(properties)).to(properties.device)
            batch[self.config.data_properties[0]] = properties[shuffle_index]
        elif self.config.debug_mode == "fix_property": # Train on one property
            properties = batch[self.config.data_properties[0]]
            properties.fill_(0)
            batch[self.config.data_properties[0]] = properties
        elif self.config.debug_mode == "random_property":
            properties = batch[self.config.data_properties[0]]
            random_properties = torch.randint(self.config.num_host, (properties.size(0), )).to(properties.device)
            batch[self.config.data_properties[0]] = random_properties

        if self.config.debug_mode == "enumerate_property":
            return_component_logits = True
        else:
            return_component_logits = False

        model_outputs = self.model(**batch, return_rates_matrix=self.config.return_rates_matrix, return_init_prob = self.config.return_init_prob, return_component_logits=return_component_logits)
        # print(batch[self.config.data_properties[0]])

        # model_outputs = self.model(**batch, return_rates_matrix=True, return_init_prob = self.config.return_init_prob, return_component_logits=return_component_logits)
        # print(model_outputs.info_dict["rates_matrix"].size())
        # ave_rates_matrix = torch.mean(model_outputs.info_dict["rates_matrix"], dim=1)
        # print(ave_rates_matrix)
        # exit()

        if self.config.weight_loss_by_count:
            loss_weight = batch.get('freq', None) * batch.get('bin_size', None)
        else: # otherwise, using the frequency
            loss_weight = batch.get('freq', None)
        
        if getattr(self.config, "weight_loss_by_time", False):  #  self.config.
            loss_weight = loss_weight * self.logistic_time_loss_weight(batch["input_time"])

        if self.config.debug_mode == "enumerate_property":
            all_logits = model_outputs.info_dict["component_logits"]
            # print(all_logits.size())
            all_logits = all_logits.view(-1, all_logits.size(-1)).transpose(0, 1).view(all_logits.size(-1), *all_logits.size()[:-1])
            # print(all_logits.size())
            # print(loss_weight.size())
            loss_weight = loss_weight.unsqueeze(0).repeat(all_logits.size(0), 1)
            labels = batch["labels"].unsqueeze(0).repeat(all_logits.size(0), 1, 1)
            # print(labels.size(), loss_weight.size())
            loss = self.nll_loss(all_logits.reshape(-1, *all_logits.size()[2:]), \
                labels.reshape(-1, *labels.size()[2:]), loss_weight=loss_weight.reshape(-1, *loss_weight.size()[2:]), reduce=reduce)
            # print(loss.size())
            if mode != "train":
                loss = loss.view(*all_logits.size()[:2], -1) # [K, B, L]
                loss = loss[0]
                # loss = -(torch.logsumexp(-loss, dim=0) - math.log(loss.size(0)))
        else:
            labels = batch["labels"]
            loss = self.nll_loss(model_outputs.logits, labels, loss_weight=loss_weight, reduce=reduce)

        # labels = batch["labels"]
        # loss = self.nll_loss(model_outputs.logits, labels, loss_weight=loss_weight, reduce=reduce)
        
        loss_dict = {}

        if getattr(self.config, "add_global_model", False) or getattr(self.config, "aggregate_global_loss", False):
            global_logits = model_outputs.info_dict["global_logits"]
            global_loss = self.nll_loss(global_logits, labels, loss_weight=loss_weight, reduce=reduce)
            loss_dict["global_loss"] = global_loss
            loss_dict["local_loss"] = loss
            if mode != "test":
                loss = loss + self.config.global_loss_w * global_loss
            else:
                loss = - (torch.logsumexp(torch.stack([-loss + math.log(1 - self.config.test_global_loss_w), -global_loss + math.log(self.config.test_global_loss_w)], dim=-1), axis=-1))

        return loss, loss_dict