"""
Bonito CTC-CRF Model.
Code is adapted from https://github.com/nanoporetech/bonito/tree/master/bonito
"""

import torch
import numpy as np
from collections import Iterable
from fast_ctc_decode import beam_search
from koi.ctc import SequenceDist, Max, Log, semiring
from koi.ctc import logZ_cu, viterbi_alignments, logZ_cu_sparse, bwd_scores_cu_sparse, fwd_scores_cu_sparse
from torch import nn
from bonito_nn import Module, Convolution, LinearCRFEncoder, Serial, Permute, layers, from_dict
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


def get_stride(m):
    if hasattr(m, 'stride'):
        return m.stride if isinstance(m.stride, int) else m.stride[0]
    if isinstance(m, Convolution):
        return get_stride(m.conv)
    if isinstance(m, Serial):
        return int(np.prod([get_stride(x) for x in m]))
    return 1


class CTC_CRF(SequenceDist):

    def __init__(self, state_len, alphabet):
        super().__init__()
        self.alphabet = alphabet
        self.state_len = state_len
        self.n_base = len(alphabet[1:])
        self.idx = torch.cat([
            torch.arange(self.n_base**(self.state_len))[:, None],
            torch.arange(
                self.n_base**(self.state_len)
            ).repeat_interleave(self.n_base).reshape(self.n_base, -1).T
        ], dim=1).to(torch.int32)

    def n_score(self):
        return len(self.alphabet) * self.n_base**(self.state_len)

    def logZ(self, scores, S:semiring=Log):
        T, N, _ = scores.shape
        Ms = scores.reshape(T, N, -1, len(self.alphabet))
        alpha_0 = Ms.new_full((N, self.n_base**(self.state_len)), S.one)
        beta_T = Ms.new_full((N, self.n_base**(self.state_len)), S.one)
        return logZ_cu_sparse(Ms, self.idx, alpha_0, beta_T, S)

    def normalise(self, scores):
        return (scores - self.logZ(scores)[:, None] / len(scores))

    def forward_scores(self, scores, S: semiring=Log):
        T, N, _ = scores.shape
        Ms = scores.reshape(T, N, -1, self.n_base + 1)
        alpha_0 = Ms.new_full((N, self.n_base**(self.state_len)), S.one)
        return fwd_scores_cu_sparse(Ms, self.idx, alpha_0, S, K=1)

    def backward_scores(self, scores, S: semiring=Log):
        T, N, _ = scores.shape
        Ms = scores.reshape(T, N, -1, self.n_base + 1)
        beta_T = Ms.new_full((N, self.n_base**(self.state_len)), S.one)
        return bwd_scores_cu_sparse(Ms, self.idx, beta_T, S, K=1)

    def compute_transition_probs(self, scores, betas):
        T, N, C = scores.shape
        # add bwd scores to edge scores
        log_trans_probs = (scores.reshape(T, N, -1, self.n_base + 1) + betas[1:, :, :, None])
        # transpose from (new_state, dropped_base) to (old_state, emitted_base) layout
        log_trans_probs = torch.cat([
            log_trans_probs[:, :, :, [0]],
            log_trans_probs[:, :, :, 1:].transpose(3, 2).reshape(T, N, -1, self.n_base)
        ], dim=-1)
        # convert from log probs to probs by exponentiating and normalising
        trans_probs = torch.softmax(log_trans_probs, dim=-1)
        #convert first bwd score to initial state probabilities
        init_state_probs = torch.softmax(betas[0], dim=-1)
        return trans_probs, init_state_probs

    def reverse_complement(self, scores):
        T, N, C = scores.shape
        expand_dims = T, N, *(self.n_base for _ in range(self.state_len)), self.n_base + 1
        scores = scores.reshape(*expand_dims)
        blanks = torch.flip(scores[..., 0].permute(
            0, 1, *range(self.state_len + 1, 1, -1)).reshape(T, N, -1, 1), [0, 2]
        )
        emissions = torch.flip(scores[..., 1:].permute(
            0, 1, *range(self.state_len, 1, -1),
            self.state_len +2,
            self.state_len + 1).reshape(T, N, -1, self.n_base), [0, 2, 3]
        )
        return torch.cat([blanks, emissions], dim=-1).reshape(T, N, -1)

    def viterbi(self, scores):
        traceback = self.posteriors(scores, Max)
        paths = traceback.argmax(2) % len(self.alphabet)
        return paths

    def path_to_str(self, path):
        alphabet = np.frombuffer(''.join(self.alphabet).encode(), dtype='u1')
        seq = alphabet[path[path != 0]]
        return seq.tobytes().decode()

    def prepare_ctc_scores(self, scores, targets):
        # convert from CTC targets (with blank=0) to zero indexed
        targets = torch.clamp(targets - 1, 0)

        T, N, C = scores.shape
        scores = scores.to(torch.float32)
        n = targets.size(1) - (self.state_len - 1)
        stay_indices = sum(
            targets[:, i:n + i] * self.n_base ** (self.state_len - i - 1)
            for i in range(self.state_len)
        ) * len(self.alphabet)
        move_indices = stay_indices[:, 1:] + targets[:, :n - 1] + 1
        stay_scores = scores.gather(2, stay_indices.expand(T, -1, -1))
        move_scores = scores.gather(2, move_indices.expand(T, -1, -1))
        return stay_scores, move_scores

    def ctc_loss(self, scores, targets, target_lengths, loss_clip=None, reduction='mean', normalise_scores=True):
        if normalise_scores:
            scores = self.normalise(scores)
        stay_scores, move_scores = self.prepare_ctc_scores(scores, targets)
        logz = logZ_cu(stay_scores, move_scores, target_lengths + 1 - self.state_len)
        loss = - (logz / target_lengths)
        if loss_clip:
            loss = torch.clamp(loss, 0.0, loss_clip)
        if reduction == 'mean':
            return loss.mean()
        elif reduction in ('none', None):
            return loss
        else:
            raise ValueError('Unknown reduction type {}'.format(reduction))

    def ctc_viterbi_alignments(self, scores, targets, target_lengths):
        stay_scores, move_scores = self.prepare_ctc_scores(scores, targets)
        return viterbi_alignments(stay_scores, move_scores, target_lengths + 1 - self.state_len)


def conv(c_in, c_out, ks, stride=1, bias=False, activation=None):
    return Convolution(c_in, c_out, ks, stride=stride, padding=ks//2, bias=bias, activation=activation)


def rnn_encoder(n_base, state_len, insize=1, stride=5, winlen=19, activation='swish', rnn_type='lstm', features=768, scale=5.0, blank_score=None, expand_blanks=True, num_layers=5):
    rnn = layers[rnn_type]
    return Serial([
            conv(insize, 4, ks=5, bias=True, activation=activation),
            conv(4, 16, ks=5, bias=True, activation=activation),
            conv(16, features, ks=winlen, stride=stride, bias=True, activation=activation),
            Permute([2, 0, 1]),
            *(rnn(features, features, reverse=(num_layers - i) % 2) for i in range(num_layers)),
            LinearCRFEncoder(
                features, n_base, state_len, activation='tanh', scale=scale,
                blank_score=blank_score, expand_blanks=expand_blanks
            )
    ])


class SeqdistModel(Module):
    def __init__(self, encoder, seqdist):
        super().__init__()
        self.seqdist = seqdist
        self.encoder = encoder
        self.stride = get_stride(encoder)
        self.alphabet = seqdist.alphabet

    def forward(self, x):
        return self.encoder(x)

    def decode_batch(self, x):
        scores = self.seqdist.posteriors(x.to(torch.float32)) + 1e-8
        tracebacks = self.seqdist.viterbi(scores.log()).to(torch.int16).T
        return [self.seqdist.path_to_str(x) for x in tracebacks.cpu().numpy()]

    def decode(self, x):
        return self.decode_batch(x.unsqueeze(1))[0]

    def loss(self, scores, targets, target_lengths, **kwargs):
        return self.seqdist.ctc_loss(scores.to(torch.float32), targets, target_lengths, **kwargs)

class BaseCaller(SeqdistModel):

    def __init__(self, config):
        seqdist = CTC_CRF(
            state_len=config['global_norm']['state_len'],
            alphabet=config['labels']['labels']
        )
        if 'type' in config['encoder']: #new-style config
            encoder = from_dict(config['encoder'])
        else: #old-style
            encoder = rnn_encoder(seqdist.n_base, seqdist.state_len, insize=config['input']['features'], **config['encoder'])
        super().__init__(encoder, seqdist)
        self.config = config

    def forward(self, x, mode='basecalling', return_features=False, return_features_only=False):
        if mode == 'basecalling':
            if return_features:
                outputs, fmaps = self.encoder(x, return_features=return_features)
                return outputs, fmaps[-2]
            elif return_features_only:
                return self.encoder(x, return_features_only=return_features_only)
            else:
                return self.encoder(x)
        elif mode == 'inference':
            if return_features:
                outputs, features = self.encoder(x, return_features=return_features)
                return self.decode_batch(outputs), features
            else:
                outputs = self.encoder(x, return_features=return_features)
                return self.decode_batch(outputs)
        else:
            raise ValueError("Invalid mode")


class CTCBaseCaller(nn.Module):

    def __init__(self, basecaller, features_dim=320):
        super(CTCBaseCaller, self).__init__()
        self.basecaller = basecaller
        self.projector = nn.Linear(features_dim, len(self.basecaller.alphabet))

    def forward(self, x, mode='basecalling'):
        if mode == 'basecalling':
            x = self.basecaller(x)
            x = self.projector(x.permute(1, 0, 2)).permute(1, 0, 2)
            return x
        elif mode == 'decode':
            return self.decode_batch(x)
        else:
            raise ValueError("Invalid mode")

    def decode_batch(self, x, beamsize=2, threshold=1e-5):
        '''x: (T, N, C)'''
        x = x.permute(1, 0, 2).cpu().numpy().astype(np.float32) # Permute to batch first
        seqs = []
        for y in x:
            seq, _ = beam_search(y, self.basecaller.alphabet, beamsize, threshold)
            seqs.append(seq)
        return seqs


class Model(torch.nn.Module):

    def __init__(self, basecaller, embedding_dim, lstm_params, attention_params, classifier_head):
        super(Model, self).__init__()
        self.basecaller = basecaller
        self.kmer_embedding = nn.Embedding(1025, embedding_dim, padding_idx=0) # 1024 kmers + 1 padding
        self.kmer_lstm = nn.LSTM(**lstm_params) # Output dimension 2 * seq_dim   
        self.multi_head_attn = neko_MultiheadAttention(**attention_params)
        self.classifier_head = classifier_head

        self.alphabet = ['N', 'A', 'C', 'G', 'T']
        self.state_len = 3
        self.n_base = 4


    def forward(self, x, mode=None):
        if mode == 'basecalling':
            return self.basecaller(x)
        elif mode == 'decode':
            return self.decode_batch(x)
        else:
            signals, kmers, kmer_lengths = x
            scores, features = self.basecaller(signals, return_features=True)
            seq_features = self.get_seq_features(kmers, kmer_lengths)
            mask = torch.stack([torch.cat([torch.zeros((length, features.shape[0]), dtype=torch.bool), 
                                torch.ones((seq_features.shape[0] - length, features.shape[0]), dtype=torch.bool)], axis=0) 
                                for length in kmer_lengths]).to(signals.device)
            seq_features, _ =  self.multi_head_attn(seq_features, features, features, attn_mask=mask)
            preds = [self.classifier_head(seq_features[:kmer_lengths[i], i, :]).squeeze(1) for i in range(len(kmer_lengths))]
            return scores, preds

    def decode_batch(self, x):
        return self.basecaller.decode_batch(x)

    def decode(self, x):
        return self.basecaller.decode_batch(x.unsqueeze(1))[0]

    def get_seq_features(self, kmers, kmer_lengths):
        seq_features = self.kmer_embedding(kmers)
        seq_features = pack_padded_sequence(seq_features, kmer_lengths.detach().cpu(), batch_first=True, enforce_sorted=False)
        seq_features, _ = self.kmer_lstm(seq_features)
        seq_features, _ = pad_packed_sequence(seq_features)
        return seq_features
