import copy

import torch
import torch.nn as nn

from .transformer_components import *

# Pretrain Heads

class MHProjectorMlpSmall(torch.nn.Module):
	def __init__(self, input_size, hidden_layer, dropout, output_size, num_heads):
		super(MHProjectorMlpSmall, self).__init__()

		self.mlp = nn.Sequential(
			nn.Linear(input_size, hidden_layer),
			nn.ReLU(inplace=True),
            nn.Dropout(dropout),
			nn.Linear(hidden_layer, output_size)
		)

		self.heads = nn.ModuleList([copy.deepcopy(self.mlp) for _ in range(num_heads)])

	def forward(self, x):
		return [head(x) for head in self.heads]

class MHProjectorLinearProbe(torch.nn.Module):
	def __init__(self, input_size, output_size, num_heads):
		super(MHProjectorLinearProbe, self).__init__()

		self.heads = nn.ModuleList([copy.deepcopy(nn.Linear(input_size, output_size)) for _ in range(num_heads)])

	def forward(self, x):
		return [head(x) for head in self.heads]

# Finetune heads

class FinetuneHeadMlp(nn.Module):
    def __init__(self, input_size, est_head_dim_1, est_head_dim_2, est_head_dropout_1, est_head_dropout_2, output_size):
        super(FinetuneHeadMlp, self).__init__()

        self.mlp = torch.nn.Sequential(
			nn.Linear(input_size, est_head_dim_1),
			nn.ReLU(inplace=True),
			nn.Dropout(est_head_dropout_1),
			nn.Linear(est_head_dim_1, est_head_dim_2),
			nn.ReLU(inplace=True),
			nn.Dropout(est_head_dropout_2),
			nn.Linear(est_head_dim_2, output_size)
		)

    def forward(self, x):
        return self.mlp(x)
    
class FinetuneHeadLinearProbe(nn.Module):
    def __init__(self, input_size, output_size):
        super(FinetuneHeadLinearProbe, self).__init__()
        self.linear = nn.Linear(input_size, output_size)
		
    def forward(self, x):
        return self.linear(x)

# Encoder Backbone

"""
The Transformer Code is mainly based on the following sources:
https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html
https://github.com/gzerveas/mvts_transformer
"""

class TransformerBackbone(nn.Module):

    def __init__(self, feat_dim, seq_len, d_model, n_heads, num_layers, dim_feedforward, est_head_dim, output_size,
                 dropout_rate=0.2, dropout_rate_est = 0.2, torch_imp = False, activation = "relu", norm = "layernorm", skipconnections = True, seed = 42, output_act = True):
        super(TransformerBackbone, self).__init__()

        self.feat_dim = feat_dim
        self.seq_len = seq_len
        self.d_model = d_model
        self.n_heads = n_heads
        self.num_layers = num_layers
        self.dim_feedforward = dim_feedforward
        self.est_head_dim = est_head_dim
        self.output_size = output_size
        self.dropout_rate = dropout_rate
        self.dropout_rate_est = dropout_rate_est
        self.activation = activation
        self.norm = norm
        self.skipconnections = skipconnections
        self.output_act = output_act

        if self.activation == "gelu":
            self.act = nn.GELU()
        elif self.activation == "relu":
            self.act = nn.ReLU()
        else:
            print("no valid activation selection, falling back to relu")
            self.act = nn.ReLU()

        self.dropout1 = nn.Dropout(dropout_rate)

        self.project_inp = nn.Linear(feat_dim, d_model)
        self.pos_enc = LearnablePositionalEncoding(d_model, max_len=self.seq_len)

        if torch_imp:
            encoder_layer = nn.TransformerEncoderLayer(self.d_model, self.n_heads, self.dim_feedforward, dropout = self.dropout_rate, activation=self.activation, batch_first=True)
            self.transformer_encoder = nn.TransformerEncoder(encoder_layer, self.num_layers)#-1)
        else:
            self.transformer_encoder = TransformerEncoder(num_layers = self.num_layers, 
                                                          d_model = self.d_model, 
                                                          num_heads = self.n_heads, 
                                                          dim_feedforward = self.dim_feedforward, 
                                                          dropout=self.dropout_rate,
                                                          activation = self.activation,
                                                          norm = self.norm,
                                                          skipconnections = self.skipconnections,
                                                          seed = seed)

        self.output_net = nn.Linear(self.d_model * self.seq_len, self.output_size)

    def forward(self, X, get_attention_maps=False):

        # Input Projection and Positional Embedding
        inp = self.project_inp(X) * math.sqrt(self.d_model)
        inp = self.pos_enc(inp)

        # Encoder
        if get_attention_maps:
            attention_maps = self.transformer_encoder.get_attention_maps(inp)
        
        output = self.transformer_encoder(inp) 

        if self.output_act == True:
            output = self.act(output)
            output = self.dropout1(output)
            
        output = output.reshape(output.shape[0], -1)

        # Estimation Head
        output = self.output_net(output)
        if get_attention_maps:
            return output, attention_maps
        else:
            return output