import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import os
from models.soft_labels import *

########################################################################################################
## Soft Contrastive Losses
########################################################################################################
#------------------------------------------------------------------------------------------#

kl_loss = nn.KLDivLoss(reduction='none')

def get_logits(z1, z2):
    B, T = z1.size(0), z1.size(1)
    if B == 1:
        return z1.new_tensor(0.)
    z = torch.cat([z1, z2], dim=0)  # 2B x T x C
    z = z.transpose(0, 1)  # T x 2B x C
    sim = torch.matmul(z, z.transpose(1, 2))  # T x 2B x 2B
    logits = torch.tril(sim, diagonal=-1)[:, :, :-1]    # T x 2B x (2B-1)
    logits += torch.triu(sim, diagonal=1)[:, :, 1:]
    logits = F.log_softmax(logits, dim=-1) # T x 2B x (2B-1)
    return logits

def get_logits2(z1, z2):
    B, T = z1.size(0), z1.size(1)
    if B == 1:
        return z1.new_tensor(0.)
    z = torch.cat([z1, z2], dim=1)  # 2B x T x C
    sim = torch.matmul(z, z.transpose(1, 2))  # T x 2B x 2B
    logits = torch.tril(sim, diagonal=-1)[:, :, :-1]    # T x 2B x (2B-1)
    logits += torch.triu(sim, diagonal=1)[:, :, 1:]
    logits = F.log_softmax(logits, dim=-1) # T x 2B x (2B-1)
    return logits

def instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R):
    B, T = z1.size(0), z1.size(1)
    DTW_ = torch.stack([torch.vstack([DTW_L, DTW_R]) for _ in range(T)])
    scale = DTW_.sum(dim=-1, keepdim=True)
    DTW_ /= scale
    logits = get_logits(z1,z2)
    logits.retain_grad()
    loss = (( kl_loss(logits.log_softmax(dim=-1), DTW_) - DTW_*DTW_.log() )).sum()
    loss /= (2*B*T)
    return loss

def temporal_contrastive_loss_soft(z1, z2, TIME_LAG_L, TIME_LAG_R):
    B, T = z1.size(0), z1.size(1)
    TIME_LAG_ = torch.stack([torch.vstack([TIME_LAG_L, TIME_LAG_R]) for _ in range(B)])
    
    scale = TIME_LAG_.sum(dim=-1, keepdim=True)
    TIME_LAG_ /= scale
    TIME_LAG_ += 1e-6
    logits = get_logits2(z1,z2)
    logits.retain_grad()
    loss = (( kl_loss(logits.log_softmax(dim=-1), TIME_LAG_) - TIME_LAG_*TIME_LAG_.log() )).sum()
    loss /= (2*B*T)
    return loss

#------------------------------------------------------------------------------------------#

def hierarchical_contrastive_loss_soft(z1, z2, DTW, tau_temp=2, lambda_=0.5, temporal_unit=0, soft_temporal=False, soft_instance=False, temporal_hierarchy=True):
    # "Sigmoid" function for "soft temporal CL"
    DTW = torch.tensor(DTW, device=z1.device)
    DTW_L, DTW_R = duplicate_DTW(DTW)
    loss = torch.tensor(0., device=z1.device)
    d = 0
    while z1.size(1) > 1:
        if lambda_ != 0:
            if soft_instance:
                loss += lambda_ * instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R)
            else:
                loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        if d >= temporal_unit:
            if 1 - lambda_ != 0:
                if soft_temporal:
                    if temporal_hierarchy:
                        TIME_LAG = generate_TIMELAG_sigmoid(z1.shape[1],tau_temp*(2**d))
                    else:
                        TIME_LAG = generate_TIMELAG_sigmoid(z1.shape[1],tau_temp)
                    TIME_LAG = torch.tensor(TIME_LAG, device=z1.device)
                    TIME_LAG_L, TIME_LAG_R = duplicate_DTW(TIME_LAG)
                    loss += (1 - lambda_) * temporal_contrastive_loss_soft(z1, z2, TIME_LAG_L, TIME_LAG_R)
                else:
                    loss += (1 - lambda_) * temporal_contrastive_loss_hard(z1, z2)
        d += 1
        z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2)
        z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2)

    if z1.size(1) == 1:
        if lambda_ != 0:
            if soft_instance:
                loss += lambda_ * instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R)
            else:
                loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        d += 1

    return loss / d


def hierarchical_contrastive_loss_sigmoid_window(z1, z2, DTW, window_ratio, tau_temp=2, lambda_=0.5, temporal_unit=0, soft_temporal=False, soft_instance=False):
    DTW = torch.tensor(DTW, device=z1.device)
    DTW_L, DTW_R = duplicate_DTW(DTW)
    loss = torch.tensor(0., device=z1.device)
    d = 0
    while z1.size(1) > 1:
        if lambda_ != 0:
            if soft_instance:
                loss += lambda_ * instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R)
            else:
                loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        if d >= temporal_unit:
            if 1 - lambda_ != 0:
                if soft_temporal:
                    TIME_LAG = generate_TIMELAG_sigmoid_window(z1.shape[1],tau_temp*(2**d),window_ratio)
                    TIME_LAG = torch.tensor(TIME_LAG, device=z1.device)
                    TIME_LAG_L, TIME_LAG_R = duplicate_DTW(TIME_LAG)
                    loss += (1 - lambda_) * temporal_contrastive_loss_soft(z1, z2, TIME_LAG_L, TIME_LAG_R)
                else:
                    loss += (1 - lambda_) * temporal_contrastive_loss_hard(z1, z2)
        d += 1
        z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2)
        z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2)

    if z1.size(1) == 1:
        if lambda_ != 0:
            if soft_instance:
                loss += lambda_ * instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R)
            else:
                loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        d += 1

    return loss / d

def hierarchical_contrastive_loss_sigmoid_threshold(z1, z2, DTW, threshold, lambda_=0.5, temporal_unit=0, soft_temporal=False, soft_instance=False):
    DTW = torch.tensor(DTW, device=z1.device)
    DTW_L, DTW_R = duplicate_DTW(DTW)
    loss = torch.tensor(0., device=z1.device)
    d = 0
    while z1.size(1) > 1:
        if lambda_ != 0:
            if soft_instance:
                loss += lambda_ * instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R)
            else:
                loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        if d >= temporal_unit:
            if 1 - lambda_ != 0:
                if soft_temporal:
                    TIME_LAG = generate_TIMELAG_sigmoid_threshold(z1.shape[1], threshold)
                    TIME_LAG = torch.tensor(TIME_LAG, device=z1.device)
                    TIME_LAG_L, TIME_LAG_R = duplicate_DTW(TIME_LAG)
                    loss += (1 - lambda_) * temporal_contrastive_loss_soft(z1, z2, TIME_LAG_L, TIME_LAG_R)
                else:
                    loss += (1 - lambda_) * temporal_contrastive_loss_hard(z1, z2)
        d += 1
        z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2)
        z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2)

    if z1.size(1) == 1:
        if lambda_ != 0:
            if soft_instance:
                loss += lambda_ * instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R)
            else:
                loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        d += 1

    return loss / d


def hierarchical_contrastive_loss_sigmoid_wo_instance(z1, z2, DTW, tau_temp=2, lambda_=0.5, temporal_unit=0, soft_temporal=False, soft_instance=False):
    DTW = torch.tensor(DTW, device=z1.device)
    DTW_L, DTW_R = duplicate_DTW(DTW)
    loss = torch.tensor(0., device=z1.device)
    d = 0
    while z1.size(1) > 1:
        if d >= temporal_unit:
            if 1 - lambda_ != 0:
                if soft_temporal:
                    TIME_LAG = generate_TIMELAG_sigmoid(z1.shape[1],tau_temp*(2**d))
                    TIME_LAG = torch.tensor(TIME_LAG, device=z1.device)
                    TIME_LAG_L, TIME_LAG_R = duplicate_DTW(TIME_LAG)
                    loss += (1 - lambda_) * temporal_contrastive_loss_soft(z1, z2, TIME_LAG_L, TIME_LAG_R)
                else:
                    loss += (1 - lambda_) * temporal_contrastive_loss_hard(z1, z2)
        d += 1
        z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2)
        z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2)

    if z1.size(1) == 1:
        d += 1

    return loss / d


def hierarchical_contrastive_loss_gaussian(z1, z2, DTW, tau_temp=2, lambda_=0.5, temporal_unit=0, soft_temporal=False, soft_instance=False, temporal_hierarchy=True):
    DTW = torch.tensor(DTW, device=z1.device)
    DTW_L, DTW_R = duplicate_DTW(DTW)
    loss = torch.tensor(0., device=z1.device)
    d = 0
    while z1.size(1) > 1:
        if lambda_ != 0:
            if soft_instance:
                loss += lambda_ * instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R)
            else:
                loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        if d >= temporal_unit:
            if 1 - lambda_ != 0:
                if soft_temporal:
                    if temporal_hierarchy:
                        TIME_LAG = generate_TIMELAG_gaussian(z1.shape[1],tau_temp/(2**d))
                    else:
                        TIME_LAG = generate_TIMELAG_gaussian(z1.shape[1],tau_temp)
                    TIME_LAG = torch.tensor(TIME_LAG, device=z1.device)
                    TIME_LAG_L, TIME_LAG_R = duplicate_DTW(TIME_LAG)
                    loss += (1 - lambda_) * temporal_contrastive_loss_soft(z1, z2, TIME_LAG_L, TIME_LAG_R)
                else:
                    loss += (1 - lambda_) * temporal_contrastive_loss_hard(z1, z2)
        d += 1
        z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2)
        z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2)

    if z1.size(1) == 1:
        if lambda_ != 0:
            if soft_instance:
                loss += lambda_ * instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R)
            else:
                loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        d += 1

    return loss / d


def hierarchical_contrastive_loss_same_interval(z1, z2, DTW, tau_temp=2, lambda_=0.5, temporal_unit=0, soft_temporal=False, soft_instance=False):
    DTW = torch.tensor(DTW, device=z1.device)
    DTW_L, DTW_R = duplicate_DTW(DTW)
    loss = torch.tensor(0., device=z1.device)
    d = 0
    while z1.size(1) > 1:
        if lambda_ != 0:
            if soft_instance:
                loss += lambda_ * instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R)
            else:
                loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        if d >= temporal_unit:
            if 1 - lambda_ != 0:
                if soft_temporal:
                    TIME_LAG = generate_TIMELAG_same_interval(z1.shape[1],tau_temp/(2**d))
                    TIME_LAG = torch.tensor(TIME_LAG, device=z1.device)
                    TIME_LAG_L, TIME_LAG_R = duplicate_DTW(TIME_LAG)
                    loss += (1 - lambda_) * temporal_contrastive_loss_soft(z1, z2, TIME_LAG_L, TIME_LAG_R)
                else:
                    loss += (1 - lambda_) * temporal_contrastive_loss_hard(z1, z2)
        d += 1
        z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2)
        z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2)

    if z1.size(1) == 1:
        if lambda_ != 0:
            if soft_instance:
                loss += lambda_ * instance_contrastive_loss_soft(z1, z2, DTW_L, DTW_R)
            else:
                loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        d += 1

    return loss / d




########################################################################################################
## Hard Contrastive Losses
########################################################################################################

#------------------------------------------------------------------------------------------#
def instance_contrastive_loss_hard(z1, z2):
    B, T = z1.size(0), z1.size(1)
    if B == 1:
        return z1.new_tensor(0.)
    z = torch.cat([z1, z2], dim=0)  # 2B x T x C
    z = z.transpose(0, 1)  # T x 2B x C
    sim = torch.matmul(z, z.transpose(1, 2))  # T x 2B x 2B
    logits = torch.tril(sim, diagonal=-1)[:, :, :-1]    # T x 2B x (2B-1)
    logits += torch.triu(sim, diagonal=1)[:, :, 1:]
    logits = -F.log_softmax(logits, dim=-1)
    
    i = torch.arange(B, device=z1.device)
    loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2
    return loss

def temporal_contrastive_loss_hard(z1, z2):
    B, T = z1.size(0), z1.size(1)
    if T == 1:
        return z1.new_tensor(0.)
    z = torch.cat([z1, z2], dim=1)  # B x 2T x C
    sim = torch.matmul(z, z.transpose(1, 2))  # B x 2T x 2T
    logits = torch.tril(sim, diagonal=-1)[:, :, :-1]    # B x 2T x (2T-1)
    logits += torch.triu(sim, diagonal=1)[:, :, 1:]
    logits = -F.log_softmax(logits, dim=-1)
    
    t = torch.arange(T, device=z1.device)
    loss = (logits[:, t, T + t - 1].mean() + logits[:, T + t, t].mean()) / 2
    return loss
#------------------------------------------------------------------------------------------#

def hierarchical_contrastive_loss_hard(z1, z2, lambda_=0.5, temporal_unit=0):
    loss = torch.tensor(0., device=z1.device)
    d = 0
    while z1.size(1) > 1:
        if lambda_ != 0:
            loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        if d >= temporal_unit:
            if 1 - lambda_ != 0:
                loss += (1 - lambda_) * temporal_contrastive_loss_hard(z1, z2)
        d += 1
        print('z1.shape',z1.shape)
        print('z2.shape',z2.shape)
        
        z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2)
        z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2)
        
    if z1.size(1) == 1:
        if lambda_ != 0:
            loss += lambda_ * instance_contrastive_loss_hard(z1, z2)
        d += 1
    return loss / d