import torch
from torch import nn
from models.tcc_loss import compute_tcc_loss
from models.d2tw.dtw_loss import compute_alignment_loss as compute_dtw_loss
from models.dropdtw.drop_dtw import compute_alignment_loss as compute_drop_dtw_loss
from models.ot.vava import compute_vava_loss


class AlignmentLoss(nn.Module):
    def __init__(self, args):
        super(AlignmentLoss, self).__init__()
        self.args = args
        self.cyclic = False if self.args.dataset in ['break_egg', 'tennis'] else True
        print(f'Cyclic action {self.cyclic}')

    def forward(self, embeddings, steps=None, seq_lens=None, pos_steps=None, global_step=0):
        if 'tcc' in self.args.loss:
            loss = compute_tcc_loss(embeddings, steps, seq_lens, self.args.tcc_temp)
            return loss, torch.zeros_like(loss)
        elif 'd2tw' in self.args.loss:
            return compute_dtw_loss(self.args, embeddings, pos_steps,
                                          alignment_type=self.args.loss,
                                          cyclic_action=self.cyclic)
        elif 'dropdtw' in self.args.loss:
            loss = compute_drop_dtw_loss(embeddings, distractors=None,
                                         keep_percentile=self.args.drop_percent,
                                         l2_normalize=self.args.drop_l2norm)
            return loss, torch.zeros_like(loss)
        elif 'vava' in self.args.loss:
            loss = compute_vava_loss(embeddings, global_step=global_step)
            return loss, torch.zeros_like(loss)
        else:
            raise NotImplementedError

