import torch
import torch.nn.functional as F
from data_structure.tensor_cache import CacheType, TensorCache
from model.r2d2_common import (
    LMLossParam, 
    CacheSlots, 
    BOS_CACHE_ID,
    EOS_CACHE_ID
)
import logging
import numpy as np
from collections import deque
import random
from model.tree_encoder import UniLMEncoder
from utils.table_converter import convert_cuda_tables
from utils.tree_utils import get_tree_from_merge_trajectory


logger = logging.getLogger(__name__)


def cuda_default_lm_loss(loss_param: LMLossParam):
    model = loss_param.model
    tables = loss_param.chart_tables
    tensor_cache = loss_param.tensor_cache
    flatten_input_ids = loss_param.flatten_input_ids
    cache_ids = tables.prepare_bilm(flatten_input_ids.shape[0], BOS_CACHE_ID, EOS_CACHE_ID)
    cache_ids = cache_ids.to(model.device)
    context_cache_ids = cache_ids.view(-1, 2)[:flatten_input_ids.shape[0], :]

    e_ij = tensor_cache.gather(
        context_cache_ids.flatten(),
        [CacheSlots.E_IJ])[0]

    e_ij = e_ij.view(*context_cache_ids.shape, model.input_dim)
    logits = model.infer(e_ij)
    return F.cross_entropy(logits, flatten_input_ids)


def estimate_span_loss(model, input_ids_np, span_batch, tables, tensor_cache):
    if len(span_batch) == 0:
        return 0
    task_id = model.mask_token_id
    input_ids_batch = []
    tgt_ids_batch = []
    context_span = []
    for batch_i, spans in enumerate(span_batch):
        for st, ed in spans:
            input_ids = [task_id] * (ed - st + 1)
            input_ids_batch.append(input_ids)
            tgt_ids_batch.append(list(input_ids_np[batch_i][st: ed + 1]))
            context_span.append([batch_i, st, ed])

    if len(tgt_ids_batch) == 0:
        return 0
    # padding tgt_ids_batch
    max_ids_len = max(map(len, tgt_ids_batch))
    for tgt_ids in tgt_ids_batch:
        if len(tgt_ids) < max_ids_len:
            tgt_ids.extend([-1] * (max_ids_len - len(tgt_ids)))

    context_span = torch.tensor(context_span, device=model.device, dtype=torch.int)
    # gather representation of spans
    cache_ids = torch.full([context_span.shape[0] * 2],
                           0,
                           requires_grad=False,
                           dtype=torch.int,
                           device=model.device)
    tables.prepare_span_bilm(cache_ids, context_span, 
                             BOS_CACHE_ID, EOS_CACHE_ID)

    context_cache_ids = cache_ids.view(-1, 2)

    context = tensor_cache.gather(context_cache_ids.flatten(),
                                  [CacheSlots.E_IJ])[0]
    context = context.view(*context_cache_ids.shape, model.input_dim)  # (N, 2, dim)
    assert isinstance(model.tree_decoder, UniLMEncoder)
    outputs = model.tree_decoder(input_ids=input_ids_batch,
                                 memory=context,
                                 embeddings=model.embedding)
    logits = model.classifier(model.cls_dense(outputs))  # (N, max_ids_len, vocab_size)
    return F.cross_entropy(logits.permute(0, 2, 1), 
                           torch.tensor(tgt_ids_batch, device=model.device),
                           ignore_index=-1)


def cuda_span_loss(loss_param: LMLossParam, max_span=10, min_span=0, weighted=True):
    model = loss_param.model
    tables = loss_param.chart_tables
    tensor_cache = loss_param.tensor_cache
    atom_spans = loss_param.atom_spans
    # get trees
    # trees = convert_cuda_tables(tables.dump_cells(), tensor_cache)  # change to merge_indices
    # roots = [t.root.best_node for t in trees]
    merge_indices = loss_param.s_indices
    roots = []
    input_ids = loss_param.input_ids
    batch_size = input_ids.shape[0]
    max_len = input_ids.shape[1]
    seq_lens = loss_param.seq_lens
    if atom_spans is None:
        atom_spans = [[] for _ in range(input_ids.shape[0])]
    merge_indices_np = merge_indices.cpu().data.numpy()
    for sent_i in range(len(seq_lens)):
        roots.append(get_tree_from_merge_trajectory(merge_indices_np[sent_i], 
                                                    seq_lens[sent_i], None))

    context_length = torch.full([batch_size, max_len, 2], 0, dtype=torch.int, device=model.device)
    if weighted:
        tables.gather_context_length(context_length)

    context_length_np = context_length.cpu().numpy()

    span_batch = []
    for batch_i, root in enumerate(roots):
        queue = deque()
        queue.append(root)
        candidate_spans = []
        while len(queue) > 0:
            parent = queue.popleft()
            if min_span <= parent.j - parent.i <= max_span:
                invalid=False
                if atom_spans[batch_i] is not None:
                    for st, ed in atom_spans[batch_i]:
                        if parent.i < st <= parent.j < ed or \
                            st < parent.i <= ed < parent.j:
                            invalid=True
                if not invalid and parent != root:
                    if weighted:
                        context_length = 0
                        if parent.i > 0:
                            context_length += context_length_np[batch_i][parent.i - 1][0]
                        if parent.j < seq_lens[batch_i] - 1:
                            context_length += context_length_np[batch_i][parent.j + 1][1]
                    else:
                        context_length = 1
                    candidate_spans.append((parent, context_length))
            if not parent.is_leaf:
                queue.append(parent.left)
                queue.append(parent.right)
            # not to mask whole sentences
    
        # select spans without overlap
        selected_spans = []
        while len(candidate_spans) > 0:
            candidates = list(map(lambda x: x[0], candidate_spans))
            total_len = sum(map(lambda x: x[1], candidate_spans))
            prob = list(map(lambda x: x[1] / total_len, candidate_spans))
            node = random.choice(candidate_spans)
            node = np.random.choice(candidates, size=1, p=prob)[0]
            selected_spans.append([node.i, node.j])
            filtered_spans = []
            for _node, _weight in candidate_spans:
                if _node.j < node.i or _node.i > node.j:
                    filtered_spans.append((_node, _weight))
            candidate_spans = filtered_spans

        selected_spans.sort(key=lambda x: x[0])
        span_batch.append(selected_spans)

    return estimate_span_loss(model, input_ids.cpu().data.numpy(), span_batch, tables, tensor_cache)

def cuda_generative_loss(loss_param: LMLossParam, max_positions=10):
    model = loss_param.model
    tables = loss_param.chart_tables
    tensor_cache = loss_param.tensor_cache
    atom_spans = loss_param.atom_spans
    input_ids = loss_param.input_ids
    seq_lens = loss_param.seq_lens

    task_id = model.mask_token_id
    input_ids_np = input_ids.to('cpu').data.numpy()
    input_ids_batch = []
    tgt_ids_batch = []
    context_span = []
    if atom_spans is None:
        atom_spans = [[] for _ in range(input_ids.shape[0])]
    for batch_i, spans in enumerate(atom_spans):
        offset = 0
        if spans is None:
            spans = []
        s_spans = sorted(spans, key=lambda sp: sp[0])
        for st, ed in s_spans:
            if ed - st  >= max_positions:
                continue
            while offset < st:
                input_ids_batch.append([task_id])
                tgt_ids_batch.append([input_ids_np[batch_i][offset]])
                context_span.append([batch_i, offset, offset])
                offset += 1
            input_ids = [task_id]
            input_ids.extend(input_ids_np[batch_i][st: ed])
            input_ids_batch.append(input_ids)
            tgt_ids_batch.append(input_ids_np[batch_i][st: ed + 1])
            context_span.append([batch_i, st, ed])
            offset = ed + 1
        while offset < seq_lens[batch_i]:
            input_ids_batch.append([task_id])
            tgt_ids_batch.append([input_ids_np[batch_i][offset]])
            context_span.append([batch_i, offset, offset])
            offset += 1
    
    # padding tgt_ids_batch
    max_ids_len = max(map(len, tgt_ids_batch))
    for tgt_ids in tgt_ids_batch:
        if len(tgt_ids) < max_ids_len:
            tgt_ids.extend([-1] * (max_ids_len - len(tgt_ids)))
    
    context_span = torch.tensor(context_span, device=model.device, dtype=torch.int)
    # gather representation of spans
    cache_ids = torch.full([context_span.shape[0] * 2],
                           0,
                           requires_grad=False,
                           dtype=torch.int,
                           device=model.device)
    tables.prepare_span_bilm(cache_ids, context_span, 
                             BOS_CACHE_ID, EOS_CACHE_ID)

    context_cache_ids = cache_ids.view(-1, 2)

    context = tensor_cache.gather(context_cache_ids.flatten(),
                               [CacheSlots.E_IJ])[0]
    context = context.view(*context_cache_ids.shape, model.input_dim)
    assert isinstance(model.tree_decoder, UniLMEncoder)
    outputs = model.tree_decoder(input_ids=input_ids_batch,
                                 memory=context,
                                 embeddings=model.embedding)
    logits = model.classifier(model.cls_dense(outputs))  # (N, max_ids_len, vocab_size)
    return F.cross_entropy(logits.permute(0, 2, 1), 
                           torch.tensor(tgt_ids_batch, device=model.device),
                           ignore_index=-1)


def cuda_span_cos_loss(loss_param: LMLossParam, min_span=2, max_span=5, rate=0.5):
    model = loss_param.model
    tables = loss_param.chart_tables
    tensor_cache = loss_param.tensor_cache
    # get trees
    trees = convert_cuda_tables(tables.dump_cells(), tensor_cache)
    roots = [t.root.best_node for t in trees]

    span_batch = []
    for batch_i, root in enumerate(roots):
        queue = deque()
        queue.append(root)
        span_candidates = []
        while len(queue) > 0:
            parent = queue.popleft()
            if min_span <= parent.j - parent.i + 1 <= max_span:
                span_candidates.append(parent)
            if not parent.is_leaf:
                queue.append(parent.left)
                queue.append(parent.right)
    
        selected_indices = np.random.choice(len(span_candidates), 
                                            size=(int(rate * len(span_candidates))),
                                            replace=False)
        for idx in selected_indices:
            span_batch.append([batch_i, span_candidates[idx]])

    # cross entropy
    # gather context of spans
    targets = []
    e_ij_pool = tensor_cache.get_tensor_cache(CacheSlots.E_IJ)
    spans_batch = []
    for batch_i, node in span_batch:
        targets.append(node.cache_id)
        spans_batch.append([batch_i, node.i, node.j])

    # gather representation of spans
    cache_ids = torch.full([len(span_batch) * 2],
                           0,
                           requires_grad=False,
                           dtype=torch.int,
                           device=model.device)
    tables.prepare_span_bilm(cache_ids, torch.tensor(spans_batch, device=model.device, dtype=torch.int), 
                             BOS_CACHE_ID, EOS_CACHE_ID)
    context_cache_ids = cache_ids.view(-1, 2)

    context = tensor_cache.gather(context_cache_ids.flatten(),
                               [CacheSlots.E_IJ])[0]
    context = context.view(*context_cache_ids.shape, model.input_dim)
    pred_tensor = model.infer(context, in_logits=True)  # (N, dim)
    target_repr = e_ij_pool[targets]
    target = torch.ones((target_repr.shape[0]), device=e_ij_pool.device)
    
    return F.cosine_embedding_loss(pred_tensor, target_repr, target)