# coding=utf-8
# Copyright (c) 2021 Ant Group
# Author: Xiang Hu

from unicodedata import bidirectional
import torch
import torch.nn.functional as F
from model.r2d2_common import (
    LMLossParam, 
    CacheSlots, 
    BOS_CACHE_ID,
    EOS_CACHE_ID
)
import logging
import numpy as np
from collections import deque
import random
from model.tree_encoder import UniLMEncoder
from utils.table_converter import convert_cuda_tables
from utils.tree_utils import get_tree_from_merge_trajectory
from utils.vocab_builder import convert_tree_to_wordtree


logger = logging.getLogger(__name__)


def cuda_default_lm_loss(loss_param: LMLossParam):
    model = loss_param.model
    tables = loss_param.chart_tables
    tensor_cache = loss_param.tensor_cache
    flatten_input_ids = loss_param.flatten_input_ids
    cache_ids = torch.full([flatten_input_ids.shape[0] * 2],
                           0,
                           requires_grad=False,
                           dtype=torch.int,
                           device=model.device)
    tables.prepare_bilm(cache_ids, BOS_CACHE_ID, EOS_CACHE_ID)
    context_cache_ids = cache_ids.view(-1, 2)[:flatten_input_ids.shape[0], :]

    e_ij = tensor_cache.gather(
        context_cache_ids.flatten(),
        [CacheSlots.E_IJ])[0]

    e_ij = e_ij.view(*context_cache_ids.shape, model.input_dim)
    logits = model.infer(e_ij)
    return F.cross_entropy(logits, flatten_input_ids)


def estimate_span_loss(model, input_ids_np, span_batch, seq_lens, tables, tensor_cache):
    if len(span_batch) == 0:
        return 0
    task_id = model.mask_token_id
    input_ids_batch = []
    tgt_ids_batch = []
    context_span = []
    for batch_i, spans in enumerate(span_batch):
        for st, ed in spans:
            input_ids = [task_id] * (ed - st + 1)
            input_ids_batch.append(input_ids)
            tgt_ids_batch.append(list(input_ids_np[batch_i][st: ed + 1]))
            context_span.append([batch_i, st, ed])

    # padding tgt_ids_batch
    max_ids_len = max(map(len, tgt_ids_batch))
    for tgt_ids in tgt_ids_batch:
        if len(tgt_ids) < max_ids_len:
            tgt_ids.extend([-1] * (max_ids_len - len(tgt_ids)))

    context_span = torch.tensor(context_span, device=model.device, dtype=torch.int)
    # gather representation of spans
    cache_ids = torch.full([context_span.shape[0] * 2],
                           0,
                           requires_grad=False,
                           dtype=torch.int,
                           device=model.device)
    tables.prepare_span_bilm(cache_ids, context_span, 
                             BOS_CACHE_ID, EOS_CACHE_ID)

    context_cache_ids = cache_ids.view(-1, 2)

    context = tensor_cache.gather(context_cache_ids.flatten(),
                               [CacheSlots.E_IJ])[0]
    context = context.view(*context_cache_ids.shape, model.input_dim)  # (N, 2, dim)
    assert isinstance(model.tree_decoder, UniLMEncoder)
    outputs = model.tree_decoder(input_ids=input_ids_batch,
                                 memory=context,
                                 embeddings=model.embedding,
                                 bidirectional_pos=True)
    logits = model.classifier(model.cls_dense(outputs))  # (N, max_ids_len, vocab_size)
    return F.cross_entropy(logits.permute(0, 2, 1), 
                           torch.tensor(tgt_ids_batch, device=model.device),
                           ignore_index=-1)


def cuda_span_loss(loss_param: LMLossParam, max_span=10, min_span=0):
    model = loss_param.model
    tables = loss_param.chart_tables
    tensor_cache = loss_param.tensor_cache
    atom_spans = loss_param.atom_spans
    # get trees
    # trees = convert_cuda_tables(tables.dump_cells(), tensor_cache)  # change to merge_indices
    # roots = [t.root.best_node for t in trees]
    merge_indices = loss_param.s_indices
    roots = []
    input_ids = loss_param.input_ids
    seq_lens = loss_param.seq_lens
    if atom_spans is None:
        atom_spans = [[] for _ in range(input_ids.shape[0])]
    merge_indices_np = merge_indices.cpu().data.numpy()
    for sent_i in range(len(seq_lens)):
        roots.append(get_tree_from_merge_trajectory(merge_indices_np[sent_i], 
                                                    seq_lens[sent_i], None))

    span_batch = []
    for batch_i, root in enumerate(roots):
        queue = deque()
        queue.append(root)
        candidate_spans = []
        while len(queue) > 0:
            parent = queue.popleft()
            if min_span <= parent.j - parent.i <= max_span:
                invalid=False
                if atom_spans[batch_i] is not None:
                    for st, ed in atom_spans[batch_i]:
                        if parent.i < st <= parent.j < ed or \
                            st < parent.i <= ed < parent.j:
                            invalid=True
                if not invalid and parent != root:
                    candidate_spans.append(parent)
            if not parent.is_leaf:
                queue.append(parent.left)
                queue.append(parent.right)
            # not to mask whole sentences
    
        # select spans without overlap
        selected_spans = []
        while len(candidate_spans) > 0:
            node = random.choice(candidate_spans)
            selected_spans.append([node.i, node.j])
            filtered_spans = []
            for _node in candidate_spans:
                if _node.j < node.i or _node.i > node.j:
                    filtered_spans.append(_node)
            candidate_spans = filtered_spans

        selected_spans.sort(key=lambda x: x[0])
        span_batch.append(selected_spans)

    return estimate_span_loss(model, input_ids.cpu().data.numpy(), span_batch, seq_lens, tables, tensor_cache)

def cuda_generative_loss(loss_param: LMLossParam, max_positions=10):
    model = loss_param.model
    tables = loss_param.chart_tables
    tensor_cache = loss_param.tensor_cache
    atom_spans = loss_param.atom_spans
    input_ids = loss_param.input_ids
    seq_lens = loss_param.seq_lens

    task_id = model.mask_token_id
    input_ids_np = input_ids.to('cpu').data.numpy()
    input_ids_batch = []
    tgt_ids_batch = []
    context_span = []
    if atom_spans is None:
        atom_spans = [[] for _ in range(input_ids.shape[0])]
    for batch_i, spans in enumerate(atom_spans):
        offset = 0
        if spans is None:
            spans = []
        s_spans = sorted(spans, key=lambda sp: sp[0])
        for st, ed in s_spans:
            if ed - st  >= max_positions:
                continue
            while offset < st:
                input_ids_batch.append([task_id])
                tgt_ids_batch.append([input_ids_np[batch_i][offset]])
                context_span.append([batch_i, offset, offset])
                offset += 1
            input_ids = [task_id]
            input_ids.extend(input_ids_np[batch_i][st: ed])
            input_ids_batch.append(input_ids)
            tgt_ids_batch.append(input_ids_np[batch_i][st: ed + 1])
            context_span.append([batch_i, st, ed])
            offset = ed + 1
        while offset < seq_lens[batch_i]:
            input_ids_batch.append([task_id])
            tgt_ids_batch.append([input_ids_np[batch_i][offset]])
            context_span.append([batch_i, offset, offset])
            offset += 1
    
    # padding tgt_ids_batch
    max_ids_len = max(map(len, tgt_ids_batch))
    for tgt_ids in tgt_ids_batch:
        if len(tgt_ids) < max_ids_len:
            tgt_ids.extend([-1] * (max_ids_len - len(tgt_ids)))
    
    context_span = torch.tensor(context_span, device=model.device, dtype=torch.int)
    # gather representation of spans
    cache_ids = torch.full([context_span.shape[0] * 2],
                           0,
                           requires_grad=False,
                           dtype=torch.int,
                           device=model.device)
    tables.prepare_span_bilm(cache_ids, context_span, 
                             BOS_CACHE_ID, EOS_CACHE_ID)

    context_cache_ids = cache_ids.view(-1, 2)

    context = tensor_cache.gather(context_cache_ids.flatten(),
                               [CacheSlots.E_IJ])[0]
    context = context.view(*context_cache_ids.shape, model.input_dim)
    assert isinstance(model.tree_decoder, UniLMEncoder)
    outputs = model.tree_decoder(input_ids=input_ids_batch,
                                 memory=context,
                                 embeddings=model.embedding)
    logits = model.classifier(model.cls_dense(outputs))  # (N, max_ids_len, vocab_size)
    return F.cross_entropy(logits.permute(0, 2, 1), 
                           torch.tensor(tgt_ids_batch, device=model.device),
                           ignore_index=-1)


def cuda_span_cos_loss(loss_param: LMLossParam, min_span=2, max_span=5, rate=0.5):
    model = loss_param.model
    tables = loss_param.chart_tables
    tensor_cache = loss_param.tensor_cache
    # get trees
    trees = convert_cuda_tables(tables.dump_cells(), tensor_cache)
    roots = [t.root.best_node for t in trees]

    span_batch = []
    for batch_i, root in enumerate(roots):
        queue = deque()
        queue.append(root)
        span_candidates = []
        while len(queue) > 0:
            parent = queue.popleft()
            if min_span <= parent.j - parent.i + 1 <= max_span:
                span_candidates.append(parent)
            if not parent.is_leaf:
                queue.append(parent.left)
                queue.append(parent.right)
    
        selected_indices = np.random.choice(len(span_candidates), 
                                            size=(int(rate * len(span_candidates))),
                                            replace=False)
        for idx in selected_indices:
            span_batch.append([batch_i, span_candidates[idx]])

    # cross entropy
    # gather context of spans
    targets = []
    e_ij_pool = tensor_cache.get_tensor_cache(CacheSlots.E_IJ)
    spans_batch = []
    for batch_i, node in span_batch:
        targets.append(node.cache_id)
        spans_batch.append([batch_i, node.i, node.j])

    # gather representation of spans
    cache_ids = torch.full([len(span_batch) * 2],
                           0,
                           requires_grad=False,
                           dtype=torch.int,
                           device=model.device)
    tables.prepare_span_bilm(cache_ids, torch.tensor(spans_batch, device=model.device, dtype=torch.int), 
                             BOS_CACHE_ID, EOS_CACHE_ID)
    context_cache_ids = cache_ids.view(-1, 2)

    context = tensor_cache.gather(context_cache_ids.flatten(),
                               [CacheSlots.E_IJ])[0]
    context = context.view(*context_cache_ids.shape, model.input_dim)
    pred_tensor = model.infer(context, in_logits=True)  # (N, dim)
    target_repr = e_ij_pool[targets]
    target = torch.ones((target_repr.shape[0]), device=e_ij_pool.device)
    
    return F.cosine_embedding_loss(pred_tensor, target_repr, target)


def cuda_vocab_mask_loss(loss_param: LMLossParam, basic_vocab=None, external_vocab=None, max_span=10):
    model = loss_param.model
    tables = loss_param.chart_tables
    tensor_cache = loss_param.tensor_cache
    merge_indices = loss_param.s_indices
    atom_spans = loss_param.atom_spans
    input_ids = loss_param.input_ids
    roots = []
    seq_lens = loss_param.seq_lens
    if atom_spans is None:
        atom_spans = [[] for _ in range(input_ids.shape[0])]
    merge_indices_np = merge_indices.cpu().data.numpy()
    for sent_i in range(len(seq_lens)):
        roots.append(get_tree_from_merge_trajectory(merge_indices_np[sent_i], 
                                                    seq_lens[sent_i], None))
    input_ids_np = input_ids.cpu().data.numpy()
    span_batch = []
    for sent_i, root in enumerate(roots):
        valid_spans = np.zeros((input_ids.shape[1], input_ids.shape[1]))
        word_tree_root = convert_tree_to_wordtree(root)
        word_tree_root.tokens_and_segments(input_ids_np[sent_i], basic_vocab, external_vocab)
        queue = deque()
        queue.append(word_tree_root)
        while len(queue) > 0:
            current = queue.popleft()
            # if current.hit_vocab and current.j - current.i < max_span or \
            #     current.i == current.j:
            #     valid_spans[current.i, current.j] = 1
            # if current.left is not None and current.right is not None:
            #     queue.append(current.left)
            #     queue.append(current.right)

        if root.j > root.i:
            valid_spans[root.i][root.j] = 0

        offset = 0
        seq_len = root.j - root.i + 1
        selected_spans = []
        while offset < seq_len:
            spans_candidates = []
            for span_end in range(offset, min(offset + max_span, seq_len)):
                if valid_spans[offset][span_end] == 1:
                    spans_candidates.append(span_end)
            span_end = np.random.choice(spans_candidates, size=1)[0]
            # span_end = spans_candidates[-1]  # take the longest valid span
            selected_spans.append([offset, span_end])
            offset = span_end + 1
        span_batch.append(selected_spans)
    return estimate_span_loss(model, input_ids_np, span_batch, seq_lens, tables, tensor_cache)