import numpy as np
import torch
import torch.nn.functional as F
import math
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


class Model(torch.nn.Module):

    def __init__(self, feature_extractor, feature_dim, embedding_dim, lstm_params, seq_dim, sample_reads=None,
                 fine_tune_feature_extractor=False):
        super(Model, self).__init__()
        self.feature_extractor = feature_extractor

        self.seq_dim = seq_dim
        
        self.embedding = nn.Embedding(4, embedding_dim) # 1024 kmers + 1 padding
        self.lstm = nn.LSTM(**lstm_params) # Output dimension 2 * seq_dim 
        self.seq_out = nn.Linear(seq_dim, 2 * feature_dim)
        self.projector = nn.Linear(feature_dim, 2)
        self.sample_reads = sample_reads
        self.fine_tune_feature_extractor = fine_tune_feature_extractor

        if not self.fine_tune_feature_extractor:
            for p in self.feature_extractor.parameters():
                p.requires_grad = False

    def extract_signal_features(self, signals, positions_masks, signals_lengths, batch_size=256, device='cuda', mode='train'):
        ''' Basecaller here acts as feature extractor'''
        ''' No gradient is calculated'''

        if self.sample_reads > 0:
            signals = self.group_features(signals, signals_lengths)
            original_signals_lengths = len(signals)

            indices = [np.random.choice(len(signal), self.sample_reads, replace=False) for signal in signals]
            signals = torch.cat([signal[idx] for signal, idx in zip(signals, indices)])
            positions_masks = [pos_mask[idx] for pos_mask, idx in zip(positions_masks, indices)]
            signals_lengths = [self.sample_reads for i in range(original_signals_lengths)]

        if self.fine_tune_feature_extractor and mode == 'train':
            self.feature_extractor.train()
            read_features = []
            for signal in signals.split(batch_size):
                features = self.feature_extractor(signal.to(device)) # S x N x C
                read_features.append(features.permute(1, 0, 2))
            read_features = torch.cat(read_features)
            read_features = self.group_features(read_features, signals_lengths)
        
        else:
            self.feature_extractor.eval()
            with torch.no_grad():
                read_features = []
                for signal in signals.split(batch_size):
                    features = self.feature_extractor(signal.to(device)) # S x N x C
                    read_features.append(features.permute(1, 0, 2).detach().cpu())
                read_features = torch.cat(read_features)
                read_features = self.group_features(read_features, signals_lengths)

        return read_features, positions_masks

    def group_features(self, signal_features, signals_lengths):
        i = 0
        grouped_features = []
        for length in signals_lengths:
            grouped_features.append(signal_features[i:i + length])
            i += length
        return grouped_features


    def extract_seq_features(self, seqs, seqs_lengths):
        seq_features = self.embedding(seqs)
        seq_features, _ = self.lstm(seq_features) # N x L x D
        seq_features = seq_features.flatten(1, 2) # N x (L x D)
        seq_features = self.seq_out(seq_features) # N x D'
        return self.group_features(seq_features, seqs_lengths)

    def extract_position_rep_from_signals(self, signal, seq, pos_mask):
        ''' signals is an array of size N x L x D '''
        ''' N is the number of signal features extracted from reads expressed in a transcript'''
        ''' sequences is an array of size M x D '''
        ''' N' is the number of extracted positions in a transcript '''
        ''' position_mask is an array of size N x M '''
        ''' where (i, j) entry is True if position N'_j is in read N_i and False otherwise '''
        
        # (N x L x D) x (M x D) -> (N x L x M)
        seq, seq_mil = seq[:, :seq.shape[-1] // 2], seq[:, seq.shape[-1] // 2:]
        seq = seq / math.sqrt(seq.shape[-1]) # scaled attention

        # N, M, D = signal.shape

        # signal = self.w_out(signal.view(N * M, D)).view(N, M, -1)
        # signal_key, signal_value = signal[:, :, :signal.shape[-1] // 2], signal[:, :, signal.shape[-1] // 2:]

        attn = torch.einsum('ijk, mk->ijm', signal, seq) # N x L x M
        attn = F.softmax(attn, dim=1)  # each entry is the attention weight for the j-th entry of signal i w.r.t position k

        # (N x L x M) x (N x L x D) -> N x M x D
        output = torch.bmm(attn.permute(0, 2, 1), signal) # Each entry i, j is the representation of read i with respect to position j
        
        # Next we extract site level representation for each sequence
        # (N x M x D) x (M x D) - > N x M
        read_attn = torch.einsum('ijk, jk->ij', output, seq_mil)
        if pos_mask is not None:
            # Zero contribution of reads that do not contain position M
            new_position_mask = torch.zeros_like(pos_mask, dtype=torch.float)
            new_position_mask.masked_fill_(~pos_mask, -1e9)
            read_attn += new_position_mask.to(output.device)
            
        read_attn = F.softmax(read_attn, dim=0).unsqueeze(-1) # N x M x 1
        
        # (N x M x D) x (N x M x 1) -> M x D
        output = (output * read_attn).sum(0)
        
        return output

    def basecall(self, signals):
        return self.feature_extractor(signals, mode='basecall')

    def forward(self, signals, signals_lengths, positions_masks, sequences, sequences_length, mode='train'):
        
        signal_features, positions_masks = self.extract_signal_features(signals.unsqueeze(1), positions_masks, signals_lengths, batch_size=256, 
                                                                        device=sequences.device, mode=mode)
        sequence_features = self.extract_seq_features(sequences, sequences_length)

        all_features = []
        for i in range(len(signal_features)):
            seq = sequence_features[i]
            read_features = signal_features[i].to(seq.device)
            pos_mask = positions_masks[i]
            all_features.append(self.extract_position_rep_from_signals(read_features, seq, pos_mask))

        all_features = torch.cat(all_features)
        return self.projector(all_features)
