import torch
import pytorch_lightning as pl
from lightning_transformers.task.nlp.language_modeling import (
    LanguageModelingDataModule,
    LanguageModelingTransformer,
)
# from pytorch_lightning.pytorch.utilities import grad_norm
import time
from lightning_transformers.task.nlp.translation import TranslationTransformer
import transformers
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from models import register_model
import math, logging
from models.criterions.cross_entropy import CEWeightedLoss
from esm import pretrained
from typing import IO, Any, Callable, Dict, Optional, Tuple, Type, Union
from utils.args import str2bool
from transformers import GPT2LMHeadModel
from transformers import AutoConfig, PreTrainedTokenizerBase
from collections import defaultdict, namedtuple
import os
from copy import deepcopy
import torch.nn.functional as F
from models.gpt2_multihosts import GPT2TimeMultiHosts
import pickle as pkl
from models.gpt2_multihosts import GPTOutputs, GPT2TimeModelMultiHosts, GPT2TimeModelMultiHostsV2, GPT2TimeModelMultiHostsParamShare, GPT2TimeModelMultiHostsSimple, GPT2TimeModelMultiHostsEmbedding, GPT2TimeModelMultiHostsDiag

class GPT2TimeModelMultiHostsHierarchyBase(transformers.GPT2LMHeadModel):
    def __init__(self, config, num_components, symmetry=True, **args) -> None:
        super().__init__(config)

        self.build_base_model(config, num_components)
        self.build_local_models(config, num_components)
        self.build_block_trans_models(config, num_components)
        self.build_other_models(config, num_components)
    
    @classmethod
    def from_config(cls, config, num_components, **args):
        model = cls(config, num_components, **args)
        return model

    def build_base_model(self, config, num_components):
        raise NotImplementedError()

    def build_other_models(self, config, num_components):
        # For example, approximation models
        raise NotImplementedError()

    def build_local_models(self, config, num_components):
        self.local_models = nn.ModuleList([GPT2TimeModelMultiHosts(config, num_component) for num_component in num_components])

    def build_block_trans_models(self, config, num_components):
        raise NotImplementedError

    def get_local_outputs(self, input_time, *args, **argv):
        raise NotImplementedError()

    def get_block_trans_outputs(self, input_time, *args, **argv):
        raise NotImplementedError

    def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
        """
        Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method.
        """
        inputs = {"input_ids": input_ids, "input_time": kwargs["input_time"]}
        # print(input_ids)
        # print(kwargs)
        for prop in self.config.data_properties:
            inputs[prop] = kwargs[prop]
        
        inputs["generation"]=True

        if self.config.num_beams > 1:
            # extend the input_time and 
            bsz = inputs[self.config.data_properties[0]].size(0)
            expanded_return_idx = (
                torch.arange(bsz).view(-1, 1).repeat(1, self.config.num_beams).view(-1).to(input_ids.device)
            )
            for prop in self.config.data_properties:
                inputs[prop] = inputs[prop].index_select(0, expanded_return_idx)
            inputs["input_time"] = inputs["input_time"].index_select(0, expanded_return_idx)


        return inputs
  
    def get_country_long_index(self, countries, continents):
        num_countries_in_continents = torch.tensor([0] + self.num_components[:-1]).to(countries.device)
        country_long_index = countries + torch.cumsum(num_countries_in_continents, dim=0)[continents] # country_index
        # print(country_long_index) # 
        return country_long_index

    def forward(self, input_time, **argv):
        local_output = self.get_local_outputs(input_time, **argv)
        block_transoutput = self.get_block_trans_outputs(input_time, **argv)
        return GPTOutputs(logits=local_output.logits + block_transoutput.logits)
        
class GPT2TimeModelMultiHostsHierarchy(GPT2TimeModelMultiHostsHierarchyBase):
    def __init__(self, config, num_components, symmetry=True, **args) -> None:
        self.num_components = num_components
        if config.implement_version == 1:
            self._model_class = GPT2TimeModelMultiHosts # .from_config(config, self.config.num_host, output_layer_config=output_layer_config)
        elif config.implement_version == 2:
            self._model_class = GPT2TimeModelMultiHostsV2 # .from_config(config, self.config.num_host)
        elif config.implement_version == 3:
            self._model_class = GPT2TimeModelMultiHostsParamShare # .from_config(config, self.config.num_host)
        elif config.implement_version == 4:
            self._model_class = GPT2TimeModelMultiHostsEmbedding
        elif config.implement_version == 8: # diagonal: ablation for removing in-block transmissions
            self._model_class = GPT2TimeModelMultiHostsDiag
        super().__init__(config, num_components, symmetry, **args)

        if config.trans_w_pos_function == "softplus":
            self.pos_func = torch.nn.Softplus()
        elif config.trans_w_pos_function == "relu":
            self.pos_func = torch.nn.ReLU()
        elif config.trans_w_pos_function == "sigmoid":
            self.pos_func = torch.nn.Sigmoid()
        
        self.eps = 1e-12
        # self.block_trans
        # tmp, debug
        # self.country_to_continent_indexing = nn.Parameter(
            # self.build_country_to_continent_indexing_matrix(num_components), requires_grad=False)
        
    def build_country_to_continent_indexing_matrix(self, num_components):
        total_country_number = sum(num_components)
        total_continent_number = len(num_components)
        # print(num_components)
        num_countries_in_continents = torch.tensor([0] + self.num_components[:-1])
        num_countries_in_continents_cumsum = torch.cumsum(num_countries_in_continents, dim=0)
        # print(total_continent_number, total_country_number, num_countries_in_continents_cumsum)
        m = torch.zeros(total_continent_number, total_country_number)
        for i, c in enumerate(num_components):
            m[i, num_countries_in_continents_cumsum[i]:num_countries_in_continents_cumsum[i]+c] = 1.0
        # print(m)
        return m

    def build_base_model(self, config, num_components):
        if config.transformer_offset:
            self.base_models = {
                "trans_base": transformers.GPT2LMHeadModel(config),
                "offsets_base": transformers.GPT2LMHeadModel(config),
                }
        else:
            self._base = transformers.GPT2LMHeadModel(config)
            self.base_models = {
                "trans_base": self._base,
                "offsets_base": self._base,
                }

    def build_local_models(self, config, num_components):
        country_model_config = deepcopy(config)
        setattr(country_model_config, "data_property", config.data_properties[1])

        if getattr(config, "add_geo_info", False):
            self.local_models = []

            geo_feats_map = pkl.load(open(config.geo_feats_path, "rb")) 
            feats_size = len(list(geo_feats_map.values())[0])
            # print(feats_size)
            for _, countries in config.contient2country:
                countries_feats = []
                for country in countries:
                    if country not in geo_feats_map:
                        # print(country)
                        countries_feats.append([torch.nan] * feats_size)
                    else:
                        feats = geo_feats_map[country]
                        countries_feats.append(feats)
                
                countries_feats = torch.Tensor(countries_feats)
                # print(countries_feats)
                # print(countries_feats[~torch.isnan(torch.sum(countries_feats, dim=-1))])
                ave_feats = torch.mean(countries_feats[~torch.isnan(torch.sum(countries_feats, dim=-1))], dim=0)
                # print(ave_feats)
                countries_feats[torch.isnan(countries_feats)] = ave_feats

                _model = self._model_class(country_model_config, len(countries), base_models=self.base_models, geo_info=countries_feats)
                # print(countries_feats, countries_feats.size())
                self.local_models.append(_model)
            self.local_models = nn.ModuleList(self.local_models)
            # exit()
        else:
            self.local_models = nn.ModuleList([self._model_class(country_model_config, num_component, base_models=self.base_models) for num_component in num_components])
        # for local_model in self.local_models:
        #     print(local_model.num_component)

    def build_block_trans_models(self, config, num_components):
        if not self.config.reuse_transformer_for_cross_block_trans or self.config.block_trans_model == "prepend":
            config_copy = deepcopy(config)
            config_copy.n_layer = config.block_trans_model_n_layer
            self.cross_block_trans_base_model = transformers.GPT2LMHeadModel(config_copy)

        total_country_number = sum(num_components)
        total_continent_number = len(num_components)

        if self.config.trans_group_weight_rely_on_aa:
            D = config.vocab_size
        else:
            D = 1

        if self.config.block_trans_model == "country_to_continent":
            self.country_to_continent_trans_weight = nn.Parameter(torch.randn(total_country_number, config.hidden_size, D * len(num_components)), requires_grad=True)
            self.country_to_continent_trans_bias = nn.Parameter(torch.randn(total_country_number, D * len(num_components)), requires_grad=True)
        elif self.config.block_trans_model == "continent_to_continent":
            self.continent_to_continent_trans_weight = nn.Parameter(torch.randn(total_continent_number, config.hidden_size, D * len(num_components)), requires_grad=True)
            self.continent_to_continent_trans_bias = nn.Parameter(torch.randn(total_continent_number, D * len(num_components)), requires_grad=True)
            # self.continent_to_continent_trans = nn.Linear(config.hidden_size, config.vocab_size * len(num_components) * len(num_components))
            # self.continent_to_country_trans = nn.Linear(config.hidden_size, config.vocab_size * total_country_number)
            self.continent_to_country_trans_weight = nn.Parameter(torch.randn(total_country_number, config.hidden_size, D), requires_grad=True)
            self.continent_to_country_trans_bias = nn.Parameter(torch.randn(total_country_number, D), requires_grad=True)
        elif self.config.block_trans_model == "country_to_country":
            self.country_to_country_trans_weight = nn.Parameter(torch.randn(total_country_number, config.hidden_size, config.vocab_size * total_country_number), requires_grad=True)
            self.country_to_country_trans_bias = nn.Parameter(torch.randn(total_country_number, config.vocab_size * total_country_number), requires_grad=True)
        elif self.config.block_trans_model == "prepend":
            # <continent> <country> ABC -> H -> V * C
            self.country_to_continent_trans_prepend = nn.Linear(config.hidden_size, D * len(num_components))
        elif self.config.block_trans_model == "embedding":
            self.country_to_continent_indexing = nn.Parameter(self.build_country_to_continent_indexing_matrix(num_components), requires_grad=False)
            self.country_embeddings_size = config.block_trans_embedding_size
            self.country_embeddings = nn.Linear(config.hidden_size, 
                                                total_country_number * self.country_embeddings_size * config.vocab_size)
        else:
            raise NotImplementedError
    
    def block_trans_embeddings(self, hidden_states, continent, country, continent_logits, return_trans_weight=False):
        V = self.config.vocab_size
        B, L = hidden_states.size(0), hidden_states.size(1)
        H = self.country_embeddings_size
        
        num_countries_in_continents = torch.tensor([0] + self.num_components[:-1]).to(hidden_states.device)
        country_long_index = country + torch.cumsum(num_countries_in_continents, dim=0)[continent]

        all_country_embeddings = self.country_embeddings(hidden_states) # [B, L, H'*country_num*V]
        all_country_embeddings = all_country_embeddings.view(B, L, V, H, -1) # [B, L, V, H', country_num]

        all_contient_embeddings = torch.mm( # [B, L, V, H', continent_number]
            all_country_embeddings.view(-1, all_country_embeddings.size(-1)), self.country_to_continent_indexing.T).view(B, L, V, H, -1)

        # [B, L, V, H]
        country_embeddings = torch.gather(all_country_embeddings, dim=-1, index=country_long_index.view(-1, 1, 1, 1, 1).expand(B, L, V, H, 1)).squeeze(-1)
        
        _trans_weight = torch.matmul(country_embeddings.view(-1, 1, H), 
                                    all_contient_embeddings.view(-1, H, all_contient_embeddings.size(-1))).view(B, L, V, -1) # [B, L, V, continent_number]
        _trans_weight = self.pos_func(_trans_weight) # Make sure it is possitive

        continent_mask = 1.0 - F.one_hot(continent, num_classes=len(self.num_components))
        continent_mask = continent_mask.to(hidden_states.device) #[B, 6]
        
        cross_continent_logits = torch.sum((continent_logits * _trans_weight) * continent_mask.view(B, 1, 1, -1), dim=-1) # [B, L, V]
        # print(cross_continent_logits.size()) # [B, L, V]
        # print("cross_continent_logits", cross_continent_logits)
        # print("_trans_weight", _trans_weight)
        # print("continent_logits", continent_logits)
        # exit()
        if return_trans_weight:
            return cross_continent_logits, {"trans_weight": _trans_weight,"cross_group_trans_weight": _trans_weight * continent_mask.view(B, 1, 1, -1)}
        return cross_continent_logits, {"cross_group_trans_weight": _trans_weight * continent_mask.view(B, 1, 1, -1)}

    def block_trans_continent_to_continent(self, hidden_states, continent, country, continent_logits, return_trans_weight=False):
        V = self.config.vocab_size
        B, L = hidden_states.size(0), hidden_states.size(1)
        num_countries_in_continents = torch.tensor([0] + self.num_components[:-1]).to(hidden_states.device)
        country_long_index = country + torch.cumsum(num_countries_in_continents, dim=0)[continent]
        # print(country_long_index)
        # print("hidden_states", hidden_states.size()) # [B, L, H]
        # print(self.continent_to_continent_trans_weight[continent].size()) # [B, H, V*6]
        # print(torch.bmm(hidden_states, self.continent_to_continent_trans_weight[continent]).size()) # [B, L, V*6]
        # print(self.continent_to_continent_trans_bias[continent].size()) # [B, V*6]

        _cross_continents_trans_weight = torch.bmm(hidden_states, self.continent_to_continent_trans_weight[continent]) + self.continent_to_continent_trans_bias[continent].unsqueeze(1) # [B, L, V, 6]
        # print(_cross_continents_trans_weight.size()) # [B, L, V*6]

        # print(self.continent_to_country_trans_weight[country_long_index].size()) # [B, H, V]
        # print(self.continent_to_country_trans_bias[country_long_index].size()) # [B, V]
        _continents_to_country_trans_weight = torch.bmm(hidden_states, self.continent_to_country_trans_weight[country_long_index]) + self.continent_to_country_trans_bias[country_long_index].unsqueeze(1) # [B, L, V]
        # print(_continents_to_country_trans_weight.size()) # [B, L, V]
        
        _trans_weight = _cross_continents_trans_weight.view(B, L, V, -1) + _continents_to_country_trans_weight.view(B, L, V, 1)
        # print(_trans_weight.size()) # [B, L, V, 6]
        _trans_weight = self.pos_func(_trans_weight) # Make sure it is positive

        continent_mask = 1.0 - F.one_hot(continent, num_classes=len(self.num_components))
        continent_mask = continent_mask.to(hidden_states.device) #[B, 6]
        
        cross_continent_logits = torch.sum((continent_logits * _trans_weight) * continent_mask.view(B, 1, 1, -1), dim=-1) # [B, L, V]
        # print(cross_continent_logits.size()) # [B, L, V]
        if return_trans_weight:
            info = {"continent_to_continent": _cross_continents_trans_weight.view(B, L, V, -1), 
                    "continent_to_country": _continents_to_country_trans_weight, 
                    "trans_weight": _trans_weight,
                    "cross_group_trans_weight": _trans_weight * continent_mask.view(B, 1, 1, -1)}
            return cross_continent_logits, info
        
        return cross_continent_logits, {"cross_group_trans_weight": _trans_weight * continent_mask.view(B, 1, 1, -1)}
    
    
    # def get_country_embeddings(self, country_long_index, hidden_states):
    #     country_embeddings_weight = self.country_embeddings_weight(country_long_index).view(-1, self.config.hidden_size, self.config.hidden_size) # [B, H, H']
    #     country_embeddings_bias = self.country_embeddings_weight(country_long_index).view(-1, self.config.hidden_size) # [B, H']
    #     country_embeddings = torch.sum(country_embeddings_weight.unsqueeze(1) * hidden_states.unsqueeze(-1), dim=-2) + country_embeddings_bias.unsqueeze(1) # [B, L, H']
    #     return country_embeddings

    def block_trans_country_to_country(self, hidden_states, continent, country, all_country_logits, return_trans_weight=False):
        V = self.config.vocab_size
        B, L = hidden_states.size(0), hidden_states.size(1)
        num_countries_in_continents = torch.tensor([0] + self.num_components[:-1]).to(hidden_states.device)
        country_long_index = country + torch.cumsum(num_countries_in_continents, dim=0)[continent]
        
        _cross_continents_trans_weight = torch.bmm(hidden_states, self.country_to_country_trans_weight[country_long_index]) \
            + self.country_to_country_trans_bias[country_long_index].unsqueeze(1) 
        _cross_continents_trans_weight = _cross_continents_trans_weight.view(B, L, V, -1) # [B, L, V, total_country_number]
        _trans_weight = _cross_continents_trans_weight
        _trans_weight = self.pos_func(_trans_weight) # make sure >= 0
        # should remove the transmission within continents, but I am too lazy to do that.

        cross_continent_logits = torch.sum((all_country_logits * _trans_weight), dim=-1) # [B, L, V]
        # print(cross_continent_logits.size()) # [B, L, V]
        if return_trans_weight:
            info = {"trans_weight": _trans_weight,}
            return cross_continent_logits, info
        
        return cross_continent_logits, {"trans_weight": _trans_weight}
    
    def block_trans_country_to_continent(self, hidden_states, continent, country, continent_logits, return_trans_weight=False):
        V = self.config.vocab_size
        B, L = hidden_states.size(0), hidden_states.size(1)
        # print(country)
        # print(continent)
        # print(self.num_components)
        num_countries_in_continents = torch.tensor([0] + self.num_components[:-1]).to(hidden_states.device)
        # print(num_countries_in_continents)
        # print(torch.cumsum(num_countries_in_continents, dim=0))
        # print(num_countries_in_continents)
        country_long_index = country + torch.cumsum(num_countries_in_continents, dim=0)[continent]
        # print(country_long_index)
        # exit()
        # print(hidden_states.size()) # [B, L, H]
        # print(hidden_states.size(), continent.size(), country.size())
        # print(self.country_to_continent_trans_weight[country_long_index].size())
        # print(self.country_to_continent_trans_bias[country_long_index].size())
        _trans_weight = torch.bmm(hidden_states, self.country_to_continent_trans_weight[country_long_index]) + self.country_to_continent_trans_bias[country_long_index].unsqueeze(1) # [B, L, V, 6]
        # print(_trans_weight.size())
        # print("_trans_weight(before pos_func)", _trans_weight.max(), _trans_weight.min(), _trans_weight.mean(), _trans_weight.size())
        _trans_weight = self.pos_func(_trans_weight) # Make sure it is possitive
        # _trans_weight = nn.ReLU()(_trans_weight) # Make sure it is possitive
        _trans_weight = _trans_weight.view(B, L, -1, len(self.num_components)) # [B, L, V, 6] or [B, L, 1, 6]
        # print(_trans_weight.size())

        continent_mask = 1.0 - F.one_hot(continent, num_classes=len(self.num_components))
        # print(continent_mask.size(), continent_mask[0], continent[0])
        # print(continent_logits.size())
        # exit()
        continent_mask = continent_mask.to(hidden_states.device) #[B, 6]
        
        cross_continent_logits = torch.sum((continent_logits * _trans_weight) * continent_mask.view(B, 1, 1, -1), dim=-1) # [B, L, V]
        # print("continent_logits", continent_logits.max(), continent_logits.min(), continent_logits.mean(), continent_logits.size())
        # print("_trans_weight", _trans_weight.max(), _trans_weight.min(), _trans_weight.mean(), _trans_weight.size())

        # print(cross_continent_logits.size()) # [B, L, V]
        # print("cross_continent_logits", cross_continent_logits)
        # print("_trans_weight", _trans_weight)
        # print("continent_logits", continent_logits)

        if return_trans_weight:
            return cross_continent_logits, {"trans_weight": _trans_weight,"cross_group_trans_weight": _trans_weight * continent_mask.view(B, 1, 1, -1)}
        return cross_continent_logits, {"cross_group_trans_weight": _trans_weight * continent_mask.view(B, 1, 1, -1)}

        # if self.config.block_trans_parameterization == "matrix":
        # elif self.config.block_trans_parameterization == "embedding":
        H = self.country_embeddings_size
        V = self.config.vocab_size
        
        # country_embeddings_weight = self.country_embeddings_weight(country_long_index).view(-1, self.config.hidden_size, self.config.hidden_size) # [B, H, H']
        # country_embeddings_bias = self.country_embeddings_weight(country_long_index).view(-1, self.config.hidden_size) # [B, H']
        # country_embeddings = torch.sum(country_embeddings_weight.unsqueeze(1) * hidden_states.unsqueeze(-1), dim=-2) + country_embeddings_bias.unsqueeze(1) # [B, L, H']
        all_country_embeddings = self.country_embeddings(hidden_states) # [B, L, H'*country_num*V]
        all_country_embeddings = all_country_embeddings.view(B, L, V, H, -1) # [B, L, V, H', country_num]
        # print(all_country_embeddings.size())
        # print(self.country_to_continent_indexing.size())
        all_contient_embeddings = torch.mm( # [B, L, V, H', continent_number]
            all_country_embeddings.view(-1, all_country_embeddings.size(-1)), self.country_to_continent_indexing.T).view(B, L, V, H, -1)
        # print(all_contient_embeddings.size())

        # [B, L, V, H]
        country_embeddings = torch.gather(all_country_embeddings, dim=-1, index=country_long_index.view(-1, 1, 1, 1, 1).expand(B, L, V, H, 1)).squeeze(-1)
        # print(country_embeddings.size())
        _trans_weight = torch.matmul(country_embeddings.view(-1, 1, H), 
                                    all_contient_embeddings.view(-1, H, all_contient_embeddings.size(-1))).view(B, L, V, -1) # [B, L, V, continent_number]
        # print(_trans_weight.size())
        # exit()
        # _trans_weight = torch.bmm(hidden_states, self.country_to_continent_trans_weight[country_long_index]) + self.country_to_continent_trans_bias[country_long_index].unsqueeze(1) # [B, L, V, 6]
        _trans_weight = self.pos_func(_trans_weight) # Make sure it is possitive

        continent_mask = 1.0 - F.one_hot(continent, num_classes=len(self.num_components))
        continent_mask = continent_mask.to(hidden_states.device) #[B, 6]
        
        cross_continent_logits = torch.sum((continent_logits * _trans_weight) * continent_mask.view(B, 1, 1, -1), dim=-1) # [B, L, V]
        # print(cross_continent_logits.size()) # [B, L, V]
        # print("cross_continent_logits", cross_continent_logits)
        # print("_trans_weight", _trans_weight)
        # print("continent_logits", continent_logits)
        # exit()
        if return_trans_weight:
            return cross_continent_logits, {"continent_to_country": _trans_weight}
        return cross_continent_logits, None

    def block_trans_country_to_continent_prepend(self, hidden_states, continent, country, continent_logits, return_trans_weight=False):
        V = self.config.vocab_size
        B, L = hidden_states.size(0), hidden_states.size(1)

        _trans_weight = self.country_to_continent_trans_prepend(hidden_states) # [B, L, V*C]
        # print(_trans_weight.size())
        _trans_weight = self.pos_func(_trans_weight) # Make sure it is possitive
        _trans_weight = _trans_weight.view(B, L, -1, len(self.num_components)) # [B, L, V, 6] or [B, L, 1, 6]
        # print(_trans_weight.size())
        continent_mask = 1.0 - F.one_hot(continent, num_classes=len(self.num_components))
        continent_mask = continent_mask.to(hidden_states.device) #[B, 6]
        cross_continent_logits = torch.sum((continent_logits * _trans_weight) * continent_mask.view(B, 1, 1, -1), dim=-1) # [B, L, V]

        if return_trans_weight:
            return cross_continent_logits, {"trans_weight": _trans_weight,"cross_group_trans_weight": _trans_weight * continent_mask.view(B, 1, 1, -1)}
        return cross_continent_logits, {"cross_group_trans_weight": _trans_weight * continent_mask.view(B, 1, 1, -1)}

    def build_other_models(self, config, num_components):
        # For example, approximation models
        # if getattr(self.config, "num_component_continent", None) is not None:
            # num_component_continent = self.config.num_component_continent
        if self.config.block_trans_model != "country_to_country":
            num_component_continent = len(num_components) # 6
            if config.continent_share_base_models:
                base_models = base_models=self.base_models
            else:
                base_models = None

            _model_config = deepcopy(config)
            setattr(_model_config, "data_property", config.data_properties[0]) # should be continent
            setattr(_model_config, "add_geo_info", False) # TODO:

            if self.config.use_simple_continent_model:
                self.continent_evolution_model = GPT2TimeModelMultiHostsSimple(_model_config, num_component_continent, base_models=base_models)
            else:
                self.continent_evolution_model = self._model_class(_model_config, num_component_continent, base_models=base_models)

    def _split_batch(self, mask, input_time, **argv):
        new_args = []
        new_argv = {}

        # for x in args:
        #     if isinstance(x, torch.Tensor):
        #         new_args.append(x[mask])
        #     else:
        #         new_args.append(x)
        
        for key in argv:
            if isinstance(argv[key], torch.Tensor):
                new_argv[key] = argv[key][mask]
            else:
                new_argv[key] = argv[key]

        return input_time[mask], new_argv

    def get_continent_evolution_outputs(self, input_time, trans_cache_hidden_states=None, offsets_cache_hidden_states=None, **argv):
        if self.config.continent_share_base_models:
            continent_outputs = self.continent_evolution_model.forward(input_time, **argv, trans_cache_hidden_states=trans_cache_hidden_states, \
                    offsets_cache_hidden_states=offsets_cache_hidden_states, return_component_logits=True) # [B, L, V, 6]
        else:
            continent_outputs = self.continent_evolution_model.forward(input_time, **argv, return_component_logits=True) # [B, L, V, 6]
            
        return continent_outputs

    def cross_block_trans_base_model_forward(self, argv, trans_cache_hidden_states):
        
        if self.config.block_trans_model == "prepend":
            input_ids = argv["input_ids"] # [B, L]
            continents = argv.get(self.config.data_properties[0]) # B
            countries = argv.get(self.config.data_properties[1]) # B
            prepend_input_ids = torch.cat([continents.unsqueeze(1), countries.unsqueeze(1), input_ids], dim=1) # [B, L+2]
            # print(input_ids.size(), prepend_input_ids, prepend_input_ids.size())
            
            if argv.get("attention_mask") is not None:
                attention_mask = argv.get("attention_mask")
                attention_mask_prepend = attention_mask.new_ones(attention_mask.size(0), 2)
                attention_mask = torch.cat([attention_mask_prepend, attention_mask], dim=1) # [B, L + 2]
                # print(attention_mask.size())
            else:
                attention_mask = argv.get("attention_mask")

            cross_block_hidden_states = self.cross_block_trans_base_model.forward(input_ids = prepend_input_ids, 
                                                                                  labels = prepend_input_ids, 
                                                                                  attention_mask = attention_mask, 
                                                                                  output_hidden_states=True).hidden_states[-1]
            # [B, L+2, H]
            return cross_block_hidden_states[:, :-2, :]

        else:
            if self.config.reuse_transformer_for_cross_block_trans:
                cross_block_hidden_states = trans_cache_hidden_states
            else:
                cross_block_hidden_states = self.cross_block_trans_base_model.forward(input_ids = argv["input_ids"], labels = argv.get("labels"), \
                            attention_mask = argv.get("attention_mask"), output_hidden_states=True).hidden_states[-1]
                    
        return cross_block_hidden_states

    def forward(self, input_time, **argv):
        return_rates_matrix = argv.get("return_rates_matrix", False)
        return_init_prob = argv.get("return_init_prob", False)
        return_cross_block_trans = argv.get("return_cross_block_trans", False)
        
        # print(input_time)
        # print(argv.get("continent"))
        # print(argv.get("country"))
        # print(self.base_models["trans_base"].device)
        # print(self.base_models["offsets_base"].device)
        # torch.cuda.synchronize()
        # t0 = time.time()
        
        trans_cache_hidden_states = self.base_models["trans_base"].forward(input_ids = argv["input_ids"], labels = argv.get("labels"), \
                        attention_mask = argv.get("attention_mask"), output_hidden_states=True).hidden_states[-1]
        
        if self.config.transformer_offset:
            offsets_cache_hidden_states = self.base_models["offsets_base"].forward(input_ids = argv["input_ids"], labels = argv.get("labels"), \
                            attention_mask = argv.get("attention_mask"), output_hidden_states=True).hidden_states[-1]
        else:
            offsets_cache_hidden_states = trans_cache_hidden_states
        
        
        # Get country-wise outputs
        # continents = argv.get("continent")
        continents = argv.get(self.config.data_properties[0])
        # countries = argv.get("country")
        countries = argv.get(self.config.data_properties[1])
        # print("continents", continents)
        # print("countries", countries)
        # exit()

        all_country_logtis = []
        all_sum_of_country_logits = []
        all_country_trans_rates = [] # transition rate matrices for countries
        all_country_init_probs = [] # initial probability for countries

        indices = []

        # torch.cuda.synchronize()
        # t1 = time.time()

        for i in range(len(self.num_components)):
            mask = (continents == i)
            # print(mask)
            if torch.sum(mask).item() <= 0:
                continue

            indices.append(torch.nonzero(mask))
            # print(indices)
            sub_input_time, sub_argv = self._split_batch(mask, input_time, trans_cache_hidden_states=trans_cache_hidden_states, \
                offsets_cache_hidden_states=offsets_cache_hidden_states,
                **argv)
            # print(list(sub_argv.keys()))
            # if sub_input_time.size(0) == 0:
                # continue
            # print(i, sub_input_time.size())
            # for key in sub_argv:
                # if isinstance(sub_argv[key], torch.Tensor):
                    # print(key, sub_argv[key].size())
            # print(torch.max(sub_argv["country"]),)
            # assert torch.max(sub_argv["country"]) <= self.local_models[i].num_component - 1
            # print(argv.get("generation", False))
            # print(sub_argv["generation"])
            country_logits = self.local_models[i].forward(sub_input_time, **sub_argv, return_component_logits=True) # [B, L, V, K]  return_logits = True
            # print(country_logits.info_dict["component_logits"])
            # print(country_logits.info_dict["component_logits"].size()) # log_p
            # print(torch.logsumexp(country_logits.info_dict["component_logits"], dim=-1).size())
            # print(torch.logsumexp(country_logits.info_dict["component_logits"], dim=-1))
            # print(country_logits.logits.size())
            if return_rates_matrix:
                all_country_trans_rates.extend(list(country_logits.info_dict["rates_matrix"]))
            if return_init_prob:
                all_country_init_probs.extend(list(country_logits.info_dict["init_prob"]))

            # print(country_logits.info_dict["component_logits"].size())
            # num_c = 
            all_sum_of_country_logits.append(
                torch.logsumexp(country_logits.info_dict["component_logits"], dim=-1) 
                - (math.log(country_logits.info_dict["component_logits"].size(-1)) if getattr(self.config, "apply_log_softmax", False) else 0.0)
                )
            all_country_logtis.append(country_logits.logits)
            # all_country_logtis.append(sub_argv["input_ids"]) # TODO: for debug
        # print(len(all_country_logtis))
        all_country_logtis = torch.cat(all_country_logtis, dim=0)
        # print(all_country_logtis.size())
        all_sum_of_country_logits = torch.cat(all_sum_of_country_logits, dim=0)
        indices = torch.cat(indices, dim=0).squeeze(-1)
        country_logtis = all_country_logtis[torch.argsort(indices)] # [B, L, V]
        sum_of_country_logits = all_sum_of_country_logits[torch.argsort(indices)] # Could be used to train global model
        # print(sum_of_country_logits.size())
        # print(continents)
        
        # collect information from countries
        if len(all_country_trans_rates) > 0:
            # all_country_trans_rates_list = [] # [list(x) for x in all_country_trans_rates]
            # print([ x.size() for x in all_country_trans_rates])
            # print(torch.argsort(indices))
            all_country_trans_rates = [all_country_trans_rates[x.item()] for x in torch.argsort(indices)]
            # print(len(all_country_trans_rates), all_country_trans_rates[0].size(), argv["input_ids"].size())
            all_country_trans_rates = [torch.gather(x, 1, argv["input_ids"][i].view(argv["input_ids"].size(1), 1, 1, 1).expand(-1, -1, x.size(-2), x.size(-1))).squeeze(1) for i, x in enumerate(all_country_trans_rates)]
            # print([ x.size() for x in all_country_trans_rates])
            # all_country_trans_rates = torch.cat(all_country_trans_rates, dim=0)
            # all_country_trans_rates = all_country_trans_rates[torch.argsort(indices)]
        if len(all_country_init_probs) > 0:
            # print([ x.size() for x in all_country_init_probs])
            all_country_init_probs = [all_country_init_probs[x.item()] for x in torch.argsort(indices)]
            # print(len(all_country_init_probs), all_country_init_probs[0].size())
            all_country_init_probs = [torch.gather(x, 1, argv["input_ids"][i].view(argv["input_ids"].size(1), 1, 1).expand(-1, -1, x.size(-1))).squeeze(1) for i, x in enumerate(all_country_init_probs)]
            # print([ x.size() for x in all_country_trans_rates])
            # all_country_init_probs = torch.cat(all_country_init_probs, dim=0)
            # all_country_init_probs = all_country_init_probs[torch.argsort(indices)]
        # exit()

        # print(sum_of_country_logits.size())
        # print(torch.all(all_country_logtis == argv["input_ids"]))
        # print(country_logtis.size())
        # exit()

        # torch.cuda.synchronize()
        # t2 = time.time()

        ## 2. Get continents outputs
        continent_outputs = self.get_continent_evolution_outputs(
            input_time, 
            trans_cache_hidden_states=trans_cache_hidden_states, 
            offsets_cache_hidden_states=offsets_cache_hidden_states, **argv)
        # continent_outputs = self.continent_evolution_model.forward(input_time, **argv, trans_cache_hidden_states=trans_cache_hidden_states, \
                # offsets_cache_hidden_states=offsets_cache_hidden_states, return_component_logits=True) # [B, L, V, 6]
        continent_logits = continent_outputs.logits
        # print(continent_logits.size())
        # print(continent_outputs.info_dict["component_logits"].size())
        # torch.cuda.synchronize()
        # t3 = time.time()

        ## 3. Calculate transmission from continents to coutries
        # calculate the A_{country, continent=1,2,3,4,5,6}
        # Could also be: A_{continent', continet=1,2,3,4,5,6} * A_{continent', country}
        # print(self.config.block_trans_model)
        cross_block_hidden_states = self.cross_block_trans_base_model_forward(argv, trans_cache_hidden_states)
        
        # if self.config.reuse_transformer_for_cross_block_trans:
        #     cross_block_hidden_states = trans_cache_hidden_states
        # else:
            # cross_block_hidden_states = self.cross_block_trans_base_model.forward(input_ids = argv["input_ids"], labels = argv.get("labels"), \
                        # attention_mask = argv.get("attention_mask"), output_hidden_states=True).hidden_states[-1]
        
        if self.config.block_trans_model == "continent_to_continent":
            cross_continent_logits, cross_block_trans = self.block_trans_continent_to_continent(
                cross_block_hidden_states, continents, countries, 
                continent_logits=torch.exp(continent_outputs.info_dict["component_logits"]),
                return_trans_weight=return_cross_block_trans
                )
        elif self.config.block_trans_model == "country_to_continent":
            cross_continent_logits, cross_block_trans  = self.block_trans_country_to_continent(
                cross_block_hidden_states, continents, countries, 
                continent_logits=torch.exp(continent_outputs.info_dict["component_logits"]),
                return_trans_weight=return_cross_block_trans)
        elif self.config.block_trans_model == "embedding":
            cross_continent_logits, cross_block_trans  = self.block_trans_embeddings(
                cross_block_hidden_states, continents, countries, 
                continent_logits=torch.exp(continent_outputs.info_dict["component_logits"]),
                return_trans_weight=return_cross_block_trans)
        elif self.config.block_trans_model == "prepend":
            cross_continent_logits, cross_block_trans = self.block_trans_country_to_continent_prepend(
                cross_block_hidden_states, continents, countries, 
                continent_logits=torch.exp(continent_outputs.info_dict["component_logits"]),
                return_trans_weight=return_cross_block_trans)
        
        # print(cross_continent_logits)
        total_logits = torch.log(torch.exp(country_logtis) + cross_continent_logits + self.eps) 
        # print("country_logtis", torch.exp(country_logtis).size()) # what size? # B, L, V
        # print("cross_continent_logits", cross_continent_logits.size()) # what size?
        # test what makes larger contributons?
        # print(argv["input_ids"].size())
        # 
        # # check and debug
        # contri_country = torch.gather(torch.exp(country_logtis)[:, :-1], -1, argv["input_ids"][:, 1:].unsqueeze(-1)).squeeze(-1)
        # contri_continent = torch.gather(cross_continent_logits[:, :-1], -1, argv["input_ids"][:, 1:].unsqueeze(-1)).squeeze(-1)
        # contri_all_continents = torch.gather(torch.exp(continent_outputs.info_dict["component_logits"])[:, :-1], 
        #                                      -2, argv["input_ids"][:, 1:].unsqueeze(-1).unsqueeze(-1).repeat(1,1,1,6)).squeeze(-2)
        # print(torch.mean(contri_country/torch.gather(torch.exp(total_logits), -1, argv["input_ids"][:, 1:].unsqueeze(-1)).squeeze(-1)))
        # print(torch.mean(contri_continent/torch.gather(torch.exp(total_logits), -1, argv["input_ids"][:, 1:].unsqueeze(-1)).squeeze(-1)))
        # print(torch.mean(contri_all_continents), torch.mean(contri_continent), (torch.mean(contri_country)))
        # print(torch.mean(country_logtis), torch.mean(cross_continent_logits))
        # print(torch.mean(torch.abs(total_logits-country_logtis)))
        # # exit()
        # print(total_logits.size(), country_logtis.size(), cross_continent_logits.size(), continent_logits.size())
        # print(total_logits)
        # torch.cuda.synchronize()
        # t4 = time.time()
        # print(t4 - t3) # some linear transformations
        # print(t3 - t2) # slow! transformer for cross continents
        # print(t2 - t1) # when K large, this is super SLOW! (eigh for local)
        # print(t1 - t0) # slow -> transformer for offset and trans rate


        # print("continent_logits", continent_logits.mean())
        # print("sum_of_country_logits", sum_of_country_logits.mean())
        # print(torch.exp(continent_outputs.info_dict["component_logits"]).mean()) # log_p

        info_dict = {
            "continent_logits": continent_logits, 
            # "continent_logits": torch.gather(continent_logits, 2, ), 
            "sum_of_country_logits": sum_of_country_logits, 
            "cross_group_trans_weight": cross_block_trans["cross_group_trans_weight"]}
        
        
        if return_cross_block_trans:
            # print(return_cross_block_trans)
            for key in info_dict:
                # info_dict[key] = cross_block_trans[key] # [B, L, V, K] / [B, L, V]
                ssize = list(info_dict[key].size())
                # print(key, ssize)
                ssize[2] = 1
                input_ids_expand = argv["input_ids"].view(argv["input_ids"].size() + (1,) * (info_dict[key].dim() - 2)) # [B, L]
                input_ids_expand = input_ids_expand.expand(*ssize)
                # print(input_ids_expand.size())
                info_dict[key] = torch.gather(info_dict[key], 2, input_ids_expand).squeeze(2)
                # print(info_dict[key].size())
        
        if return_rates_matrix:
            info_dict["country_trans_rates"] = all_country_trans_rates
            # print(all_country_trans_rates[0].size())
        if return_init_prob:
            info_dict["country_init_probs"] = all_country_init_probs
            # print(all_country_init_probs[0].size())
            # exit()

        # for key in info_dict:
        #     print(key)
        #     if isinstance(info_dict[key], list):
        #         # print(key, [x.size() for x in info_dict[key]])
        #         # print(key, [x.device for x in info_dict[key]])
        #         # info_dict[key] = [x.cpu() for x in info_dict[key]]
        #     else:
        #         # info_dict[key] = info_dict[key].cpu()
        #         print(key, info_dict[key].size())

        return GPTOutputs(logits=total_logits, info_dict=info_dict)


class GPT2TimeModelMultiHostsHierarchyRemoveCrossBlocks(GPT2TimeModelMultiHostsHierarchy):
    def __init__(self, config, num_components, symmetry=True, **args) -> None:
        super().__init__(config, num_components, symmetry, **args)
    
    def build_block_trans_models(self, config, num_components):
        return None

    def build_other_models(self, config, num_components):
        return None
    
    def forward(self, input_time, **argv):
        return_rates_matrix = argv.get("return_rates_matrix", False)
        return_init_prob = argv.get("return_init_prob", False)
        return_cross_block_trans = argv.get("return_cross_block_trans", False)

        trans_cache_hidden_states = self.base_models["trans_base"].forward(input_ids = argv["input_ids"], labels = argv.get("labels"), \
                        attention_mask = argv.get("attention_mask"), output_hidden_states=True).hidden_states[-1]
        
        if self.config.transformer_offset:
            offsets_cache_hidden_states = self.base_models["offsets_base"].forward(input_ids = argv["input_ids"], labels = argv.get("labels"), \
                            attention_mask = argv.get("attention_mask"), output_hidden_states=True).hidden_states[-1]
        else:
            offsets_cache_hidden_states = trans_cache_hidden_states
        
        
        # Get country-wise outputs
        continents = argv.get(self.config.data_properties[0])
        countries = argv.get(self.config.data_properties[1])

        all_country_logtis = []
        all_sum_of_country_logits = []
        all_country_trans_rates = [] # transition rate matrices for countries
        all_country_init_probs = [] # initial probability for countries

        indices = []

        for i in range(len(self.num_components)):
            mask = (continents == i)
            # print(mask)
            if torch.sum(mask).item() <= 0:
                continue

            indices.append(torch.nonzero(mask))
            sub_input_time, sub_argv = self._split_batch(mask, input_time, trans_cache_hidden_states=trans_cache_hidden_states, \
                offsets_cache_hidden_states=offsets_cache_hidden_states,
                **argv)

            country_logits = self.local_models[i].forward(sub_input_time, **sub_argv, return_component_logits=True) # [B, L, V, K]  return_logits = True

            if return_rates_matrix:
                all_country_trans_rates.extend(list(country_logits.info_dict["rates_matrix"]))
            if return_init_prob:
                all_country_init_probs.extend(list(country_logits.info_dict["init_prob"]))

            all_sum_of_country_logits.append(
                torch.logsumexp(country_logits.info_dict["component_logits"], dim=-1) 
                - (math.log(country_logits.info_dict["component_logits"].size(-1)) if getattr(self.config, "apply_log_softmax", False) else 0.0)
                )
            all_country_logtis.append(country_logits.logits)

        all_country_logtis = torch.cat(all_country_logtis, dim=0)
        all_sum_of_country_logits = torch.cat(all_sum_of_country_logits, dim=0)
        indices = torch.cat(indices, dim=0).squeeze(-1)
        country_logtis = all_country_logtis[torch.argsort(indices)] # [B, L, V]
        sum_of_country_logits = all_sum_of_country_logits[torch.argsort(indices)] # Could be used to train global model

        if len(all_country_trans_rates) > 0:
            # all_country_trans_rates_list = [] # [list(x) for x in all_country_trans_rates]
            # print([ x.size() for x in all_country_trans_rates])
            # print(torch.argsort(indices))
            all_country_trans_rates = [all_country_trans_rates[x.item()] for x in torch.argsort(indices)]
            # print(len(all_country_trans_rates), all_country_trans_rates[0].size(), argv["input_ids"].size())
            all_country_trans_rates = [torch.gather(x, 1, argv["input_ids"][i].view(argv["input_ids"].size(1), 1, 1, 1).expand(-1, -1, x.size(-2), x.size(-1))).squeeze(1) for i, x in enumerate(all_country_trans_rates)]
            # print([ x.size() for x in all_country_trans_rates])
            # all_country_trans_rates = torch.cat(all_country_trans_rates, dim=0)
            # all_country_trans_rates = all_country_trans_rates[torch.argsort(indices)]
        if len(all_country_init_probs) > 0:
            # print([ x.size() for x in all_country_init_probs])
            all_country_init_probs = [all_country_init_probs[x.item()] for x in torch.argsort(indices)]
            # print(len(all_country_init_probs), all_country_init_probs[0].size())
            all_country_init_probs = [torch.gather(x, 1, argv["input_ids"][i].view(argv["input_ids"].size(1), 1, 1).expand(-1, -1, x.size(-1))).squeeze(1) for i, x in enumerate(all_country_init_probs)]
            # print([ x.size() for x in all_country_trans_rates])
            # all_country_init_probs = torch.cat(all_country_init_probs, dim=0)
            # all_country_init_probs = all_country_init_probs[torch.argsort(indices)]
        
        total_logits = country_logtis # torch.log(torch.exp(country_logtis)) 
        
        info_dict = {"sum_of_country_logits": sum_of_country_logits}
        if return_rates_matrix:
            info_dict["country_trans_rates"] = all_country_trans_rates
        if return_init_prob:
            info_dict["country_init_probs"] = all_country_init_probs

        return GPTOutputs(logits=total_logits, info_dict=info_dict)


class GPT2TimeModelMultiHostsHierarchyRandomBlocks(GPT2TimeModelMultiHostsHierarchy):
    def __init__(self, config, num_components, symmetry=True, **args) -> None:
        self.num_blocks = getattr(config, "num_blocks", len(num_components)) # 6 or whatever 
        self.num_country_in_each_block = math.ceil((sum(num_components) / self.num_blocks))
        self.num_all_countries = self.num_country_in_each_block * self.num_blocks
        logging.info("num_blocks: %d, num_country_in_each_block: %d, num_all_countries: %d" % (self.num_blocks, self.num_country_in_each_block, self.num_all_countries))
        super().__init__(config, num_components, symmetry, **args)
        # implement random sampling of clusters
    
    def build_base_model(self, config, num_components):
        if config.transformer_offset:
            raise NotImplementedError
        else:
            self._base = transformers.GPT2LMHeadModel(config)
            self._local_embedding_ffn = nn.Linear(config.hidden_size, config.vocab_size * self.num_all_countries * config.block_trans_embedding_size)
            self.base_models = {
                "trans_base": self._base,
                "offsets_base": self._base,
                "local_embedding_ffn": self._local_embedding_ffn
                }

    def build_block_trans_models(self, config, num_components):
        country_model_config = deepcopy(config)
        setattr(country_model_config, "data_property", config.data_properties[0])
        setattr(country_model_config, "local_embedding_size", config.block_trans_embedding_size)
        self.continent_evolution_model = GPT2TimeModelMultiHostsEmbedding(
            country_model_config, self.num_blocks, base_models=self.base_models)


    def build_local_models(self, config, num_components):
        country_model_config = deepcopy(config)
        setattr(country_model_config, "data_property", config.data_properties[1])
        setattr(country_model_config, "local_embedding_size", config.block_trans_embedding_size)
        self.local_model = GPT2TimeModelMultiHostsEmbedding(country_model_config, self.num_country_in_each_block, base_models=self.base_models)

    def block_trans_embeddings(self, continent, continent_logits, country_embeddings, all_contient_embeddings, return_trans_weight=False):
        B, L, V, H = country_embeddings.size()
        # country_embeddings: [B, L, V, H]
        # all_contient_embeddings: [B, L, V, H, num_blocks]
        # [B, L, V, H]        
        _trans_weight = torch.matmul(country_embeddings.view(-1, 1, H), 
                                    all_contient_embeddings.view(-1, H, all_contient_embeddings.size(-1))).view(B, L, V, -1) # [B, L, V, num_blocks]
        _trans_weight = self.pos_func(_trans_weight) # Make sure it is possitive

        continent_mask = 1.0 - F.one_hot(continent, num_classes=len(self.num_components))
        continent_mask = continent_mask.to(continent.device) #[B, 6]        
        cross_continent_logits = torch.sum((continent_logits * _trans_weight) * continent_mask.view(B, 1, 1, -1), dim=-1) # [B, L, V]
        if return_trans_weight:
            return cross_continent_logits, {"trans_weight": _trans_weight,"cross_group_trans_weight": _trans_weight * continent_mask.view(B, 1, 1, -1)}
        return cross_continent_logits, {"cross_group_trans_weight": _trans_weight * continent_mask.view(B, 1, 1, -1)}


    def forward(self, input_time, **argv):
        return_rates_matrix = argv.get("return_rates_matrix", False)
        return_init_prob = argv.get("return_init_prob", False)
        return_cross_block_trans = argv.get("return_cross_block_trans", False)

        countries = argv.pop(self.config.data_properties[1])
        real_continents = argv.pop(self.config.data_properties[0])
        country_long_index = self.get_country_long_index(countries, real_continents)

        B, L = argv["input_ids"].size(0), argv["input_ids"].size(1)
        V, C, H = self.config.vocab_size, self.num_all_countries, self.config.block_trans_embedding_size
        
        trans_cache_hidden_states = self.base_models["trans_base"].forward(input_ids = argv["input_ids"], labels = argv.get("labels"), \
                        attention_mask = argv.get("attention_mask"), output_hidden_states=True).hidden_states[-1]
        local_embeddings = self._local_embedding_ffn(trans_cache_hidden_states) # [B, L, V*C*H]
        local_embeddings = local_embeddings.view(B*L*V, H, C)
        local_embeddings_in_country = torch.gather(
            local_embeddings.view(B, -1, C), 1, country_long_index.view(B, 1, 1).expand(B, L*V*H, 1)).view(B, L, V, H)

        # shuffle the index of countries
        country_shuffle_index = torch.randperm(self.num_all_countries).to(local_embeddings.device)
        # print(country_shuffle_index.view(-1, self.num_country_in_each_block))

        reverse_index = torch.argsort(country_shuffle_index)
        # print(reverse_index)
        # print(country_long_index)
        country_long_index_shuffled = reverse_index[country_long_index]
        # print(country_long_index_shuffled)
        random_continent = country_long_index_shuffled // self.num_country_in_each_block
        random_country = country_long_index_shuffled % self.num_country_in_each_block
        
        # print(random_continent)
        # print(random_country)
        # exit()

        # shuffle local embeddings
        # print(local_embeddings[0, :, 6])
        local_embeddings_shuffle = local_embeddings.view(-1, C).T[country_shuffle_index].T.view(B*L*V, H, C)
        # print(local_embeddings_shuffle[0, :, 0] == local_embeddings[0, :, 6])
        local_embeddings_shuffle = local_embeddings_shuffle.view(B*L*V, H, self.num_blocks, self.num_country_in_each_block)
        # exit()
        local_embeddings_in_continent = torch.gather(
            local_embeddings_shuffle.view(B, L*V*H, self.num_blocks, self.num_country_in_each_block), 
            -2, random_continent.view(B, 1, 1, 1).expand(B, L*V*H, 1, self.num_country_in_each_block)).squeeze(-2)
        local_embeddings_in_continent = local_embeddings_in_continent.view(B, L, V, H, -1).transpose(-2, -1) # [B, L, V, K, H]
        # print(local_embeddings_continent_level.size())

        # Get country-level outputs 
        # print(random_country)
        argv[self.config.data_properties[1]] = random_country
        country_logits = self.local_model.forward(input_time, **argv, return_component_logits=True, 
                                                  cache_local_embeddings=local_embeddings_in_continent)
        sum_of_country_logits = torch.logsumexp(country_logits.info_dict["component_logits"], dim=-1) \
            - (math.log(country_logits.info_dict["component_logits"].size(-1)) if getattr(self.config, "apply_log_softmax", False) else 0.0)
        
        # Get continent-level outputs
        block_embeddings = torch.mean(local_embeddings_shuffle, dim=-1).transpose(-2, -1).view(B, L, V, self.num_blocks, H)
        # print(block_embeddings.size())
        argv[self.config.data_properties[0]] = random_continent # TODO : check?
        continent_outputs = self.get_continent_evolution_outputs(
            input_time, 
            cache_local_embeddings=block_embeddings, **argv)
        continent_logits = continent_outputs.logits

        # get cross-block transmissions
        cross_continent_logits, cross_block_trans  = self.block_trans_embeddings(
            continent=random_continent, 
            country_embeddings=local_embeddings_in_country, 
            all_contient_embeddings=block_embeddings.transpose(-2, -1),
            continent_logits=torch.exp(continent_outputs.info_dict["component_logits"]),
            return_trans_weight=return_cross_block_trans)

        total_logits = torch.log(torch.exp(country_logits.logits) + cross_continent_logits + self.eps) 
        
        info_dict = {"continent_logits": continent_logits, "sum_of_country_logits": sum_of_country_logits, 
                     "cross_group_trans_weight": cross_block_trans["cross_group_trans_weight"]}
        
        # if return_rates_matrix:
            # info_dict["country_trans_rates"] = all_country_trans_rates
        # if return_init_prob:
            # info_dict["country_init_probs"] = all_country_init_probs
        if return_cross_block_trans:
            for key in cross_block_trans:
                # info_dict[key] = cross_block_trans[key] # [B, L, V, K] / [B, L, V]
                ssize = list(cross_block_trans[key].size())
                ssize[2] = 1
                input_ids_expand = argv["input_ids"].view(argv["input_ids"].size() + (1,) * (cross_block_trans[key].dim() - 2)) # [B, L]
                input_ids_expand = input_ids_expand.expand(*ssize)
                info_dict[key] = torch.gather(cross_block_trans[key], 2, input_ids_expand).squeeze(2)

        return GPTOutputs(logits=total_logits, info_dict=info_dict)

    
class GPT2TimeModelMultiHostsHierarchyBruteForce(GPT2TimeModelMultiHostsHierarchy):
    def __init__(self, config, num_components, symmetry=True, **args) -> None:
        super().__init__(config, num_components, symmetry, **args)
    
    def build_other_models(self, config, num_components):
        # We don't need the global model here.
        pass

    def forward(self, input_time, **argv):
        trans_cache_hidden_states = self.base_models["trans_base"].forward(input_ids = argv["input_ids"], labels = argv.get("labels"), \
                        attention_mask = argv.get("attention_mask"), output_hidden_states=True).hidden_states[-1]
        if self.config.transformer_offset:
            offsets_cache_hidden_states = self.base_models["offsets_base"].forward(input_ids = argv["input_ids"], labels = argv.get("labels"), \
                            attention_mask = argv.get("attention_mask"), output_hidden_states=True).hidden_states[-1]
        else:
            offsets_cache_hidden_states = trans_cache_hidden_states
        
        continents = argv.get(self.config.data_properties[0])
        countries = argv.get(self.config.data_properties[1])
        
        all_country_logtis = []
        all_all_country_logits = []
        all_sum_of_country_logits = []
        indices = []

        for i in range(len(self.num_components)):
            indices.append(torch.nonzero(continents == i))
            _countries = countries.new_zeros(countries.size())
            _countries[continents == i] = countries[continents == i]
            argv[self.config.data_properties[1]] = _countries
            
            country_logits = self.local_models[i].forward(
                input_time, **argv, return_component_logits=True,
                trans_cache_hidden_states=trans_cache_hidden_states,
                offsets_cache_hidden_states=offsets_cache_hidden_states) # [B, L, V, K]  return_logits = True
            
            all_sum_of_country_logits.append(torch.logsumexp(country_logits.info_dict["component_logits"], dim=-1))
            all_country_logtis.append(country_logits.logits[continents == i])
            
            # print(country_logits.info_dict["component_logits"].size())
            all_all_country_logits.append(country_logits.info_dict["component_logits"]) # [B, L, V, # of countries in each continents]

        indices = torch.cat(indices, dim=0).squeeze(-1)
        all_country_logtis = torch.cat(all_country_logtis, dim=0) # [B, L, V]
        all_all_country_logits = torch.cat(all_all_country_logits, dim=-1) # [B, L, V, total_number_of_countries]
        # print(all_all_country_logits.size())
        country_logtis = all_country_logtis[torch.argsort(indices)] # [B, L, V]
        continents_logits = torch.stack(all_sum_of_country_logits, dim=-1) # [B, L, V, 6]

        ## 3. Calculate transmission from continents to coutries
        # calculate the A_{country, continent=1,2,3,4,5,6}
        # Could also be: A_{continent', continet=1,2,3,4,5,6} * A_{continent', country}
        if self.config.reuse_transformer_for_cross_block_trans:
            cross_block_hidden_states = trans_cache_hidden_states
        else:
            cross_block_hidden_states = self.cross_block_trans_base_model.forward(input_ids = argv["input_ids"], labels = argv.get("labels"), \
                        attention_mask = argv.get("attention_mask"), output_hidden_states=True).hidden_states[-1]
        
        if self.config.block_trans_model == "continent_to_continent":
            cross_continent_logits, _ = self.block_trans_continent_to_continent(
                cross_block_hidden_states, continents, countries, 
                continent_logits=torch.exp(continents_logits))
        elif self.config.block_trans_model == "country_to_continent":
            cross_continent_logits, _ = self.block_trans_country_to_continent(
                cross_block_hidden_states, continents, countries, 
                continent_logits=torch.exp(continents_logits))
        elif self.config.block_trans_model == "embedding":
            cross_continent_logits, cross_block_trans  = self.block_trans_embeddings(
                cross_block_hidden_states, continents, countries, 
                continent_logits=torch.exp(continents_logits))
        elif self.config.block_trans_model == "country_to_country":
            pass
            cross_continent_logits, _ = self.block_trans_country_to_country(
                cross_block_hidden_states, continents, countries, torch.exp(all_all_country_logits))

        total_logits = torch.log(torch.exp(country_logtis) + cross_continent_logits) 

        # # for debug
        # contri_country = torch.gather(torch.exp(country_logtis)[:, :-1], -1, argv["input_ids"][:, 1:].unsqueeze(-1)).squeeze(-1)
        # contri_continent = torch.gather(cross_continent_logits[:, :-1], -1, argv["input_ids"][:, 1:].unsqueeze(-1)).squeeze(-1)
        # # contri_all_continents = torch.gather(torch.exp(continent_outputs.info_dict["component_logits"])[:, :-1], 
        #                                     #  -2, argv["input_ids"][:, 1:].unsqueeze(-1).unsqueeze(-1).repeat(1,1,1,6)).squeeze(-2)
        # print(torch.mean(contri_country/torch.gather(torch.exp(total_logits), -1, argv["input_ids"][:, 1:].unsqueeze(-1)).squeeze(-1)))
        # print(torch.mean(contri_continent/torch.gather(torch.exp(total_logits), -1, argv["input_ids"][:, 1:].unsqueeze(-1)).squeeze(-1)))
        # print(torch.mean(contri_continent), (torch.mean(contri_country)))
        # # exit()

        return GPTOutputs(logits=total_logits, info_dict={})

@register_model("gpt2_time_multi_hosts_hierarchy")
class GPT2TimeMultiHostsHierarchy(GPT2TimeMultiHosts):
    def __init__(self, config, alphabet) -> None:
        super().__init__(config, alphabet)

        
    def initialize_model(self, pretrained_model_name_or_path: str):
        """create and initialize the model to use with this task,
        Feel free to overwrite this method if you are initializing the model in a different way
        """
        config = AutoConfig.from_pretrained(
            pretrained_model_name_or_path="gpt2", **self.model_data_kwargs
        )

        setattr(config, "offset_pos_function", getattr(self.config, "offset_pos_function", "softmax"))
        setattr(config, "pos_function", getattr(self.config, "pos_function", "softplus"))
        setattr(config, "trans_w_pos_function", getattr(self.config, "trans_w_pos_function", "softplus"))
        setattr(config, "max_rate_value", getattr(self.config, "max_rate_value", 1e5))
        setattr(config, "min_rate_value", getattr(self.config, "min_rate_value", 1e-12))

        setattr(config, "apply_log_softmax", getattr(self.config, "apply_log_softmax", False))

        setattr(config, "normalize_time_a", self.config.normalize_time_a)
        setattr(config, "normalize_time_b", self.config.normalize_time_b)
        setattr(config, "transformer_offset", self.config.transformer_offset)
        # setattr(config, "data_property", self.config.data_properties[0])
        setattr(config, "data_properties", self.config.data_properties)
        setattr(config, "share_base", getattr(self.config, "share_base", False))
        setattr(config, "output_layer_type", getattr(self.config, "output_layer_type", "linear"))
        setattr(config, "block_trans_model", self.config.block_trans_model)
        setattr(config, "block_trans_model_n_layer", 
                getattr(self.config, "block_trans_model_n_layer", self.config.num_hidden_layers))
        # setattr(config, "reuse_", self.config.block_trans_model)
        setattr(config, "implement_version", self.config.implement_version)
        
        setattr(config, "reuse_transformer_for_cross_block_trans", \
            getattr(self.config, "reuse_transformer_for_cross_block_trans", False))
        
        setattr(config, "use_simple_continent_model", \
            getattr(self.config, "use_simple_continent_model", False))

        setattr(config, "continent_share_base_models", \
            getattr(self.config, "continent_share_base_models", True))

        setattr(config, "block_trans_embedding_size", 
                getattr(self.config, "block_trans_embedding_size", 128))


        # Include extra geographical features, for example, coordinates
        setattr(config, "contient2country", self.config.contient2country) 
        setattr(config, "add_geo_info", getattr(self.config, "add_geo_info", False)) 
        setattr(config, "geo_feats_path", getattr(self.config, "geo_feats_path", None)) 

        setattr(config, "trans_group_weight_rely_on_aa", \
                getattr(self.config, "trans_group_weight_rely_on_aa", True)) 

        setattr(config, "prepend_cross_block_trans", \
                getattr(self.config, "prepend_cross_block_trans", False)) 

        # setattr(config, "num_component_continent", \
            # getattr(self.config, "num_component_continent", None))
        
        # setattr(config, "return_rates_matrix", getattr(self.config, "return_rates_matrix", False))
        # setattr(config, "return_init_prob", getattr(self.config, "return_init_prob", False))

        output_layer_config = AutoConfig.from_pretrained(
            pretrained_model_name_or_path="gpt2", 
            vocab_size=config.vocab_size,
            max_position_embeddings=config.max_position_embeddings,
            hidden_size=config.hidden_size,
            num_hidden_layers=getattr(self.config, "output_layer_num_hidden_layers", self.config.num_hidden_layers)
        )

        number_of_conponents = [len(x[1]) for x in self.config.contient2country]

        if getattr(self.config, "random_blocks", False):
            self.model = GPT2TimeModelMultiHostsHierarchyRandomBlocks.from_config(config, number_of_conponents, output_layer_config=output_layer_config)
        elif self.config.brute_force_mean_field == True: # 
            self.model = GPT2TimeModelMultiHostsHierarchyBruteForce.from_config(config, number_of_conponents, output_layer_config=output_layer_config)
        else: # approximation
            if getattr(self.config, "remove_cross_block_trans", False): # some ablation
                self.model = GPT2TimeModelMultiHostsHierarchyRemoveCrossBlocks.from_config(config, number_of_conponents, output_layer_config=output_layer_config)
            else:
                self.model = GPT2TimeModelMultiHostsHierarchy.from_config(config, number_of_conponents, output_layer_config=output_layer_config)
    # def load_pretrained_model(self, path):
    #     pretrained_model_state_dict = torch.load(path)["state_dict"]
    #     for state in pretrained_model_state_dict:
    #         if state in self.state_dict():
    #             if self.state_dict()[state].size() != pretrained_model_state_dict[state].size():
    #                 logging.warning("The parameter %s of pretrained model (%s) doesn't fit the current model %s." % (state, str(pretrained_model_state_dict[state].size()), str(self.state_dict()[state].size())))
    #             else:
    #                 self.state_dict()[state].copy_(pretrained_model_state_dict[state])

    
    @classmethod
    def load_from_checkpoint(
        cls,
        checkpoint_path: Union[str, IO],
        map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
        hparams_file: Optional[str] = None,
        strict: bool = True,
        hf_pipeline_kwargs: Optional[Dict] = None,
        # config = None,
        args = None,
        **kwargs
    ):
        # if args.ensemble and len(checkpoint_path.split(",")) > 1:
        #     checkpoint_paths = checkpoint_path.split(",")
        #     # print(checkpoint_paths)
        #     model_list = []
        #     for path in checkpoint_paths:
        #         _model = super().load_from_checkpoint(path, map_location, hparams_file, strict)
        #         model_list.append(_model)
        #     # models = [super().load_from_checkpoint(path, map_location, hparams_file, strict) for path in checkpoint_paths]
        #     model = nn.ModuleList(model_list) 
        # else:
        # print(map_location)
        model = super().load_from_checkpoint(checkpoint_path, map_location, hparams_file, strict=False) # TODO: for debug...
        # print(model)
        # print(model.model.base_models)
        # model.model.base_models = nn.ModuleDict(model.model.base_models)
        # print(model.model.base_models.trans_base.device)
        # print(model.model.transformer.device)
        # exit()

        # model.resume_from_checkpoint = checkpoint_path
        model.config.resume_from_checkpoint = checkpoint_path
        model.config.pred_data_paths = getattr(args, "pred_data_paths", "")
        if args is not None:
            model.config.test_data_paths = args.test_data_paths
        for key in kwargs:
            logging.info("Overwrite model hyperparameter %s:" % key + ", from " + str(getattr(model, key, None)) + " to " + str(kwargs[key]))
            setattr(model, key, kwargs[key])
        return model

    @classmethod
    def add_argparse_args(cls, parent_parser):
        parent_parser = super(GPT2TimeMultiHostsHierarchy, cls).add_argparse_args(parent_parser)
        parent_parser.add_argument('--block_trans_model', type=str, default="country_to_country", 
                                   choices=["country_to_continent", "continent_to_continent", "country_to_country", "embedding", "prepend"])
        parent_parser.add_argument('--continent_loss_weight', type=float, default=0)
        parent_parser.add_argument('--reuse_transformer_for_cross_block_trans', type=str2bool, default='false')
        # parent_parser.add_argument('--continent_implement_version', type=str2bool, default='false')
        parent_parser.add_argument('--brute_force_mean_field', type=str2bool, default='false')
        parent_parser.add_argument('--random_blocks', type=str2bool, default='false')
        parent_parser.add_argument('--continent_loss_target', type=str, default='label', choices=["label", "logits"])
        parent_parser.add_argument('--block_trans_embedding_size', type=int, default=128)
        
        parent_parser.add_argument('--use_simple_continent_model', type=str2bool, default='false')

        parent_parser.add_argument('--continent_share_base_models', type=str2bool, default='true')

        parent_parser.add_argument('--apply_log_softmax', type=str2bool, default='false')
        parent_parser.add_argument('--cross_continent_reg', type=float, default=0.0)
        # parent_parser.add_argument('--num_component_continent', type=int, default=None)

        parent_parser.add_argument('--output_info', type=str2bool, default='false', help="Output info like rate matrix in testing time.")
        
        parent_parser.add_argument('--add_geo_info', type=str2bool, default='false')
        parent_parser.add_argument('--geo_feats_path', type=str, default="")
        
        parent_parser.add_argument('--detach_continent_loss', type=str2bool, default='false')
        parent_parser.add_argument('--trans_group_weight_rely_on_aa', type=str2bool, default='true')

        parent_parser.add_argument('--remove_cross_block_trans', type=str2bool, default='false')

        parent_parser.add_argument('--trans_w_pos_function', type=str, default="softplus", 
                                   choices=["softplus", "sigmoid", "relu"])

        
        parent_parser.add_argument('--block_trans_model_n_layer', type=int, default=4)
        return parent_parser
        
    def calc_continent_loss(self, continent_logits, labels, loss_weight, reduce):
        if self.config.detach_continent_loss:
            labels = labels.detach()

        if self.config.continent_loss_target == "label":
            return self.nll_loss(continent_logits, labels, loss_weight=loss_weight, reduce=reduce) 
        elif self.config.continent_loss_target == "logits": # mse loss
            B = continent_logits.size(0)
            assert continent_logits.size() == labels.size()
            # print(continent_logits.size(), labels.size())
            # print(torch.mean(continent_logits), torch.sum(continent_logits.exp(), dim=-1))
            # print(torch.mean(labels), torch.sum(labels.exp(), dim=-1))
            if getattr(self.config, "apply_log_softmax", False): # calculate KLD instead of MSE
                loss = torch.mean(torch.exp(labels.view(B, -1)) * (labels.view(B, -1) - continent_logits.view(B, -1)), dim=-1) # [B]
            else:
                loss = torch.mean((continent_logits.view(B, -1) - labels.view(B, -1)) ** 2, dim=-1) # [B]
            if reduce:
                if loss_weight is not None:
                    if not self.config.no_normalization_in_batch:
                        loss_weight = loss_weight / loss_weight.sum()
                    loss = torch.sum(loss * loss_weight)
                else:
                    loss = loss.mean()
            return loss
        else:
            raise NotImplementedError

    def forward(self, batch, batch_idx, reduce=True, mode="train"):
        if getattr(self.config, "zero_time", False):
            batch["input_time"].fill_(0.)

        if getattr(self.config, "set_time", None) is not None:
            batch["input_time"].fill_(self.config.set_time) # set time bin as a constant.

        model_outputs = self.model(**batch, 
                                   return_rates_matrix = self.config.output_info, 
                                   return_init_prob = self.config.output_info,
                                   return_cross_block_trans = self.config.output_info)
        
        # logits, info_dict = self.model(**batch, return_rates_matrix=self.config.return_rates_matrix, return_init_prob = self.config.return_init_prob)
        loss_weight = batch.get('freq', None) * batch.get('bin_size', None) # TODO change from batch.get('freq', None)

        labels = batch["labels"]
        country_loss = self.nll_loss(model_outputs.logits, labels, loss_weight=loss_weight, 
                                     reduce=reduce, ignore_bos=True if mode == "test" else False)
        loss_dict = {"country_loss": country_loss}

        if mode != "test" and self.config.continent_loss_weight > 0 and "continent_logits" in model_outputs.info_dict:
            if self.config.continent_loss_target == "label":
                continent_loss = self.calc_continent_loss(
                    model_outputs.info_dict["continent_logits"], 
                    labels, 
                    loss_weight, reduce)
            elif self.config.continent_loss_target == "logits":
                continent_loss = self.calc_continent_loss(
                    model_outputs.info_dict["continent_logits"], 
                    model_outputs.info_dict["sum_of_country_logits"], 
                    loss_weight, reduce)
            # continent_loss = self.nll_loss(model_outputs.info_dict["continent_logits"], labels, loss_weight=loss_weight, reduce=reduce) 
            # print("continent_loss", continent_loss)
            loss_dict["continent_loss"] = continent_loss
            loss = country_loss + self.config.continent_loss_weight * continent_loss
        else:
            loss = country_loss
        
        if  mode != "test" and getattr(self.config, "cross_continent_reg", 0.0) > 0.0 and "cross_group_trans_weight" in model_outputs.info_dict:
            cross_group_trans_weight = model_outputs.info_dict["cross_group_trans_weight"]
            # print(cross_group_trans_weight.size())
            # print(cross_group_trans_weight)
            cross_group_l2_loss = torch.mean(cross_group_trans_weight ** 2)
            loss_dict["cross_continent_reg_loss"] = cross_group_l2_loss
            loss = loss + self.config.cross_continent_reg * cross_group_l2_loss
            # print(cross_group_l2_loss)
            # print(batch["continent"])

        # print(mode, self.config.continent_loss_weight)
        # print(loss, loss_dict)
        if self.config.output_info: # collect information for analysis
            for key in model_outputs.info_dict:
                loss_dict[key] = model_outputs.info_dict[key]
            
        return loss, loss_dict # model_outputs.info_dict

    def overwrite_generate_kwargs(self, new_config):
        super().overwrite_generate_kwargs(new_config)
        setattr(self.config, "output_info", new_config.output_info)

        # setattr(self.config, "do_sample", new_config.do_sample)
    #     setattr(self.config, "num_beams", new_config.num_beams)
    #     setattr(self.config, "temperature", new_config.temperature)
    #     setattr(self.config, "num_return_sequences", new_config.num_return_sequences)
    #     setattr(self.config, "output_token_losses", new_config.output_token_losses)

    #     setattr(self.config, "return_rates_matrix", new_config.return_rates_matrix)
    #     setattr(self.config, "return_init_prob", new_config.return_init_prob)
