import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from models.soft_labels import *

########################################################################################################
## Soft Contrastive Losses
########################################################################################################
#------------------------------------------------------------------------------------------#
def instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R):
    z1 = z1.permute(0,2,1)
    z2 = z2.permute(0,2,1)
    B, T = z1.size(0), z1.size(1)
    if B == 1:
        return z1.new_tensor(0.)
    z = torch.cat([z1, z2], dim=0)  # 2B x T x C
    z = z.transpose(0, 1)  # T x 2B x C
    sim = torch.matmul(z, z.transpose(1, 2))  # T x 2B x 2B
    logits = torch.tril(sim, diagonal=-1)[:, :, :-1]    # T x 2B x (2B-1)
    logits += torch.triu(sim, diagonal=1)[:, :, 1:]
    logits = -F.log_softmax(logits, dim=-1)
    i = torch.arange(B, device=z1.device)
    loss = torch.sum(logits[:,i]*DTW_L)
    loss += torch.sum(logits[:,B + i]*DTW_R)
    loss /= (2*B*T)
    return loss

def temporal_contrastive_loss_soft(z1, z2, TIME_LAG_L, TIME_LAG_R):
    z1 = z1.permute(0,2,1)
    z2 = z2.permute(0,2,1)
    B, T = z1.size(0), z1.size(1)
    if T == 1:
        return z1.new_tensor(0.)
    z = torch.cat([z1, z2], dim=1)  # B x 2T x C
    sim = torch.matmul(z, z.transpose(1, 2))  # B x 2T x 2T
    logits = torch.tril(sim, diagonal=-1)[:, :, :-1]    # B x 2T x (2T-1)
    logits += torch.triu(sim, diagonal=1)[:, :, 1:]
    logits = -F.log_softmax(logits, dim=-1)
    t = torch.arange(T, device=z1.device)
    loss = torch.sum(logits[:,t]*TIME_LAG_L)
    loss += torch.sum(logits[:,T + t]*TIME_LAG_R)
    loss /= (2*B*T)
    return loss
#------------------------------------------------------------------------------------------#

def hierarchical_contrastive_loss_soft(z1, z2, DTW, tau_temp=2, lambda_=0.5, temporal_unit=0, soft_temporal=False, soft_instance=False, hierarchy=True):
    # "Sigmoid" function for "soft temporal CL"
    DTW = torch.tensor(DTW, device=z1.device)
    DTW_L, DTW_R = duplicate_DTW(DTW)
    loss = torch.tensor(0., device=z1.device)
    d = 0
    while z1.size(1) > 1:
        if lambda_ != 0:
            if soft_instance:
                loss += lambda_ * instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R)
            else:
                loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        if d >= temporal_unit:
            if 1 - lambda_ != 0:
                if soft_temporal:
                    if hierarchy:
                        TIME_LAG = generate_TIMELAG_sigmoid(z1.shape[1],tau_temp*(2**d))
                    else:
                        TIME_LAG = generate_TIMELAG_sigmoid(z1.shape[1],tau_temp)
                    TIME_LAG = torch.tensor(TIME_LAG, device=z1.device)
                    TIME_LAG_L, TIME_LAG_R = duplicate_DTW(TIME_LAG)
                    loss += (1 - lambda_) * temporal_contrastive_loss_soft(z1, z2, TIME_LAG_L, TIME_LAG_R)
                else:
                    loss += (1 - lambda_) * temporal_contrastive_loss_hard(z1, z2)
        d += 1
        z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2)
        z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2)

    if z1.size(1) == 1:
        if lambda_ != 0:
            if soft_instance:
                loss += lambda_ * instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R)
            else:
                loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        d += 1

    return loss / d


def hierarchical_contrastive_loss_sigmoid_window(z1, z2, DTW, window_ratio, tau_temp=2, lambda_=0.5, temporal_unit=0, soft_temporal=False, soft_instance=False):
    DTW = torch.tensor(DTW, device=z1.device)
    DTW_L, DTW_R = duplicate_DTW(DTW)
    loss = torch.tensor(0., device=z1.device)
    d = 0
    while z1.size(1) > 1:
        if lambda_ != 0:
            if soft_instance:
                loss += lambda_ * instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R)
            else:
                loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        if d >= temporal_unit:
            if 1 - lambda_ != 0:
                if soft_temporal:
                    TIME_LAG = generate_TIMELAG_sigmoid_window(z1.shape[1],tau_temp*(2**d),window_ratio)
                    TIME_LAG = torch.tensor(TIME_LAG, device=z1.device)
                    TIME_LAG_L, TIME_LAG_R = duplicate_DTW(TIME_LAG)
                    loss += (1 - lambda_) * temporal_contrastive_loss_soft(z1, z2, TIME_LAG_L, TIME_LAG_R)
                else:
                    loss += (1 - lambda_) * temporal_contrastive_loss_hard(z1, z2)
        d += 1
        z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2)
        z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2)

    if z1.size(1) == 1:
        if lambda_ != 0:
            if soft_instance:
                loss += lambda_ * instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R)
            else:
                loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        d += 1

    return loss / d

def hierarchical_contrastive_loss_sigmoid_threshold(z1, z2, DTW, threshold, lambda_=0.5, temporal_unit=0, soft_temporal=False, soft_instance=False):
    DTW = torch.tensor(DTW, device=z1.device)
    DTW_L, DTW_R = duplicate_DTW(DTW)
    loss = torch.tensor(0., device=z1.device)
    d = 0
    while z1.size(1) > 1:
        if lambda_ != 0:
            if soft_instance:
                loss += lambda_ * instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R)
            else:
                loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        if d >= temporal_unit:
            if 1 - lambda_ != 0:
                if soft_temporal:
                    TIME_LAG = generate_TIMELAG_sigmoid_threshold(z1.shape[1], threshold)
                    TIME_LAG = torch.tensor(TIME_LAG, device=z1.device)
                    TIME_LAG_L, TIME_LAG_R = duplicate_DTW(TIME_LAG)
                    loss += (1 - lambda_) * temporal_contrastive_loss_soft(z1, z2, TIME_LAG_L, TIME_LAG_R)
                else:
                    loss += (1 - lambda_) * temporal_contrastive_loss_hard(z1, z2)
        d += 1
        z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2)
        z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2)

    if z1.size(1) == 1:
        if lambda_ != 0:
            if soft_instance:
                loss += lambda_ * instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R)
            else:
                loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        d += 1

    return loss / d


def hierarchical_contrastive_loss_sigmoid_wo_instance(z1, z2, DTW, tau_temp=2, lambda_=0.5, temporal_unit=0, soft_temporal=False, soft_instance=False):
    DTW = torch.tensor(DTW, device=z1.device)
    DTW_L, DTW_R = duplicate_DTW(DTW)
    loss = torch.tensor(0., device=z1.device)
    d = 0
    while z1.size(1) > 1:
        if d >= temporal_unit:
            if 1 - lambda_ != 0:
                if soft_temporal:
                    TIME_LAG = generate_TIMELAG_sigmoid(z1.shape[1],tau_temp*(2**d))
                    TIME_LAG = torch.tensor(TIME_LAG, device=z1.device)
                    TIME_LAG_L, TIME_LAG_R = duplicate_DTW(TIME_LAG)
                    loss += (1 - lambda_) * temporal_contrastive_loss_soft(z1, z2, TIME_LAG_L, TIME_LAG_R)
                else:
                    loss += (1 - lambda_) * temporal_contrastive_loss_hard(z1, z2)
        d += 1
        z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2)
        z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2)

    if z1.size(1) == 1:
        d += 1

    return loss / d


def hierarchical_contrastive_loss_gaussian(z1, z2, DTW, tau_temp=2, lambda_=0.5, temporal_unit=0, soft_temporal=False, soft_instance=False, hierarchy=True):
    DTW = torch.tensor(DTW, device=z1.device)
    DTW_L, DTW_R = duplicate_DTW(DTW)
    loss = torch.tensor(0., device=z1.device)
    d = 0
    while z1.size(1) > 1:
        if lambda_ != 0:
            if soft_instance:
                loss += lambda_ * instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R)
            else:
                loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        if d >= temporal_unit:
            if 1 - lambda_ != 0:
                if soft_temporal:
                    if hierarchy:
                        TIME_LAG = generate_TIMELAG_gaussian(z1.shape[1],tau_temp/(2**d))
                    else:
                        TIME_LAG = generate_TIMELAG_gaussian(z1.shape[1],tau_temp)
                    TIME_LAG = torch.tensor(TIME_LAG, device=z1.device)
                    TIME_LAG_L, TIME_LAG_R = duplicate_DTW(TIME_LAG)
                    loss += (1 - lambda_) * temporal_contrastive_loss_soft(z1, z2, TIME_LAG_L, TIME_LAG_R)
                else:
                    loss += (1 - lambda_) * temporal_contrastive_loss_hard(z1, z2)
        d += 1
        z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2)
        z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2)

    if z1.size(1) == 1:
        if lambda_ != 0:
            if soft_instance:
                loss += lambda_ * instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R)
            else:
                loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        d += 1

    return loss / d


def hierarchical_contrastive_loss_same_interval(z1, z2, DTW, tau_temp=2, lambda_=0.5, temporal_unit=0, soft_temporal=False, soft_instance=False):
    DTW = torch.tensor(DTW, device=z1.device)
    DTW_L, DTW_R = duplicate_DTW(DTW)
    loss = torch.tensor(0., device=z1.device)
    d = 0
    while z1.size(1) > 1:
        if lambda_ != 0:
            if soft_instance:
                loss += lambda_ * instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R)
            else:
                loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        if d >= temporal_unit:
            if 1 - lambda_ != 0:
                if soft_temporal:
                    TIME_LAG = generate_TIMELAG_same_interval(z1.shape[1],tau_temp/(2**d))
                    TIME_LAG = torch.tensor(TIME_LAG, device=z1.device)
                    TIME_LAG_L, TIME_LAG_R = duplicate_DTW(TIME_LAG)
                    loss += (1 - lambda_) * temporal_contrastive_loss_soft(z1, z2, TIME_LAG_L, TIME_LAG_R)
                else:
                    loss += (1 - lambda_) * temporal_contrastive_loss_hard(z1, z2)
        d += 1
        z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2)
        z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2)

    if z1.size(1) == 1:
        if lambda_ != 0:
            if soft_instance:
                loss += lambda_ * instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R)
            else:
                loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        d += 1

    return loss / d




########################################################################################################
## Hard Contrastive Losses
########################################################################################################

#------------------------------------------------------------------------------------------#
def instance_contrastive_loss_hard(z1, z2):
    B, T = z1.size(0), z1.size(1)
    if B == 1:
        return z1.new_tensor(0.)
    z = torch.cat([z1, z2], dim=0)  # 2B x T x C
    z = z.transpose(0, 1)  # T x 2B x C
    sim = torch.matmul(z, z.transpose(1, 2))  # T x 2B x 2B
    logits = torch.tril(sim, diagonal=-1)[:, :, :-1]    # T x 2B x (2B-1)
    logits += torch.triu(sim, diagonal=1)[:, :, 1:]
    logits = -F.log_softmax(logits, dim=-1)
    
    i = torch.arange(B, device=z1.device)
    loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2
    return loss

def temporal_contrastive_loss_hard(z1, z2):
    B, T = z1.size(0), z1.size(1)
    if T == 1:
        return z1.new_tensor(0.)
    z = torch.cat([z1, z2], dim=1)  # B x 2T x C
    sim = torch.matmul(z, z.transpose(1, 2))  # B x 2T x 2T
    logits = torch.tril(sim, diagonal=-1)[:, :, :-1]    # B x 2T x (2T-1)
    logits += torch.triu(sim, diagonal=1)[:, :, 1:]
    logits = -F.log_softmax(logits, dim=-1)
    
    t = torch.arange(T, device=z1.device)
    loss = (logits[:, t, T + t - 1].mean() + logits[:, T + t, t].mean()) / 2
    return loss
#------------------------------------------------------------------------------------------#

def hierarchical_contrastive_loss_hard(z1, z2, lambda_=0.5, temporal_unit=0):
    loss = torch.tensor(0., device=z1.device)
    d = 0
    while z1.size(1) > 1:
        if lambda_ != 0:
            loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        if d >= temporal_unit:
            if 1 - lambda_ != 0:
                loss += (1 - lambda_) * temporal_contrastive_loss_hard(z1, z2)
        d += 1
        print('z1.shape',z1.shape)
        print('z2.shape',z2.shape)
        
        z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2)
        z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2)
        
    if z1.size(1) == 1:
        if lambda_ != 0:
            loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        d += 1
    return loss / d