import math
from dataclasses import dataclass
from enum import Enum
from typing import NamedTuple, Tuple

import torch
from choices import *
from config_base import BaseConfig
from torch import nn
from torch.nn import init

from .blocks import *
from .nn import timestep_embedding
from .unet import *


class LatentNetType(Enum):
	none = 'none'
	# injecting inputs into the hidden layers
	skip = 'skip'


class LatentNetReturn(NamedTuple):
	pred: torch.Tensor = None


@dataclass
class MLPSkipNetConfig(BaseConfig):
	"""
	default MLP for the latent DPM in the paper!
	"""
	num_channels: int
	skip_layers: Tuple[int]
	num_hid_channels: int
	num_layers: int
	num_time_emb_channels: int = 64
	activation: Activation = Activation.silu
	use_norm: bool = True
	condition_bias: float = 1
	dropout: float = 0
	last_act: Activation = Activation.none
	num_time_layers: int = 2
	time_last_act: bool = False
	mask_threshold: float = None

	def make_model(self):
		return MLPSkipNet(self)


class MLPSkipNet(nn.Module):
	"""
	concat x to hidden layers

	default MLP for the latent DPM in the paper!
	"""
	def __init__(self, conf: MLPSkipNetConfig):
		super().__init__()
		self.conf = conf

		layers = []
		for i in range(conf.num_time_layers):
			if i == 0:
				a = conf.num_time_emb_channels
				b = conf.num_channels
			else:
				a = conf.num_channels
				b = conf.num_channels
			layers.append(nn.Linear(a, b))
			if i < conf.num_time_layers - 1 or conf.time_last_act:
				layers.append(conf.activation.get_act())
		self.time_embed = nn.Sequential(*layers)

		self.layers = nn.ModuleList([])
		for i in range(conf.num_layers):
			if i == 0:
				act = conf.activation
				norm = conf.use_norm
				cond = True
				a, b = conf.num_channels, conf.num_hid_channels
				dropout = conf.dropout
			elif i == conf.num_layers - 1:
				act = Activation.none
				norm = False
				cond = False
				a, b = conf.num_hid_channels, conf.num_channels
				dropout = 0
			else:
				act = conf.activation
				norm = conf.use_norm
				cond = True
				a, b = conf.num_hid_channels, conf.num_hid_channels
				dropout = conf.dropout

			if i in conf.skip_layers:
				a += conf.num_channels

			self.layers.append(
				MLPLNAct(
					a,
					b,
					norm=norm,
					activation=act,
					cond_channels=conf.num_channels,
					use_cond=cond,
					condition_bias=conf.condition_bias,
					dropout=dropout,
				))
		self.last_act = conf.last_act.get_act()

	def forward(self, x, t, **kwargs):
		t = timestep_embedding(t, self.conf.num_time_emb_channels)
		cond = self.time_embed(t)
		h = x
		for i in range(len(self.layers)):
			if i in self.conf.skip_layers:
				# injecting input into the hidden layers
				h = torch.cat([h, x], dim=1)
				# print(cond.size())
				# print(self.layers[i])
			h = self.layers[i].forward(x=h, cond=cond)
		h = self.last_act(h)
		return LatentNetReturn(h)


class MLPLNAct(nn.Module):
	def __init__(
		self,
		in_channels: int,
		out_channels: int,
		norm: bool,
		use_cond: bool,
		activation: Activation,
		cond_channels: int,
		condition_bias: float = 0,
		dropout: float = 0,
	):
		super().__init__()
		self.activation = activation
		self.condition_bias = condition_bias
		self.use_cond = use_cond

		# print('in = {}, out = {}'.format(in_channels, out_channels))
		self.linear = nn.Linear(in_channels, out_channels)
		self.act = activation.get_act()
		if self.use_cond:
			self.linear_emb = nn.Linear(cond_channels, out_channels)
			self.cond_layers = nn.Sequential(self.act, self.linear_emb)
		if norm:
			self.norm = nn.LayerNorm(out_channels)
		else:
			self.norm = nn.Identity()

		if dropout > 0:
			self.dropout = nn.Dropout(p=dropout)
		else:
			self.dropout = nn.Identity()

		self.init_weights()

	def init_weights(self):
		for module in self.modules():
			if isinstance(module, nn.Linear):
				if self.activation == Activation.relu:
					init.kaiming_normal_(module.weight,
										a=0,
										nonlinearity='relu')
				elif self.activation == Activation.lrelu:
					init.kaiming_normal_(module.weight,
										a=0.2,
										nonlinearity='leaky_relu')
				elif self.activation == Activation.silu:
					init.kaiming_normal_(module.weight,
										a=0,
										nonlinearity='relu')
				else:
					# leave it as default
					pass

	def forward(self, x, cond=None):
		# print('x size = {}'.format(x.size()))
		# print(self.linear)
		x = self.linear(x)
		if self.use_cond:
			# (n, c) or (n, c * 2)
			cond = self.cond_layers(cond)
			cond = (cond, None)

			# scale shift first
			x = x * (self.condition_bias + cond[0])
			if cond[1] is not None:
				x = x + cond[1]
			# then norm
			x = self.norm(x)
		else:
			# no condition
			x = self.norm(x)
		x = self.act(x)
		x = self.dropout(x)
		return x