import torch
import torch.nn as nn

from datasets.cater.enums import MAX_FRAMES

from math import ceil, floor

import logging

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

class TemporalQuantization(nn.Module):
	def __init__(self, num_atomic_events, time_dim=MAX_FRAMES, time_dim_quant=4, time_mult=True, fill=None, bias=False):
		super().__init__()
		self.time_dim = time_dim
		# this provides the temporal output shape of |time_dim_quant|
		self.time_dim_quant = time_dim_quant

		self.quantize = self.time_dim > time_dim_quant

		# usually the time quant dim is not feasible given the input shape, therefore this is the actual dimension of the temporal dimension
		self.time_dim_quant_feasible = None

		self.time_mult = time_mult
		self.num_atomic_events = num_atomic_events
		
		# quantize dimension is less then the input dimension
		if self.quantize:
			# softly sum up the probs, leaving parameters for each event type
			# ie conv1d.weight.shape = [num_atomic_events, 1 (out channels), kernel_size]
			kernel_size, stride = self.bin_size(time_dim, time_dim_quant)
			logger.info(f'Conv Kernel Size: {kernel_size} Stride: {stride}')
			self.conv1d = nn.Conv1d(num_atomic_events, num_atomic_events, kernel_size, stride=stride, groups=num_atomic_events, bias=bias)
			# self.conv1d = nn.Conv1d(num_atomic_events, num_atomic_events, 3, groups=num_atomic_events, bias=bias)
			
			# initialize the weights with some value
			if fill is not None:
				# self.conv1d.weight.data.fill_(fill)
				self.conv1d.weight.data.fill_(fill)
				if bias:
					self.conv1d.bias.data.fill_(0.0)
		
		# self.conv_scalar = 1e-4
		self.conv_scalar = torch.nn.Parameter(torch.tensor([1.0], requires_grad=True))

	def bin_size(self, time_dim, time_dim_quant):
		# compute the rel
		kernel_size = ceil(time_dim/time_dim_quant)
		stride = floor(time_dim/time_dim_quant)
		return kernel_size, stride

	def forward(self, ts):
		bs, n_ts, n_events, dim_ts = ts.shape
		
		ts = ts.reshape(bs * n_ts, n_events, dim_ts)
		if self.quantize:
			# stack all batches and all timeseries
			ts = self.conv1d(ts)
			ts = ts * self.conv_scalar

		ts = ts.reshape(bs, n_ts, n_events, ts.shape[-1])
		
		if self.time_mult:
			# ts = ts * self.time_mat.repeat(bs, n_ts, 1, 1).to(ts.device)
			ts = ts * torch.arange(ts.shape[-1]).repeat(bs, n_ts, self.num_atomic_events, 1).to(ts.device)

		if self.time_dim_quant_feasible is None:
			self.time_dim_quant_feasible = ts.shape[-1]
			if self.time_dim_quant_feasible != self.time_dim_quant:
				logger.warning(f'time dim quant {self.time_dim_quant} is not feasible given time dim {self.time_dim},\
					 actual quant dim used is {self.time_dim_quant_feasible}')
			else:
				logger.info(f'feasible time quant dim {self.time_dim_quant_feasible}')
		return ts