import json
import torch
from collections import OrderedDict
from transformers import AutoTokenizer, BartConfig
from hexa.models.modeling_r2c2 import R2C2ForConditionalGeneration


class PersonaSummarizer(R2C2ForConditionalGeneration):
    def __init__(self, config, tokenizer):
        super().__init__(config)    
        self.tokenizer = tokenizer
        
    def batch_respond(
        self, 
        messages,
        num_beams=1, 
        do_sample=False,
        min_length = 10,
        max_length = 64,
        length_penality = 0.65,     # same value used in parlai
        no_repeat_ngram_size = 3,   # same value used in parlai
        decoder_start_token_id = 2,        
    ):
        if not isinstance(messages, list):
            messages = [messages]
        
        list_of_texts = [m['text'] for m in messages]
        text = [self.tokenizer.bos_token + text + self.tokenizer.eos_token for text in list_of_texts]
        input_ids = self.tokenizer(
            text, 
            padding=True, 
            truncation=True, 
            return_tensors='pt'
        )['input_ids'].to(self.device)
        bsz = input_ids.shape[0]
        with torch.no_grad():
            summary_ids = self.generate(
                input_ids,
                num_beams = num_beams,
                do_sample = do_sample,
                min_length = min_length, 
                max_length = max_length,            
                length_penalty = length_penality,
                no_repeat_ngram_size = no_repeat_ngram_size,
                decoder_input_ids = torch.tensor(                    
                    [decoder_start_token_id], 
                    dtype=torch.long, 
                    device=self.device
                    ).repeat(bsz, 1),
            )    
            response = self.tokenizer.batch_decode(
                summary_ids, 
                skip_special_tokens=True, 
                clean_up_tokenization_spaces=False
            )
        return response          
    

def build_summarizer(opt):
    
    def _rename_state_dict(
        state_dict, 
        offset, 
        vocab_size, 
        prefix='model', 
        remove_str='seq2seq_'
    ):
        new_state_dict = OrderedDict()
        for k,v in state_dict.items():
            k = k.lower()
            if remove_str is not None:
                if remove_str in k:
                    k = k.replace(remove_str, '')
            if k=='start':
                continue
            if 'encoder.' in k:
                k = f'{prefix}.' + k
            if 'decoder.' in k:
                k = f'{prefix}.' + k
            if 'ffn.lin' in k:
                k = k.replace('ffn.lin', 'ffn_lin')
            if 'position_embeddings' in k:
                pass
                # v = torch.cat((v, v[-offset:]))
            if k == 'embeddings.weight':
                k = 'lm_head.weight'
            new_state_dict[k] = v    
        new_state_dict['final_logits_bias'] = torch.zeros([1, vocab_size])    
        new_state_dict['model.shared.weight'] = state_dict['embeddings.weight']    
        return new_state_dict    
    
    
    if 'device_id' in opt:
        device = torch.device(opt['device_id'])
    else:
        device = torch.device('cuda:0')
    
    # device = torch.device(opt['device_id'])    
    
    tokenizer_path = 'hexa/models/tokenizer'
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

    model_config_path = 'hexa/models/config/bb3_3b.json'
    model_config = BartConfig.from_json_file(model_config_path)
    model_config.forced_eos_token_id = None
    model_config.force_bos_token_to_be_generated = None
    model_config.encoder_no_repeat_ngram_size = None
    model_config.no_repeat_ngram_size = None
    model_config.length_penalty = None

    model = PersonaSummarizer(model_config, tokenizer)

    parlai_model_path = '/models/bb3/persona_summarizer/model'
    parlai_state_dict = torch.load(parlai_model_path, map_location='cpu')
    state_dict = _rename_state_dict(parlai_state_dict['model'], offset=0, vocab_size=model_config.vocab_size)
    model.load_state_dict(state_dict)
    model.eval().to(device)
    return model


if __name__ == '__main__':
    config = {
        'device_id': 'cuda:0',
    }
    summarizer = build_summarizer(config)
    msg = [
        {'text': 'My cats are very funny and cute, their names are Ross and Daniel.',
         'episode_done': True
        },
        {'text': 'My cats are very funny and cute, their names are Ross and Daniel.',
         'episode_done': True
        },        
    ]
    res = summarizer.batch_respond(msg)
    print(res)
    debug = True