import torch
import torch.nn as nn
import math
import logging 
log = logging.getLogger(__name__)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        '''
        From https://discuss.pytorch.org/t/how-to-modify-the-positional-encoding-in-torch-nn-transformer/104308/2
        '''
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, seq):
        #seq is [batch, len, dim]
        assert len(seq.shape) == 3
        pos_enc = self.pe[:,:seq.size(1),:]
        out = seq + pos_enc
        test = torch.zeros_like(seq) + pos_enc
        return out, pos_enc

class BrainPositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        '''
        From https://discuss.pytorch.org/t/how-to-modify-the-positional-encoding-in-torch-nn-transformer/104308/2
        '''
        super(BrainPositionalEncoding, self).__init__()

        pe = torch.zeros(max_len, d_model)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe) #TODO

    def forward(self, seq):
        #seq is [batch, len, dim]
        assert len(seq.shape) == 3
        pos_enc = self.pe[:,:seq.size(1),:]
        out = seq + pos_enc
        test = torch.zeros_like(seq) + pos_enc
        return out, pos_enc

class TimePositionEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(TimePositionEncoding, self).__init__()

        pe_dim = int(d_model/5) #The idea is that each one of the XYZ will get their own position embedding
        pe = torch.zeros(max_len, pe_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, pe_dim, 2).float() * (-math.log(10000.0) / pe_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        self.max_len = max_len

    def forward(self, seq, positions): 
        #seq is [batch, len, dim]
        assert len(seq.shape) == 3
        coords, seq_id, time_id = positions

        #self.pe is [1, max_len, d] size
        #coords is [batch, seq_len-1, 3] size
        p_embed = self.pe[0,coords]#2 axis is the XYZ axis
        n_batch, seq_len, n_axes, d_p_embed = p_embed.shape
        p_embed = p_embed.reshape(n_batch, seq_len, n_axes*d_p_embed)#flatten the last two dims into one position vector

        seq_id = self.pe[0,seq_id]
        time_id = self.pe[0,time_id]

        n_remainder = d_p_embed % 5
        remainder = torch.zeros(n_batch, seq_len, n_remainder).to(p_embed.device)
        input_embeddings = torch.cat([p_embed, seq_id, time_id, remainder], axis=-1)

        batch_size, _, d_embed = seq.shape
        cls_embed = torch.unsqueeze(self.pe[0,0].repeat(batch_size,5), 1) #[batch_size, 1, d]
        remainder = torch.zeros(n_batch, 1, n_remainder).to(p_embed.device)
        cls_embed = torch.cat([cls_embed, remainder], axis=-1)

        input_embeddings = torch.cat([cls_embed, input_embeddings], axis=1)#[batch, seq_len, d]

        out = seq + input_embeddings
        return out, input_embeddings

class TwinsPositionEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(TwinsPositionEncoding, self).__init__()

        pe_dim = int(d_model/3) #The idea is that each one of the XYZ will get their own position embedding
        pe = torch.zeros(max_len, pe_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, pe_dim, 2).float() * (-math.log(10000.0) / pe_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        self.max_len = max_len

    def forward(self, seq, positions): 
        #seq is [batch, len, dim]
        assert len(seq.shape) == 3
        coords = positions

        #self.pe is [1, max_len, d] size
        #coords is [batch, seq_len-1, 3] size
        p_embed = self.pe[0,coords]#2 axis is the XYZ axis
        n_batch, seq_len, n_axes, d_p_embed = p_embed.shape
        p_embed = p_embed.reshape(n_batch, seq_len, n_axes*d_p_embed)#flatten the last two dims into one position vector

        n_remainder = d_p_embed % 3
        remainder = torch.zeros(n_batch, seq_len, n_remainder).to(p_embed.device)
        input_embeddings = torch.cat([p_embed, remainder], axis=-1)

        batch_size, _, d_embed = seq.shape
        cls_embed = torch.unsqueeze(self.pe[0,0].repeat(batch_size,3), 1) #[batch_size, 1, d]
        remainder = torch.zeros(n_batch, 1, n_remainder).to(p_embed.device)
        cls_embed = torch.cat([cls_embed, remainder], axis=-1)

        input_embeddings = torch.cat([cls_embed, input_embeddings], axis=1)#[batch, seq_len, d]

        out = seq + input_embeddings
        return out, input_embeddings

class EmptyEncoding(nn.Module):
    def __init__(self):
        super(EmptyEncoding, self).__init__()

    def forward(self, seq, positions): 
        #seq is [batch, len, dim]
        assert len(seq.shape) == 3

        input_embeddings = []
        out = seq
        return out, input_embeddings

class RegionPositionEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000, dictionary_size=100, init_embedding=False, init_embedding_std=0.02):
        super(RegionPositionEncoding, self).__init__()

        assert d_model%4==0
        pe_dim = int(d_model/4) #The idea is that each one of the XYZ + seq id will get their own position embedding
        pe = torch.zeros(max_len, pe_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, pe_dim, 2).float() * (-math.log(10000.0) / pe_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        self.max_len = max_len

        self.input_embeddings = nn.Embedding(dictionary_size, pe_dim*3)
        if init_embedding:
            self.input_embeddings.weight.data.normal_(mean=0.0, std=init_embedding_std)

        self.register_buffer(
            "input_ids", torch.arange(dictionary_size).expand((1, -1)), persistent=False
        ) # input_ids is [1, dictionary_size]

    def forward(self, seq, positions): 
        #seq is [batch, len, dim]
        assert len(seq.shape) == 3
        coords, seq_id = positions

        p_embed = self.input_embeddings(coords)

        seq_id = self.pe[0,seq_id]
        input_embeddings = torch.cat([p_embed, seq_id], axis=-1)

        batch_size, _, d_embed = seq.shape
        cls_embed = torch.unsqueeze(self.pe[0,0].repeat(batch_size,4), 1) #[batch_size, 1, d]
        
        input_embeddings = torch.cat([cls_embed, input_embeddings], axis=1)#[batch, seq_len, d]

        out = seq + input_embeddings
        return out, input_embeddings

class MultiSubjBrainPositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(MultiSubjBrainPositionalEncoding, self).__init__()

        assert d_model%4==0
        pe_dim = int(d_model/4) #The idea is that each one of the XYZ + seq id will get their own position embedding
        pe = torch.zeros(max_len, pe_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, pe_dim, 2).float() * (-math.log(10000.0) / pe_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        self.max_len = max_len

    def forward(self, seq, positions): 
        #seq is [batch, len, dim]
        assert len(seq.shape) == 3
        coords, seq_id = positions
        #coords = coords[:,torch.randperm(coords.shape[1]),:]#TODO
        #coords = torch.zeros(coords.shape,dtype=torch.long)#TODO
        #coords = coords[:,[38, 34, 20, 29, 11, 32, 24,  0, 26, 48,  2, 22, 13,  3, 46, 25, 41, 19, 15, 30,  8,  7, 16, 10, 45,  1,  6,  5, 43, 23, 21, 37, 27, 28, 49, 40, 36, 44, 47,  4, 17, 12, 42, 14,  9, 18, 31, 33, 39, 35],:]#TODO

        #self.pe is [1, max_len, d] size
        #coords is [batch, seq_len-1, 3] size
        p_embed = self.pe[0,coords]#2 axis is the XYZ axis
        n_batch, seq_len, n_axes, d_p_embed = p_embed.shape
        p_embed = p_embed.reshape(n_batch, seq_len, n_axes*d_p_embed)#flatten the last two dims into one position vector
        seq_id = self.pe[0,seq_id]
        input_embeddings = torch.cat([p_embed, seq_id], axis=-1)

        batch_size, _, d_embed = seq.shape
        cls_embed = torch.unsqueeze(self.pe[0,0].repeat(batch_size,4), 1) #[batch_size, 1, d]
        
        input_embeddings = torch.cat([cls_embed, input_embeddings], axis=1)#[batch, seq_len, d]

        out = seq + input_embeddings
        return out, input_embeddings

class InputEmbedding(nn.Module): 
    def __init__(self, d_model, dictionary_size=300, init_embedding=False, init_embedding_std=0.02): 
        super(InputEmbedding, self).__init__() 
        # import pdb; pdb.set_trace()
        self.input_embeddings = nn.Embedding(dictionary_size, d_model)
        if init_embedding:
            self.input_embeddings.weight.data.normal_(mean=0.0, std=init_embedding_std)

        self.register_buffer(
            "input_ids", torch.arange(dictionary_size).expand((1, -1)), persistent=False
        ) # input_ids is [1, dictionary_size]

    def forward(self, seq, positions): 
        #seq is [batch, len, dim]
        assert len(seq.shape) == 3
        #input_embeddings = self.input_embeddings(self.input_ids[:, :seq.size(1)])
        #input_embeddings = self.input_embeddings(self.input_ids[:, [68,91]])#TODO
        input_embeddings_orig = self.input_embeddings(self.input_ids[:, list(range(seq.size(1)))])#NOTE that this assumes that the embeddings are the same for everything in the batch. It has shape [1,n_channels,d_embed]
        input_embeddings = self.input_embeddings(self.input_ids[:, positions])#TODO [1,n_batch, n_chan, d_embed]
        #import pdb; pdb.set_trace()
        assert input_embeddings.size(0)==1
        input_embeddings = input_embeddings[0]

        out = seq + input_embeddings
        return out, input_embeddings


class TransformerEncoderInput(nn.Module):
    def __init__(self, cfg, dropout=0.1):
        super(TransformerEncoderInput, self).__init__()
        self.cfg = cfg
        self.in_proj = nn.Linear(in_features=cfg.input_dim, out_features=cfg.hidden_dim)
        if "position_encoding" in self.cfg and self.cfg.position_encoding == "brain_position_encoding":
            self.positional_encoding = BrainPositionalEncoding(self.cfg.hidden_dim)
        elif "position_encoding" in self.cfg and self.cfg.position_encoding == "multi_subj_brain_position_encoding":
            self.positional_encoding = MultiSubjBrainPositionalEncoding(self.cfg.hidden_dim)
        elif "position_encoding" in self.cfg and self.cfg.position_encoding == "twins_position_encoding":
            self.positional_encoding = TwinsPositionEncoding(self.cfg.hidden_dim)
        elif "position_encoding" in self.cfg and self.cfg.position_encoding == "single_subject_position_encoding":
            self.positional_encoding = InputEmbedding(self.cfg.hidden_dim)
        elif "position_encoding" in self.cfg and self.cfg.position_encoding == "time_encoding":
            self.positional_encoding = TimePositionEncoding(self.cfg.hidden_dim)
        elif "position_encoding" in self.cfg and self.cfg.position_encoding == "empty_encoding":
            self.positional_encoding = EmptyEncoding()
        elif "position_encoding" in self.cfg and self.cfg.position_encoding == "region_encoding":
            self.positional_encoding = RegionPositionEncoding(self.cfg.hidden_dim)
        else:
            self.positional_encoding = PositionalEncoding(self.cfg.hidden_dim)

        log.info(f"Using {type(self.positional_encoding).__name__} for positional encoding")

        self.layer_norm = nn.LayerNorm(cfg.hidden_dim)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, input_specs, positions=None):
        input_specs = self.in_proj(input_specs)
        if isinstance(self.positional_encoding, PositionalEncoding):
            input_specs, pos_enc = self.positional_encoding(input_specs)
        else:
            input_specs, pos_enc = self.positional_encoding(input_specs, positions=positions)
        input_specs = self.layer_norm(input_specs)
        
        ## Sensor dropout
        if self.training: 
            if input_specs.size(1) > 2:  
                apply_mask = torch.rand(input_specs.size(1)) < self.cfg.get("sensor_dropout", 0) ## true if need to perform dropout
                apply_mask[0] = False ## cls token is never dropped 
                input_specs[:, apply_mask, :] = torch.zeros(input_specs[:, apply_mask, :].size(), device=input_specs.device)

        input_specs = self.dropout(input_specs)
        return input_specs, pos_enc
