import torch
import torch.nn as nn
from torch.nn import functional as F
from model.layers.avg_window import AvgWindow

class DoubleConv1D(nn.Module):
	def __init__(self, in_channels, out_channels, kernel_size, dropout_rate=0):
		super().__init__()
		self.double_conv = nn.Sequential(
			nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2),
			nn.BatchNorm1d(out_channels),
			nn.ReLU(inplace=True),
			nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2),
			nn.BatchNorm1d(out_channels),
			nn.ReLU(inplace=True)
		)
		self.dropout = nn.Dropout(dropout_rate)

	def forward(self, x):
		x = self.double_conv(x)
		return self.dropout(x)


class Down1D(nn.Module):
	def __init__(self, in_channels, out_channels, kernel_size, sampling_scale, dropout_rate):
		super().__init__()
		self.maxpool_conv = nn.Sequential(
			nn.MaxPool1d(sampling_scale),
			DoubleConv1D(in_channels, out_channels, kernel_size, dropout_rate)
		)

	def forward(self, x):
		x = self.maxpool_conv(x)
		return x


class Up1D(nn.Module):
	def __init__(self, in_channels, out_channels, kernel_size, sampling_scale, dropout_rate, bilinear=True):
		super().__init__()
		if bilinear:
			self.up = nn.Upsample(scale_factor=sampling_scale, mode='linear', align_corners=True)
		else:
			self.up = nn.ConvTranspose1d(in_channels // 2, in_channels // 2, kernel_size=sampling_scale, stride=sampling_scale)

		self.conv = DoubleConv1D(in_channels, out_channels, kernel_size, dropout_rate)

	def forward(self, x1, x2):
		x1 = self.up(x1)
		diff = x2.size()[2] - x1.size()[2]
		x1 = F.pad(x1, [diff // 2, diff - diff // 2])
		x = torch.cat([x2, x1], dim=1)
		return self.conv(x)

# Task embed
class TaskEncoder(nn.Module):
	def __init__(self, task_num, emb_len, out_len, out_channel):
		super(TaskEncoder, self).__init__()
		self.out_len = out_len
		self.out_channel = out_channel
		self.encoder = nn.Sequential(
			nn.Embedding(task_num, emb_len),
			nn.ReLU(inplace=True),
			nn.Linear(emb_len, out_len * out_channel)
		)

	def forward(self, x):
		x = self.encoder(x)
		x = x.view(x.size(0), self.out_channel, self.out_len)
		return x


class OutConv1D(nn.Module):
	def __init__(self, in_channels, out_channels):
		super(OutConv1D, self).__init__()
		self.outConv = nn.Sequential(
			nn.Conv1d(in_channels, out_channels, kernel_size=1),
		)

	def forward(self, x):
		return self.outConv(x)


class EventComprehensionNet(nn.Module):
	def __init__(self, n_channels, convchannels, kernel_size, sampling_scale, n_classes, input_len, task_num, dropout_rate, moving_avg_size, bilinear=True):
		super(EventComprehensionNet, self).__init__()
		self.n_channels = n_channels
		self.n_classes = n_classes
		self.bilinear = bilinear
		self.task_embed_out_len = input_len // sampling_scale[0] // sampling_scale[1] // sampling_scale[2] // sampling_scale[3]

		self.inc = DoubleConv1D(n_channels, convchannels, kernel_size, dropout_rate)
		self.down1 = Down1D(convchannels, convchannels*2, kernel_size, sampling_scale[0], dropout_rate)
		self.down2 = Down1D(convchannels*2, convchannels*4, kernel_size, sampling_scale[1], dropout_rate)
		self.down3 = Down1D(convchannels*4, convchannels*8, kernel_size, sampling_scale[2], dropout_rate)
		self.down4 = Down1D(convchannels*8, convchannels*8, kernel_size, sampling_scale[3], dropout_rate)
		self.up1 = Up1D(convchannels*16, convchannels*4, kernel_size, sampling_scale[3], dropout_rate, bilinear)
		self.up2 = Up1D(convchannels*8, convchannels*2, kernel_size, sampling_scale[2], dropout_rate, bilinear)
		self.up3 = Up1D(convchannels*4, convchannels, kernel_size, sampling_scale[1], dropout_rate, bilinear)
		self.up4 = Up1D(convchannels*2, convchannels, kernel_size, sampling_scale[0], dropout_rate, bilinear)
		self.taskEncoder = TaskEncoder(task_num, 10, self.task_embed_out_len, convchannels*8)
		self.outc = OutConv1D(convchannels, n_classes)
		self.avg_window = AvgWindow(moving_avg_size)


	def forward(self, x, taskID):
		x1 = self.inc(x)
		x2 = self.down1(x1)
		x3 = self.down2(x2)
		x4 = self.down3(x3)
		x5 = self.down4(x4)
		taskEncode = self.taskEncoder(taskID)
		x = x5 + taskEncode
		x = self.up1(x, x4)
		x = self.up2(x, x3)
		x = self.up3(x, x2)
		x = self.up4(x, x1)
		out = self.outc(x)
		out = self.avg_window(out)
		return out
