import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss
from federatedscope.nlp.hetero_tasks.dataset.utils import setup_tokenizer
from federatedscope.nlp.loss.label_smooth_loss import \
    LabelSmoothingLoss


class ModelOutput(object):
    def __init__(self,
                 loss=None,
                 regular_loss=None,
                 contrastive_loss=None,
                 logits=None,
                 hidden_states=None,
                 example_indices=None):
        self.loss = loss
        self.regular_loss = regular_loss
        self.contrastive_loss = contrastive_loss
        self.logits = logits
        self.hidden_states = hidden_states
        self.example_indices = example_indices


class ContrastiveHead(nn.Module):
    def __init__(self, input_dim, inner_dim, out_dim, dropout_prob):
        super().__init__()

        self.dense = nn.Linear(input_dim, inner_dim)
        self.dropout = nn.Dropout(p=dropout_prob)
        self.out_prj = nn.Linear(inner_dim, out_dim)

    def forward(self, x):
        x = self.dense(self.dropout(x))
        x = torch.tanh(x)
        x = self.out_prj(self.dropout(x))
        return x


class ATCModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        from transformers.models.encoder_decoder import EncoderDecoderModel
        from transformers.models.bert.modeling_bert import BertLMPredictionHead

        self.model = EncoderDecoderModel.from_encoder_decoder_pretrained(
            config.model_type, config.model_type)
        self.lm_head = BertLMPredictionHead(self.model.encoder.config)

        self.client_id = None
        self.task = config.task
        self.pt_cfg = self.model.encoder.config
        self.vocab_size = self.pt_cfg.vocab_size
        self.hidden_size = self.pt_cfg.hidden_size
        self.dropout_prob = self.pt_cfg.hidden_dropout_prob
        self.dropout = nn.Dropout(self.dropout_prob)

        setup_tokenizer(config.model_type)  # update global token ids
        from federatedscope.nlp.hetero_tasks.dataset.utils import \
            BOS_TOKEN_ID, EOS_TOKEN_ID, PAD_TOKEN_ID
        self.label_smoothing = config.label_smoothing
        self.padding_idx = PAD_TOKEN_ID
        self.classifier = nn.Linear(self.hidden_size, config.num_labels)

        self.use_contrastive_loss = config.use_contrastive_loss
        if self.use_contrastive_loss:
            self.contrast_topk = config.contrast_topk
            self.contrast_temp = config.contrast_temp
            self.contrast_head = ContrastiveHead(
                input_dim=self.hidden_size,
                inner_dim=self.hidden_size,
                out_dim=self.hidden_size,
                dropout_prob=self.dropout_prob)

        # for eval generation
        self.model.config.decoder_start_token_id = BOS_TOKEN_ID
        self.model.config.eos_token_id = EOS_TOKEN_ID
        self.model.config.pad_token_id = PAD_TOKEN_ID
        self.model.config.vocab_size = self.pt_cfg.vocab_size
        self.model.config.max_length = config.max_length
        self.model.config.min_length = config.min_length
        self.model.config.no_repeat_ngram_size = config.no_repeat_ngram_size
        self.model.config.length_penalty = config.length_penalty
        self.model.config.num_beams = config.num_beams

    def update_client_id(self, client_id):
        self.client_id = client_id

    def generate(self, **kwargs):
        return self.model.generate(**kwargs)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        start_positions=None,
        end_positions=None,
        labels=None,
        pretrain_task=None,
        contrast_monitor=None,
        in_contrast_prepare=None,
        example_indices=None,
    ):
        if in_contrast_prepare:  # return dec_hidden_states & dec_out
            self.eval()
            with torch.no_grad():
                example_indices = [
                    k for k in example_indices
                    if k.item() in contrast_monitor.synth_tokens
                ]
                if len(example_indices) == 0:
                    return ModelOutput(example_indices=example_indices)

                example_indices = torch.stack(example_indices)
                synth_input_ids = torch.stack([
                    contrast_monitor.synth_tokens[k.item()]
                    for k in example_indices
                ]).to(self.model.device)

                enc_hidden = torch.stack([
                    contrast_monitor.enc_hidden[k.item()]
                    for k in example_indices
                ]).to(self.model.device)
                outputs = self.model.decoder.bert(
                    input_ids=synth_input_ids,
                    encoder_hidden_states=enc_hidden,
                )
                logits = self.model.decoder.cls(outputs.last_hidden_state)
                dec_hidden = self.contrast_head(
                    outputs.last_hidden_state).mean(1)

                return ModelOutput(logits=logits.argmax(-1),
                                   hidden_states=dec_hidden,
                                   example_indices=example_indices)

        enc_outputs = self.model.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
        )

        regular_loss, contrastive_loss = None, None
        if self.task == 'pretrain':
            if pretrain_task == 'mlm':
                logits = self.lm_head(enc_outputs.last_hidden_state)
                loss_fct = CrossEntropyLoss()
                masked_lm_loss = loss_fct(logits.view(-1, self.vocab_size),
                                          labels.view(-1))
                loss = masked_lm_loss

            elif pretrain_task == 'denoise':
                dec_outputs = self.model.decoder.bert(
                    input_ids=labels,
                    encoder_hidden_states=enc_outputs.last_hidden_state,
                    encoder_attention_mask=attention_mask,
                )
                logits = self.model.decoder.cls(
                    dec_outputs.last_hidden_state)[:, :-1, :]
                loss_fct = CrossEntropyLoss(ignore_index=self.padding_idx)
                denoise_loss = loss_fct(
                    logits.contiguous().view(-1, self.vocab_size),
                    labels[:, 1:].contiguous().view(-1))
                loss = denoise_loss

            else:
                raise KeyError(
                    'Unsupported pretrain task: \'{}\''.format(pretrain_task))

        else:
            # regular loss
            if self.task in {'imdb', 'agnews'}:
                pooled_output = self.dropout(enc_outputs.pooler_output)
                logits = self.classifier(pooled_output)
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, logits.size(-1)),
                                labels.view(-1))

            elif self.task in {'squad', 'newsqa'}:
                logits = self.classifier(enc_outputs.last_hidden_state)
                start_logits, end_logits = logits.split(1, dim=-1)
                start_logits = start_logits.squeeze(-1).contiguous()
                end_logits = end_logits.squeeze(-1).contiguous()
                logits = (start_logits, end_logits)

                # sometimes the start/end positions are outside our model
                # inputs, we ignore these terms
                ignored_index = start_logits.size(1)
                start_positions = start_positions.clamp(0, ignored_index)
                end_positions = end_positions.clamp(0, ignored_index)

                loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
                start_loss = loss_fct(start_logits, start_positions)
                end_loss = loss_fct(end_logits, end_positions)
                loss = (start_loss + end_loss) / 2

            elif self.task in {'cnndm', 'msqg'}:
                dec_outputs = self.model.decoder.bert(
                    input_ids=labels,
                    encoder_hidden_states=enc_outputs.last_hidden_state,
                    encoder_attention_mask=attention_mask,
                )
                dec_hidden_states = dec_outputs.last_hidden_state
                logits = self.model.decoder.cls(dec_hidden_states)[:, :-1, :]

                num_tokens = labels[:, 1:].ne(self.padding_idx).sum().item()
                label_smoothing = self.label_smoothing if self.training \
                    else 0.0
                if label_smoothing > 0:
                    loss_fct = LabelSmoothingLoss(
                        label_smoothing,
                        self.vocab_size,
                        ignore_index=self.padding_idx,
                    ).to(logits.device)
                    loss = loss_fct(
                        F.log_softmax(logits.contiguous().view(
                            -1, self.vocab_size),
                                      dim=-1),
                        labels[:, 1:].contiguous().view(-1)) / num_tokens
                else:
                    loss_fct = CrossEntropyLoss(ignore_index=self.padding_idx)
                    loss = loss_fct(
                        logits.contiguous().view(-1, self.vocab_size),
                        labels[:, 1:].contiguous().view(-1))

            else:
                raise KeyError('Unsupported task: \'{}\''.format(self.task))

            # contrastive loss
            if self.use_contrastive_loss and self.training:
                regular_loss = loss.clone()
                example_indices = [
                    k for k in example_indices
                    if k.item() in contrast_monitor.synth_tokens
                ]
                all_group_ids = contrast_monitor.all_group_ids[self.client_id]
                topk_group_ids = \
                    contrast_monitor.topk_group_ids[self.client_id]
                if len(example_indices) > 0 and len(topk_group_ids) > 1:
                    example_indices = torch.stack(example_indices)
                    synth_input_ids = torch.stack([
                        contrast_monitor.synth_tokens[k.item()]
                        for k in example_indices
                    ]).to(self.model.device)

                    contrast_enc_hidden = torch.stack([
                        contrast_monitor.enc_hidden[k.item()]
                        for k in example_indices
                    ]).to(self.model.device)
                    contrast_outputs = self.model.decoder.bert(
                        input_ids=synth_input_ids,
                        encoder_hidden_states=contrast_enc_hidden,
                    )
                    cur_dec_hidden = self.contrast_head(
                        contrast_outputs.last_hidden_state).mean(1)

                    pos_client_ids = [
                        x for x in topk_group_ids[1:self.contrast_topk + 1]
                    ]
                    all_dec_hiddens = contrast_monitor.dec_hidden
                    sim_hiddens = [[
                        all_dec_hiddens[cid][k.item()] for k in example_indices
                    ] for cid in pos_client_ids]
                    sim_hiddens = torch.stack([
                        torch.stack(hid) for hid in sim_hiddens
                    ]).mean(0).to(self.model.device)
                    sim_matrix = F.cosine_similarity(cur_dec_hidden,
                                                     sim_hiddens,
                                                     dim=-1)
                    nominator = torch.exp(sim_matrix / self.contrast_temp)
                    denominator = nominator

                    neg_client_ids = [
                        x for x in all_group_ids[::-1][:self.contrast_topk]
                        if x not in topk_group_ids
                    ]
                    if len(neg_client_ids) > 0:
                        dissim_hiddens = [[
                            all_dec_hiddens[cid][k.item()]
                            for k in example_indices
                        ] for cid in neg_client_ids]
                        dissim_hiddens = torch.stack([
                            torch.stack(hid) for hid in dissim_hiddens
                        ]).to(self.model.device)
                        dissim_matrix = F.cosine_similarity(
                            cur_dec_hidden.unsqueeze(0),
                            dissim_hiddens,
                            dim=-1)
                        denominator = denominator + (torch.exp(
                            dissim_matrix / self.contrast_temp)).sum(0)

                    contrastive_loss = -torch.log(
                        nominator / denominator).mean()
                    loss += contrastive_loss

        return ModelOutput(loss=loss,
                           regular_loss=regular_loss,
                           contrastive_loss=contrastive_loss,
                           logits=logits)
