import os
import io
import sys
import argparse
import torch

from XLM.src.utils import AttrDict
from XLM.src.utils import bool_flag, initialize_exp
from XLM.src.data.dictionary import Dictionary
from XLM.src.model.transformer import TransformerModel

class Param():
  def __init__(self, dump_path, exp_name, exp_id, batch_size, model_path, tgt_lang, src_lang, max_vocab=-1, min_count=0):
    self.dump_path = dump_path # Experiment dump path
    self.exp_name = exp_name # Experiment name
    self.exp_id = exp_id # Experiment ID
    self.batch_size = batch_size # Number of sentences per batch
    self.model_path = model_path # Model path
    self.src_lang = src_lang # Source language
    self.tgt_lang = tgt_lang # Target language
    self.max_vocab = max_vocab # Maximum vocabulary size (-1 to disable)
    self.min_count = min_count, # Minimum vocabulary count


class Trainer():
    def __init__(self, model_path, tgt_lang, src_lang,dump_path = "./dumped/", exp_name="translate", exp_id="test", batch_size=32):
        
        params = Param(dump_path, exp_name, exp_id, batch_size, model_path, tgt_lang, src_lang)
        
        assert params.src_lang != '' and params.tgt_lang != '' and params.src_lang != params.tgt_lang

        # initialize the experiment
        logger = initialize_exp(params)
        
        # On a pas de GPU
        #reloaded = torch.load(params.model_path)
        reloaded = torch.load(params.model_path, map_location=torch.device('cpu'))
        model_params = AttrDict(reloaded['params'])
        self.supported_languages = model_params.lang2id.keys() 
        logger.info("Supported languages: %s" % ", ".join(self.supported_languages))

        # update dictionary parameters
        for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']:
            setattr(params, name, getattr(model_params, name))

        # build dictionary / build encoder / build decoder / reload weights
        self.dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
        #self.encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval()
        self.encoder = TransformerModel(model_params, self.dico, is_encoder=True, with_output=True).eval()
        #self.decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
        self.decoder = TransformerModel(model_params, self.dico, is_encoder=False, with_output=True).eval()
        self.encoder.load_state_dict(reloaded['encoder'])
        self.decoder.load_state_dict(reloaded['decoder'])
        params.src_id = model_params.lang2id[params.src_lang]
        params.tgt_id = model_params.lang2id[params.tgt_lang]
        self.model_params = model_params
        self.params = params
  
    def translate(self, src_sent=[]):
        flag = False
        if type(src_sent) == str :
            src_sent = [src_sent]
            flag = True
        tgt_sent = []
        for i in range(0, len(src_sent), self.params.batch_size):
            # prepare batch
            word_ids = [torch.LongTensor([self.dico.index(w) for w in s.strip().split()])
                        for s in src_sent[i:i + self.params.batch_size]]
            lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
            batch = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(self.params.pad_index)
            batch[0] = self.params.eos_index
            for j, s in enumerate(word_ids):
                if lengths[j] > 2:  # if sentence not empty
                    batch[1:lengths[j] - 1, j].copy_(s)
                batch[lengths[j] - 1, j] = self.params.eos_index
            langs = batch.clone().fill_(self.params.src_id)

            # encode source batch and translate it
            #encoded = self.encoder('fwd', x=batch.cuda(), lengths=lengths.cuda(), langs=langs.cuda(), causal=False)
            encoded = self.encoder('fwd', x=batch, lengths=lengths, langs=langs, causal=False)
            encoded = encoded.transpose(0, 1)
            #decoded, dec_lengths = self.decoder.generate(encoded, lengths.cuda(), self.params.tgt_id, max_len=int(1.5 * lengths.max().item() + 10))
            decoded, dec_lengths = self.decoder.generate(encoded, lengths, self.params.tgt_id, max_len=int(1.5 * lengths.max().item() + 10))

            # convert sentences to words
            for j in range(decoded.size(1)):

                # remove delimiters
                sent = decoded[:, j]
                delimiters = (sent == self.params.eos_index).nonzero().view(-1)
                assert len(delimiters) >= 1 and delimiters[0].item() == 0
                sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]

                # output translation
                source = src_sent[i + j].strip()
                target = " ".join([self.dico[sent[k].item()] for k in range(len(sent))])
                #sys.stderr.write("%i / %i: %s -> %s\n" % (i + j, len(src_sent), source, target))
                tgt_sent.append(target )
        
        if flag :
            return tgt_sent[0]
        return tgt_sent

