import numpy as np
import torch
from torch.nn import functional as F
from torch.nn import MSELoss
from torch import nn
from koi.ctc import SequenceDist, Max, Log, semiring
from koi.ctc import logZ_cu, viterbi_alignments, logZ_cu_sparse, bwd_scores_cu_sparse, fwd_scores_cu_sparse


class SegmentationLoss(nn.Module):
    
    def __init__(self, signal_length=3600):
        super(SegmentationLoss,self).__init__()
        self.signal_length = signal_length
        self.criterion = MSELoss(reduction='mean')

    def forward(self, outputs, segments, segment_lengths):
        assert(len(outputs) == len(segments))
        segments = F.normalize(segments.float(), p=1, dim=1)
        loss = 0
        for output, segment, segment_length in zip(outputs, segments, segment_lengths):
            loss += self.criterion(F.softmax(output, dim=0), segment[:segment_length])
            # loss -= (F.log_softmax(output, dim=0) * segment[:segment_length].float()).sum()
        return loss / len(outputs)

class CTCLoss(nn.Module):
    def __init__(self, idx):
        super(CTCLoss,self).__init__()
        self.alphabet = ['N', 'A', 'C', 'G', 'T']
        self.state_len = 3
        self.n_base = 4
        # self.idx = torch.cat([
        #     torch.arange(self.n_base**(self.state_len))[:, None],
        #     torch.arange(
        #         self.n_base**(self.state_len)
        #     ).repeat_interleave(self.n_base).reshape(self.n_base, -1).T
        # ], dim=1).to(torch.int32)
        self.idx = idx

    def n_score(self):
        return len(self.alphabet) * self.n_base**(self.state_len)

    def logZ(self, scores, S:semiring=Log):
        T, N, _ = scores.shape
        Ms = scores.reshape(T, N, -1, len(self.alphabet))
        alpha_0 = Ms.new_full((N, self.n_base**(self.state_len)), S.one)
        beta_T = Ms.new_full((N, self.n_base**(self.state_len)), S.one)
        return logZ_cu_sparse(Ms, self.idx, alpha_0, beta_T, S)

    def normalise(self, scores):
        return (scores - self.logZ(scores)[:, None] / len(scores))

    def forward_scores(self, scores, S: semiring=Log):
        T, N, _ = scores.shape
        Ms = scores.reshape(T, N, -1, self.n_base + 1)
        alpha_0 = Ms.new_full((N, self.n_base**(self.state_len)), S.one)
        return fwd_scores_cu_sparse(Ms, self.idx, alpha_0, S, K=1)

    def backward_scores(self, scores, S: semiring=Log):
        T, N, _ = scores.shape
        Ms = scores.reshape(T, N, -1, self.n_base + 1)
        beta_T = Ms.new_full((N, self.n_base**(self.state_len)), S.one)
        return bwd_scores_cu_sparse(Ms, self.idx, beta_T, S, K=1)

    def compute_transition_probs(self, scores, betas):
        T, N, C = scores.shape
        # add bwd scores to edge scores
        log_trans_probs = (scores.reshape(T, N, -1, self.n_base + 1) + betas[1:, :, :, None])
        # transpose from (new_state, dropped_base) to (old_state, emitted_base) layout
        log_trans_probs = torch.cat([
            log_trans_probs[:, :, :, [0]],
            log_trans_probs[:, :, :, 1:].transpose(3, 2).reshape(T, N, -1, self.n_base)
        ], dim=-1)
        # convert from log probs to probs by exponentiating and normalising
        trans_probs = torch.softmax(log_trans_probs, dim=-1)
        #convert first bwd score to initial state probabilities
        init_state_probs = torch.softmax(betas[0], dim=-1)
        return trans_probs, init_state_probs

    def reverse_complement(self, scores):
        T, N, C = scores.shape
        expand_dims = T, N, *(self.n_base for _ in range(self.state_len)), self.n_base + 1
        scores = scores.reshape(*expand_dims)
        blanks = torch.flip(scores[..., 0].permute(
            0, 1, *range(self.state_len + 1, 1, -1)).reshape(T, N, -1, 1), [0, 2]
        )
        emissions = torch.flip(scores[..., 1:].permute(
            0, 1, *range(self.state_len, 1, -1),
            self.state_len +2,
            self.state_len + 1).reshape(T, N, -1, self.n_base), [0, 2, 3]
        )
        return torch.cat([blanks, emissions], dim=-1).reshape(T, N, -1)

    def viterbi(self, scores):
        traceback = self.posteriors(scores, Max)
        paths = traceback.argmax(2) % len(self.alphabet)
        return paths

    def path_to_str(self, path):
        alphabet = np.frombuffer(''.join(self.alphabet).encode(), dtype='u1')
        seq = alphabet[path[path != 0]]
        return seq.tobytes().decode()

    def prepare_ctc_scores(self, scores, targets):
        # convert from CTC targets (with blank=0) to zero indexed
        targets = torch.clamp(targets - 1, 0)

        T, N, C = scores.shape

        # scores = scores.to(torch.float32)
        n = targets.size(1) - (self.state_len - 1)
  
        stay_indices = sum(
            targets[:, i:n + i] * self.n_base ** (self.state_len - i - 1)
            for i in range(self.state_len)
        ) * len(self.alphabet)

        # state_len = 5, i.e, N + 4 nucleotides
        # targets.size(1) is the sequence length
        # n = seq_length - (state_len - 1) = maximum width for staying
        # 

        move_indices = stay_indices[:, 1:] + targets[:, :n - 1] + 1
        stay_scores = scores.gather(2, stay_indices.expand(T, -1, -1))
        move_scores = scores.gather(2, move_indices.expand(T, -1, -1))
        return stay_scores, move_scores

    def ctc_viterbi_alignments(self, scores, targets, target_lengths):
        stay_scores, move_scores = self.prepare_ctc_scores(scores, targets)
        return viterbi_alignments(stay_scores, move_scores, target_lengths + 1 - self.state_len)

    def forward(self, scores, targets, target_lengths, loss_clip=None, reduction='mean', normalise_scores=True):
        if normalise_scores:
            scores = self.normalise(scores)
        stay_scores, move_scores = self.prepare_ctc_scores(scores, targets)
        logz = logZ_cu(stay_scores, move_scores, target_lengths + 1 - self.state_len)
        loss = - (logz / target_lengths)
        if loss_clip:
            loss = torch.clamp(loss, 0.0, loss_clip)
        if reduction == 'mean':
            return loss.mean()
        elif reduction in ('none', None):
            return loss
        else:
            raise ValueError('Unknown reduction type {}'.format(reduction))


class SimCLRLoss(nn.Module):
    
    def __init__(self, temp=0.1):
        super(SimCLRLoss,self).__init__()
        self.temp = temp
        self.criterion = nn.CrossEntropyLoss()
    
    def forward(self, features, n_views=1):
        # Expect features as N_transcripts x (n_positions * num_reads) x dim
        # Here second dimension is organized as :
        # x_1, x_2, ..., x_(n_pos), x'_1, x'_2, ..., x'_(n_pos), x''_1, x''_2, ..., x''_(n_pos), ...
        # Features is expected to have been l2 normalized

        n_transcripts = features.shape[0]
        device = features.device
        labels = torch.cat([torch.arange(features.shape[1] // n_views) for _ in range(n_views)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        labels = labels.repeat(n_transcripts, 1, 1)
        labels = labels.to(device)
        
        similarity_matrix = torch.einsum('ijk, ikl->ijl', features, features.permute(0, 2, 1)) 
        
        # similarity matrix is of shape n_transcripts x (n_positions * num_reads)
        # discard the main diagonal from both: labels and similarities matrix

        mask = torch.eye(labels.shape[1], dtype=torch.bool).repeat(n_transcripts, 1, 1).to(device)
        labels = labels[~mask].view(n_transcripts, labels.shape[1], -1)
        #labels is now of shape n_transcripts x (n_position * num_reads) x (n_positions * num_reads - 1)
        # reshaping similarity matrix to the same size
        similarity_matrix = similarity_matrix[~mask].view(n_transcripts, similarity_matrix.shape[1], -1)
        # select and combine multiple positives
        
        positives = similarity_matrix[labels.bool()].view(n_transcripts, similarity_matrix.shape[1], -1)
        # select only the negatives
        negatives = similarity_matrix[~labels.bool()].view(n_transcripts, similarity_matrix.shape[1], -1)
        logits = torch.cat([positives, negatives], dim=2) 
        logits = logits / self.temp 
        loss = -torch.log(F.softmax(logits, dim=2)[:, :, :(n_views - 1)].sum(dim=2)).mean(dim=1).mean(dim=0) 

        return loss
