import sys 
sys.path.append("..")
from arguments import ModelArguments, DataArguments
from arguments import DenseTrainingArguments as TrainingArguments

from .BiGPTDenseRetriever import (BiGPTDenseModel, BiGPTDenseModelForInference)
from .BiGPTJDenseRetriever import (BiGPTJDenseModel, BiGPTJDenseModelForInference)
from .VanillaDenseRetriever import (VanillaDenseModel, VanillaDenseModelForInference)


from .TopDenseRetriever import (TopDenseModel, TopDenseModelForInference)

## ********************************************
## Debug
from .BiGPTJDenseRetriever_DropDim import (BiGPTJDenseModel_DropDim, BiGPTJDenseModelForInference_DropDim)

## ********************************************

from collections import OrderedDict

import logging
logger = logging.getLogger(__name__)

try:
    from opendelta import PrefixModel, BitFitModel
    _opendelta_available = True
except ModuleNotFoundError:
    _opendelta_available = False

def get_network(
    model_args: ModelArguments,
    data_args: DataArguments,
    training_args: TrainingArguments,
    config: OrderedDict,
    tokenizer: OrderedDict,
    cache_dir: str,
    do_train: bool,
):
    
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model_args.pad_token_id = tokenizer.pad_token_id


    if model_args.no_use_gpt:
        logger.info("Vanilla-Transformers! {}".format(model_args.model_name_or_path))
        # logger.info("Vanilla-Transformers! GPT pad token id: {}".format(model_args.pad_token_id))
        
        if do_train:
            model = VanillaDenseModel.build(
                model_args,
                data_args,
                training_args,
                config=config,
                tokenizer=tokenizer,
                cache_dir=model_args.cache_dir,
            )
        else:
            model = VanillaDenseModelForInference.build(
                model_args=model_args,
                data_args=data_args,
                config=config,
                tokenizer=tokenizer,
                cache_dir=model_args.cache_dir,
            )
        return model
    
    
    if model_args.top_gpt:
        logger.info("Top-Transformers! {}".format(model_args.model_name_or_path))

        if do_train:
            model = TopDenseModel.build(
                model_args,
                data_args,
                training_args,
                config=config,
                tokenizer=tokenizer,
                cache_dir=model_args.cache_dir,
            )
        else:
            model = TopDenseModelForInference.build(
                model_args=model_args,
                data_args=data_args,
                config=config,
                tokenizer=tokenizer,
                cache_dir=model_args.cache_dir,
            )
        return model
    
    
    ## ****************************************
    ## residual-gpt
    if model_args.residual_encoder_name_or_path:
        ## **********************************
        ## GPT-J
        ## **********************************
        if "gpt-j" in model_args.model_name_or_path:
            logger.info("Hybrid-Transformers! GPT-J pad token id: {}".format(model_args.pad_token_id))
            
            if do_train:
                model = BiGPTJDenseModel.build(
                    model_args,
                    data_args,
                    training_args,
                    config=config,
                    tokenizer=tokenizer,
                    cache_dir=model_args.cache_dir,
                )
            else:
                model = BiGPTJDenseModelForInference.build(
                    model_args=model_args,
                    data_args=data_args,
                    config=config,
                    tokenizer=tokenizer,
                    cache_dir=model_args.cache_dir,
                )
            return model
        
        ## **********************************
        ## GPT2
        ## **********************************
        else:
            # logger.info("Hybrid-Transformers! GPT2")
            logger.info("Hybrid-Transformers! GPT2 pad token id: {}".format(model_args.pad_token_id))

            if do_train:
                model = BiGPTDenseModel.build(
                    model_args,
                    data_args,
                    training_args,
                    config=config,
                    tokenizer=tokenizer,
                    cache_dir=model_args.cache_dir,
                )
            else:
                model = BiGPTDenseModelForInference.build(
                    model_args=model_args,
                    data_args=data_args,
                    config=config,
                    tokenizer=tokenizer,
                    cache_dir=model_args.cache_dir,
                )
            return model
    

    

def get_delta_model_class(model_type):
    if not _opendelta_available:
        raise ValueError(
            'Opendelta package not available. You can obtain it from https://github.com/thunlp/OpenDelta.')
    delta_models = {
        'bitfit': BitFitModel,
        'prefix': PrefixModel
    }
    return delta_models[model_type]

