import torch
import torch.nn as nn
import torch.nn.functional as F
import os

from fairseq import search
from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import (
    TransformerModel,
    transformer_wmt_en_de,
    transformer_iwslt_de_en,
)
import logging

from fairseq.sequence_generator import SequenceGenerator
from fairseq.my_generator import MySequenceGenerator

logger = logging.getLogger(__name__)


@register_model("transformer_repro")
class TransformerCLGModel(TransformerModel):
    """
    target <-> source contrastive learning => simCLR
    """

    def __init__(self, generator, tgt_dict, cfg, args):
        super().__init__(args, generator.encoder, generator.decoder)
        self.cfg = cfg
        self.tgt_dict = tgt_dict
        self.args = args
        self.generator = generator
        self.pad_id = self.args.pad
        self.hidden_size = self.args.encoder_embed_dim

    @classmethod
    def build_model(cls, args, task):
        # set any default arguments
        transformer_clg(args)
        transformer_model = TransformerModel.build_model(args, task)
        return TransformerCLGModel(
            transformer_model, task.target_dictionary, task.cfg, args
        )

    @classmethod
    def add_args(cls, parser):
        """Add model-specific arguments to the parser."""
        # we want to build the args recursively in this case.
        super(TransformerCLGModel, TransformerCLGModel).add_args(parser)
        parser.add_argument('--lenpen', default=0.1, type=float)
        parser.add_argument('--max_len_a', default=1.0, type=float)
        parser.add_argument('--max_len_b', default=50.0, type=float)
        # parser.add_argument('--diverse', default=None)
        parser.add_argument('--diverse_bias', default=3.5, type=float)
        parser.add_argument('--max_sample_num', default=16, type=int)
        parser.add_argument('--samples_from_batch', default=0, type=int)
        parser.add_argument('--max_sample_len', default=48, type=int)
        parser.add_argument('--skip_warmup_ckpt', default=None)
        parser.add_argument('--cl_loss', default="infoNCE", choices=['ranking', "infoNCE"])


    def forward(self, src_tokens, src_lengths, prev_output_tokens):
        """
        cos_score distance of hypothesis to source
        bleu its actual bleu score
        """
        batch_size = src_tokens.size(0)
        encoder = self.generator.encoder
        decoder = self.generator.decoder
        encoder_out = encoder(src_tokens, src_lengths)
        decoder_out = decoder(prev_output_tokens, encoder_out, features_only=True)
        decoder_hidden_states = decoder_out[0]  # batch x tgt_len x hidden
        lm_logits = decoder.output_layer(decoder_hidden_states)
        # without CL
        decoder_out = list(decoder_out)
        decoder_out[0] = lm_logits
        decoder_out = tuple(decoder_out)
        return decoder_out


@register_model_architecture("transformer_repro", "transformer_repro_wmt")
def transformer_cl_wmt_en_de(args):
    transformer_wmt_en_de(args)


@register_model_architecture("transformer_repro", "transformer_repro_iwslt")
def transformer_clg(args):
    transformer_iwslt_de_en(args)
