from .utils_UNET import *

import math
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np




class AbstractSelector(nn.Module):
  def __init__(self,input_size = (1,28,28), output_size = (1,28,28)):
    super().__init__()

    self.input_size = input_size
    self.output_size = output_size
    if self.output_size is None:
      self.output_size = self.input_size
    
  def __call__(self, x):
    raise NotImplementedError



def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

class SelectorRealX(AbstractSelector):
  def __init__(self,input_size = (1,28,28), output_size = (1,28,28)):
    super().__init__(input_size = input_size, output_size = output_size)
    self.fc1 = nn.Linear(np.prod(self.input_size), 100)
    self.fc2 = nn.Linear(100, 100)
    self.pi = nn.Linear(100, np.prod(self.output_size))

    
    self.fc1.apply(init_weights)
    self.fc2.apply(init_weights)
    self.pi.apply(init_weights)

  def __call__(self, x):
    x = x.flatten(1) # Batch_size, Channels* SizeProduct
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    return self.pi(x)

  
class RealXSelector(AbstractSelector):
  def __init__(self, input_size, output_size, middle_size = 100):
    super().__init__(input_size = input_size, output_size = output_size)
    

    self.fc1 = nn.Linear(np.prod(self.input_size),middle_size)
    self.fc2 = nn.Linear(middle_size,middle_size)
    self.pi = nn.Linear(middle_size, np.prod(self.output_size))

  def __call__(self, x):
    x = x.flatten(1)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    return self.pi(x)



def calculate_blocks_patch(input_size, kernel_size, kernel_stride):
  """
  Calculates the number of blocks in a patch.
  """
  size = [1,]
  for k in range(len(kernel_size)):
    size.append(math.floor((input_size[k+1] - (kernel_size[k]-1) -1 ) / kernel_stride[k]) + 1)
  return tuple(size)


class SelectorUNET(AbstractSelector):
    def __init__(self,  input_size = (1,28,28), output_size = (1, 28, 28), kernel_size = (1,1), kernel_stride = (1,1), bilinear = True, log2_min_channel = 6):

      aux_output_size = calculate_blocks_patch(input_size, kernel_size, kernel_stride)
      assert aux_output_size == output_size, "Output size of the selector must be the same as the output size of the unet."
      super().__init__(input_size = input_size, output_size = output_size)
      self.input_size = input_size
      self.output_size = output_size
      self.channels = self.input_size[0]
      self.out_channel = self.output_size[0]
      self.w = self.input_size[1]
      self.h = self.input_size[2]
      self.bilinear = bilinear
      self.log2_min_channel = log2_min_channel
      self.factor = 2 if self.bilinear else 1

    
      self.nb_block = int(math.log(min(self.output_size[1], self.output_size[2]), 2)//2)
      self.dim_latent = (2**(self.log2_min_channel+self.nb_block)/self.factor, self.w / (2**self.nb_block), self.h/(2**self.nb_block))
      self.getconfiguration = nn.Sequential(*[
        nn.Conv2d(self.channels, 2**self.log2_min_channel, kernel_size = kernel_size, stride = kernel_stride),
        nn.ReLU(inplace = False),
        nn.Conv2d(2**self.log2_min_channel, 2**self.log2_min_channel, kernel_size = 3, padding=1),
        nn.BatchNorm2d(2**self.log2_min_channel),
        nn.ReLU(inplace = False),
      ])

      self.UNET = UNet(n_classes = 1, bilinear = self.bilinear, nb_block = self.nb_block, log2_min_channel=self.log2_min_channel)



    def __call__(self, x):
      batch_size = x.shape[0]
      x = self.getconfiguration(x)
      x = self.UNET(x)
      x = x.view(batch_size, np.prod(self.output_size))

      return x

    def encode(self, x):
      x = self.getconfiguration(x)
      x = self.UNET.encode(x)
      return x

    def decode(self, x):
      x = self.UNET.decode(x)
      return x
    

selection_network_list = {
  "SelectorREALX" : SelectorRealX,
  "SelectorUNET" : SelectorUNET,
}

def get_selection_network(selector_name):
  if selector_name is None or selector_name == "none" :
    return None
  elif selector_name in selection_network_list.keys():
    return selection_network_list[selector_name]
  else:
    raise NotImplementedError(f"This selector {selector_name} is not implemented")