import copy
import torch
import numpy as np
from torch import nn
from torch.nn.utils.parametrizations import weight_norm

from baseline.bendr.model.trainable.utils import _make_mask


class _SingleAxisOperation(nn.Module):
 def __init__(self, axis=-1):
 super().__init__()
 self.axis = axis

 def forward(self, x):
 raise NotImplementedError

# Some general purpose convenience layers
# ---------------------------------------


class Expand(_SingleAxisOperation):
 def forward(self, x):
 return x.unsqueeze(self.axis)


class Squeeze(_SingleAxisOperation):
 def forward(self, x):
 return x.squeeze(self.axis)


class Permute(nn.Module):
 def __init__(self, axes):
 super().__init__()
 self.axes = axes

 def forward(self, x):
 return x.permute(self.axes)


class Concatenate(_SingleAxisOperation):
 def forward(self, *x):
 if len(x) == 1 and isinstance(x[0], tuple):
 x = x[0]
 return torch.cat(x, dim=self.axis)


class IndexSelect(nn.Module):
 def __init__(self, indices):
 super().__init__()
 assert isinstance(indices, (int, list, tuple))
 if isinstance(indices, int):
 indices = [indices]
 self.indices = list()
 for i in indices:
 assert isinstance(i, int)
 self.indices.append(i)

 def forward(self, *x):
 if len(x) == 1 and isinstance(x[0], tuple):
 x = x[0]
 if len(self.indices) == 1:
 return x[self.indices[0]]
 return [x[i] for i in self.indices]


class Flatten(nn.Module):
 def forward(self, x):
 return x.contiguous().view(x.size(0), -1)


class ConvBlock2D(nn.Module):
 """
 Implements complete convolution block with order:
 - Convolution
 - dropout (spatial)
 - activation
 - batch-norm
 - (optional) residual reconnection
 """

 def __init__(self, in_filters, out_filters, kernel, stride=(1, 1), padding=0, dilation=1, groups=1, do_rate=0.5,
 batch_norm=True, activation=nn.LeakyReLU, residual=False):
 super().__init__()
 self.kernel = kernel
 self.activation = activation()
 self.residual = residual

 self.conv = nn.Conv2d(in_filters, out_filters, kernel, stride=stride, padding=padding, dilation=dilation,
 groups=groups, bias=not batch_norm)
 self.dropout = nn.Dropout2d(p=do_rate)
 self.batch_norm = nn.BatchNorm2d(out_filters)

 def forward(self, input, **kwargs):
 res = input
 input = self.conv(input, **kwargs)
 input = self.dropout(input)
 input = self.activation(input)
 input = self.batch_norm(input)
 return input + res if self.residual else input

# ---------------------------------------


# New layers
# ---------------------------------------


class DenseFilter(nn.Module):
 def __init__(self, in_features, growth_rate, filter_len=5, do=0.5, bottleneck=2, activation=nn.LeakyReLU, dim=-2):
 """
 This DenseNet-inspired filter block features in the TIDNet network from Kostas & Rudzicz 2020 (Thinker
 Invariance). 2D convolution is used, but with a kernel that only spans one of the dimensions. In TIDNet it is
 used to develop channel operations independently of temporal changes.

 Parameters
 ----------
 in_features
 growth_rate
 filter_len
 do
 bottleneck
 activation
 dim
 """
 super().__init__()
 dim = dim if dim > 0 else dim + 4
 if dim < 2 or dim > 3:
 raise ValueError('Only last two dimensions supported')
 kernel = (filter_len, 1) if dim == 2 else (1, filter_len)

 self.net = nn.Sequential(
 nn.BatchNorm2d(in_features),
 activation(),
 nn.Conv2d(in_features, bottleneck * growth_rate, 1),
 nn.BatchNorm2d(bottleneck * growth_rate),
 activation(),
 nn.Conv2d(bottleneck * growth_rate, growth_rate, kernel, padding=tuple((k // 2 for k in kernel))),
 nn.Dropout2d(do)
 )

 def forward(self, x):
 return torch.cat((x, self.net(x)), dim=1)


class DenseSpatialFilter(nn.Module):
 def __init__(self, channels, growth, depth, in_ch=1, bottleneck=4, dropout_rate=0.0, activation=nn.LeakyReLU,
 collapse=True):
 """
 This extends the :any:`DenseFilter` to specifically operate in channel space and collapse this dimension
 over the course of `depth` layers.

 Parameters
 ----------
 channels
 growth
 depth
 in_ch
 bottleneck
 dropout_rate
 activation
 collapse
 """
 super().__init__()
 self.net = nn.Sequential(*[
 DenseFilter(in_ch + growth * d, growth, bottleneck=bottleneck, do=dropout_rate,
 activation=activation) for d in range(depth)
 ])
 n_filters = in_ch + growth * depth
 self.collapse = collapse
 if collapse:
 self.channel_collapse = ConvBlock2D(n_filters, n_filters, (channels, 1), do_rate=0)

 def forward(self, x):
 if len(x.shape) < 4:
 x = x.unsqueeze(1).permute([0, 1, 3, 2])
 x = self.net(x)
 if self.collapse:
 return self.channel_collapse(x).squeeze(-2)
 return x


class SpatialFilter(nn.Module):
 def __init__(self, channels, filters, depth, in_ch=1, dropout_rate=0.0, activation=nn.LeakyReLU, batch_norm=True,
 residual=False):
 super().__init__()
 kernels = [(channels // depth, 1) for _ in range(depth-1)]
 kernels += [(channels - sum(x[0] for x in kernels) + depth-1, 1)]
 self.filter = nn.Sequential(
 ConvBlock2D(in_ch, filters, kernels[0], do_rate=dropout_rate/depth, activation=activation,
 batch_norm=batch_norm),
 *[ConvBlock2D(filters, filters, kernel, do_rate=dropout_rate/depth, activation=activation,
 batch_norm=batch_norm)
 for kernel in kernels[1:]]
 )
 self.residual = nn.Conv1d(channels * in_ch, filters, 1) if residual else None

 def forward(self, x):
 res = x
 if len(x.shape) < 4:
 x = x.unsqueeze(1)
 elif self.residual:
 res = res.contiguous().view(res.shape[0], -1, res.shape[3])
 x = self.filter(x).squeeze(-2)
 return x + self.residual(res) if self.residual else x


class TemporalFilter(nn.Module):

 def __init__(self, channels, filters, depth, temp_len, dropout=0., activation=nn.LeakyReLU, residual='netwise'):
 """
 This implements the dilated temporal-only spanning convolution from TIDNet.

 Parameters
 ----------
 channels
 filters
 depth
 temp_len
 dropout
 activation
 residual
 """
 super().__init__()
 temp_len = temp_len + 1 - temp_len % 2
 self.residual_style = str(residual)
 net = list()

 for i in range(depth):
 dil = depth - i
 conv = weight_norm(nn.Conv2d(channels if i == 0 else filters, filters, kernel_size=(1, temp_len),
 dilation=dil, padding=(0, dil * (temp_len - 1) // 2)))
 net.append(nn.Sequential(
 conv,
 activation(),
 nn.Dropout2d(dropout)
 ))
 if self.residual_style.lower() == 'netwise':
 self.net = nn.Sequential(*net)
 self.residual = nn.Conv2d(channels, filters, (1, 1))
 elif residual.lower() == 'dense':
 self.net = net

 def forward(self, x):
 if self.residual_style.lower() == 'netwise':
 return self.net(x) + self.residual(x)
 elif self.residual_style.lower() == 'dense':
 for l in self.net:
 x = torch.cat((x, l(x)), dim=1)
 return x



class _BENDREncoder(nn.Module):
 def __init__(self, in_features, encoder_h=256,):
 super().__init__()
 self.in_features = in_features
 self.encoder_h = encoder_h

 def load(self, filename, strict=True):
 state_dict = torch.load(filename)
 self.load_state_dict(state_dict, strict=strict)

 def save(self, filename):
 torch.save(self.state_dict(), filename)

 def freeze_features(self, unfreeze=False):
 for param in self.parameters():
 param.requires_grad = unfreeze


class ConvEncoderBENDR(_BENDREncoder):
 def __init__(self, in_features, encoder_h=256, enc_width=(3, 2, 2, 2, 2, 2),
 dropout=0., projection_head=False, enc_downsample=(3, 2, 2, 2, 2, 2)):
 super().__init__(in_features, encoder_h)
 self.encoder_h = encoder_h
 if not isinstance(enc_width, (list, tuple)):
 enc_width = [enc_width]
 if not isinstance(enc_downsample, (list, tuple)):
 enc_downsample = [enc_downsample]
 assert len(enc_downsample) == len(enc_width)

 # Centerable convolutions make life simpler
 enc_width = [e if e % 2 else e+1 for e in enc_width]
 self._downsampling = enc_downsample
 self._width = enc_width

 self.encoder = nn.Sequential()
 for i, (width, downsample) in enumerate(zip(enc_width, enc_downsample)):
 self.encoder.add_module("Encoder_{}".format(i), nn.Sequential(
 nn.Conv1d(in_features, encoder_h, width, stride=downsample, padding=width // 2),
 nn.Dropout1d(dropout),
 nn.GroupNorm(encoder_h // 2, encoder_h),
 nn.GELU(),
 ))
 in_features = encoder_h

 if projection_head:
 self.encoder.add_module("projection-1", nn.Sequential(
 nn.Conv1d(in_features, in_features, 1),
 nn.Dropout1d(dropout*2),
 nn.GroupNorm(in_features // 2, in_features),
 nn.GELU()
 ))

 def description(self, sfreq=None, sequence_len=None):
 widths = list(reversed(self._width))[1:]
 strides = list(reversed(self._downsampling))[1:]

 rf = self._width[-1]
 for w, s in zip(widths, strides):
 rf = rf if w == 1 else (rf - 1) * s + 2 * (w // 2)

 desc = "Receptive field: {} samples".format(rf)
 if sfreq is not None:
 desc += ", {:.2f} seconds".format(rf / sfreq)

 ds_factor = np.prod(self._downsampling)
 desc += " | Downsampled by {}".format(ds_factor)
 if sfreq is not None:
 desc += ", new sfreq: {:.2f} Hz".format(sfreq / ds_factor)
 desc += " | Overlap of {} samples".format(rf - ds_factor)
 if sequence_len is not None:
 desc += " | {} encoded samples/trial".format(sequence_len // ds_factor)
 return desc

 def downsampling_factor(self, samples):
 for factor in self._downsampling:
 samples = np.ceil(samples / factor)
 return samples

 def forward(self, x):
 return self.encoder(x)


# FIXME this is redundant with part of the contextualizer
class EncodingAugment(nn.Module):
 def __init__(self, in_features, mask_p_t=0.1, mask_p_c=0.01, mask_t_span=6, mask_c_span=64, dropout=0.1,
 position_encoder=25):
 super().__init__()
 self.mask_replacement = torch.nn.Parameter(torch.zeros(in_features), requires_grad=True)
 self.p_t = mask_p_t
 self.p_c = mask_p_c
 self.mask_t_span = mask_t_span
 self.mask_c_span = mask_c_span
 transformer_dim = 3 * in_features

 conv = nn.Conv1d(in_features, in_features, position_encoder, padding=position_encoder // 2, groups=16)
 nn.init.normal_(conv.weight, mean=0, std=2 / transformer_dim)
 nn.init.constant_(conv.bias, 0)
 conv = weight_norm(conv, dim=2)
 self.relative_position = nn.Sequential(conv, nn.GELU())

 self.input_conditioning = nn.Sequential(
 Permute([0, 2, 1]),
 nn.LayerNorm(in_features),
 nn.Dropout(dropout),
 Permute([0, 2, 1]),
 nn.Conv1d(in_features, transformer_dim, 1),
 )

 def forward(self, x, mask_t=None, mask_c=None):
 bs, feat, seq = x.shape

 if self.training:
 if mask_t is None and self.p_t > 0 and self.mask_t_span > 0:
 mask_t = _make_mask((bs, seq), self.p_t, x.shape[-1], self.mask_t_span)
 if mask_c is None and self.p_c > 0 and self.mask_c_span > 0:
 mask_c = _make_mask((bs, feat), self.p_c, x.shape[1], self.mask_c_span)

 if mask_t is not None:
 x.transpose(2, 1)[mask_t] = self.mask_replacement
 if mask_c is not None:
 x[mask_c] = 0

 x = self.input_conditioning(x + self.relative_position(x))
 return x

 def init_from_contextualizer(self, filename):
 state_dict = torch.load(filename)
 self.load_state_dict(state_dict, strict=False)
 for param in self.parameters():
 param.requires_grad = False
 print("Initialized mask embedding and position encoder from ", filename)


class _Hax(nn.Module):
 """T-fixup assumes self-attention norms are removed"""
 def __init__(self):
 super().__init__()

 def forward(self, x):
 return x


class BENDRContextualizer(nn.Module):

 def __init__(
 self,
 in_features,
 hidden_feedforward=3076,
 heads=8,
 layers=8,
 dropout=0.15,
 activation='gelu',
 position_encoder=25,
 layer_drop=0.0,
 mask_p_t=0.1,
 mask_p_c=0.004,
 mask_t_span=6,
 mask_c_span=64,
 start_token=-5,
 finetuning=False,

 ):
 super(BENDRContextualizer, self).__init__()

 self.dropout = dropout
 self.in_features = in_features
 self._transformer_dim = in_features * 3

 encoder = nn.TransformerEncoderLayer(d_model=in_features * 3, nhead=heads, dim_feedforward=hidden_feedforward,
 dropout=dropout, activation=activation)
 encoder.norm1 = _Hax()
 encoder.norm2 = _Hax()

 self.norm = nn.LayerNorm(self._transformer_dim)

 # self.norm_layers = nn.ModuleList([copy.deepcopy(norm) for _ in range(layers)])
 self.transformer_layers = nn.ModuleList([copy.deepcopy(encoder) for _ in range(layers)])
 self.layer_drop = layer_drop
 self.p_t = mask_p_t
 self.p_c = mask_p_c
 self.mask_t_span = mask_t_span
 self.mask_c_span = mask_c_span
 self.start_token = start_token
 self.finetuning = finetuning

 # Initialize replacement vector with 0's
 self.mask_replacement = torch.nn.Parameter(torch.normal(0, in_features**(-0.5), size=(in_features,)),
 requires_grad=True)

 self.position_encoder = position_encoder > 0
 if position_encoder:
 conv = nn.Conv1d(in_features, in_features, position_encoder, padding=position_encoder // 2, groups=16)
 nn.init.normal_(conv.weight, mean=0, std=2 / self._transformer_dim)
 nn.init.constant_(conv.bias, 0)
 conv = weight_norm(conv, dim=2)
 self.relative_position = nn.Sequential(conv, nn.GELU())

 self.input_conditioning = nn.Sequential(
 Permute([0, 2, 1]),
 nn.LayerNorm(in_features),
 nn.Dropout(dropout),
 Permute([0, 2, 1]),
 nn.Conv1d(in_features, self._transformer_dim, 1),
 Permute([2, 0, 1]),
 )

 self.output_layer = nn.Conv1d(self._transformer_dim, in_features, 1)
 self.apply(self.init_bert_params)

 def init_bert_params(self, module):
 if isinstance(module, nn.Linear):
 # module.weight.data.normal_(mean=0.0, std=0.02)
 nn.init.xavier_uniform_(module.weight.data)
 if module.bias is not None:
 module.bias.data.zero_()
 # Tfixup
 module.weight.data = 0.67 * len(self.transformer_layers) ** (-0.25) * module.weight.data

 # if isinstance(module, nn.Conv1d):
 # # std = np.sqrt((4 * (1.0 - self.dropout)) / (self.in_features * self.in_features))
 # # module.weight.data.normal_(mean=0.0, std=std)
 # nn.init.xavier_uniform_(module.weight.data)
 # module.bias.data.zero_()

 def forward(self, x, mask_t=None, mask_c=None):
 bs, feat, seq = x.shape
 if self.training and self.finetuning:
 if mask_t is None and self.p_t > 0:
 mask_t = _make_mask((bs, seq), self.p_t, x.shape[-1], self.mask_t_span)
 if mask_c is None and self.p_c > 0:
 mask_c = _make_mask((bs, feat), self.p_c, x.shape[1], self.mask_c_span)

 # Multi-gpu workaround, wastes memory
 x = x.clone()

 if mask_t is not None:
 x.transpose(2, 1)[mask_t] = self.mask_replacement
 if mask_c is not None:
 x[mask_c] = 0

 if self.position_encoder:
 x = x + self.relative_position(x)
 x = self.input_conditioning(x)

 if self.start_token is not None:
 in_token = self.start_token * torch.ones((1, 1, 1), requires_grad=True).to(x.device).expand([-1, *x.shape[1:]])
 x = torch.cat([in_token, x], dim=0)

 for layer in self.transformer_layers:
 if not self.training or torch.rand(1) > self.layer_drop:
 x = layer(x)

 return self.output_layer(x.permute([1, 2, 0]))

 def freeze_features(self, unfreeze=False, finetuning=False):
 for param in self.parameters():
 param.requires_grad = unfreeze
 if self.finetuning or finetuning:
 self.mask_replacement.requires_grad = False

 def load(self, filename, strict=True):
 state_dict = torch.load(filename)
 self.load_state_dict(state_dict, strict=strict)

 def save(self, filename):
 torch.save(self.state_dict(), filename)
