import torch
import yaml 
import os 
import argparse
import re 
from transformers import AutoModelForSeq2SeqLM, AutoConfig, BartConfig

from transformers import PreTrainedTokenizerFast
from .model import load_model

def load_args(save_dir):
    config_file = os.path.join(save_dir, 'params.yaml')
    with open(config_file, 'r') as f:
        config = yaml.safe_load(f)
    args = argparse.Namespace(**config)
    return args 

def load_config(save_dir):
    config_file = os.path.join(save_dir, 'config.json')
    with open(config_file, 'r') as f:
        config = yaml.safe_load(f)
    args = argparse.Namespace(**config)
    return args 


def get_checkpoint_id(save_dir):
    cpt_file = [f for f in os.listdir(save_dir) if 'checkpoint' in f][0]
    cpid = int(re.search(r'checkpoint-(\d+)', cpt_file).group(1))
    return cpid 

def load_tokenizer(save_dir, cfg_name=None):
    # if from_checkpoint:
    #     cpid = get_checkpoint_id(save_dir)
    #     tokenizer = PreTrainedTokenizerFast.from_pretrained(os.path.join(save_dir, f'checkpoint-{cpid}/tokenizer.json'))
    # else:
    # breakpoint()
    if cfg_name is not None:
        tokenizer = PreTrainedTokenizerFast.from_pretrained(os.path.join(save_dir, cfg_name))
    tokenizer = PreTrainedTokenizerFast.from_pretrained(os.path.join(save_dir, f'tokenizer.json'))
    return tokenizer

def load_pretrained_model(save_dir, tokenizer, model_name, from_checkpoint=False):
    if from_checkpoint:
        cpid = get_checkpoint_id(save_dir)
        # config = AutoConfig.from_pretrained(os.path.join(save_dir, f'checkpoint-{cpid}/config.json'))
        config = BartConfig.from_pretrained(os.path.join(save_dir, f'checkpoint-{cpid}/config.json'))
        # model = AutoModelForSeq2SeqLM.from_pretrained(os.path.join(save_dir, f'checkpoint-{cpid}/pytorch_model.bin'), config=config)
        # model = AutoModelForSeq2SeqLM.from_pretrained(os.path.join(save_dir, f'checkpoint-{cpid}/model.safetensors'), config=config, use_safetensors=True)
        model = BartForPolynomialSystemGeneration.from_pretrained(os.path.join(save_dir, f'checkpoint-{cpid}/model.safetensors'), config=config, use_safetensors=True)
        # model = load_model(config, tokenizer, model=model_name)
        # model.from_pretrained(os.path.join(save_dir, f'checkpoint-{cpid}/model.safetensors'), config=config, use_safetensors=True)
        # model = load_model(config, tokenizer, model=model_name)
    else:
        # config = AutoConfig.from_pretrained(os.path.join(save_dir, f'config.json'))
        config = BartConfig.from_pretrained(os.path.join(save_dir, f'config.json'))
        if not hasattr(config, 'continuous_embedding_model'):
            setattr(config, 'continuous_embedding_model', 'ffn')

        # print(config)
        # exit()
        # model = load_model(config, tokenizer, model=model_name)
        model = BartForPolynomialSystemGeneration.from_pretrained(os.path.join(save_dir, f'model.safetensors'), config=config, use_safetensors=True)    
        # model = AutoModelForSeq2SeqLM.from_pretrained(os.path.join(save_dir, f'pytorch_model.bin'), config=config)    
        
    model.eval().cuda()
    return model 

def load_trained_bag(save_dir, from_checkpoint=False, model_name='bart+'):
    params = load_args(save_dir)
    tokenizer = load_tokenizer(save_dir, from_checkpoint=from_checkpoint)
    model = load_pretrained_model(save_dir, tokenizer, model_name=model_name, from_checkpoint=from_checkpoint)
    model.to_bettertransformer()
    bag = {'model': model, 'params': params, 'tokenizer': tokenizer}
    
    return bag
