import json
import os.path
from typing import Optional, Tuple, List
import torch
from torch.nn import Module, Sequential, Conv2d, ConvTranspose2d, BatchNorm2d, ReLU
from torch.utils import tensorboard

from Utils import logger, config as cfg
from Utils.Constants import FileNamesConstants


class FeatureMapAutoEncoder(Module):

    def __init__(self, layers_shapes: List[Tuple[Tuple[int, int], int]], tb_log_path: Optional[str]):
        """
        Feature map auto encoder model
        :param layers_shapes: list of convolutions layers shapes each entry is a tuple where first entry is the
        kernel size and second is the number of channels (filters),
        e.g., [((3,3), 8)] this is a single layer AE with conv with size of 3,3 and 8 filters
        :param tb_log_path: path for tensorboard log
        """
        # layers_shapes = ([(1, 5), 12])
        super(FeatureMapAutoEncoder, self).__init__()
        self._data_type_is_double = True
        input_shape = (cfg.stats_used_in_feature_map, cfg.feature_map_layers_size, cfg.feature_map_values_size)
        curr_in_channels = input_shape[0]
        self._encoder = Sequential()
        for curr_layer in layers_shapes:
            self._encoder.append(Conv2d(in_channels=curr_in_channels, out_channels=curr_layer[1],
                                        kernel_size=curr_layer[0]))
            curr_in_channels = curr_layer[1]
            self._encoder.append(BatchNorm2d(curr_in_channels))
            self._encoder.append(ReLU())

        self._decoder = Sequential()
        for curr_layer in reversed(layers_shapes):
            self._decoder.append(ConvTranspose2d(in_channels=curr_in_channels, out_channels=curr_layer[1],
                                                 kernel_size=curr_layer[0]))
            curr_in_channels = curr_layer[1]
            self._decoder.append(BatchNorm2d(curr_in_channels))
            self._decoder.append(ReLU())
        self._decoder.append(ConvTranspose2d(in_channels=curr_in_channels, out_channels=input_shape[0],
                                             kernel_size=(1, 1)))

        if self._data_type_is_double:
            self.double()
        self.to(cfg.device)
        logger().log('FeatureMapAutoEncoder::__init__',
                     f'Auto encoder is double=', self._data_type_is_double, ' and device in config is=', cfg.device,
                     ' model device is cuda=', next(self.parameters()).is_cuda)

        if tb_log_path is not None:
            self._tb_writer = tensorboard.SummaryWriter(os.path.join(tb_log_path, 'logs'))
            self._tb_writer.add_graph(model=self, input_to_model=torch.rand(2, *input_shape, dtype=torch.double,
                                                                            device=cfg.device))
            with open(os.path.join(tb_log_path, FileNamesConstants.MODEL_HYPER_PARAMS), 'w') as jf:
                json.dump(dict(layers_shapes=layers_shapes, tb_log_path=tb_log_path), jf)
        else:
            self._tb_writer = None

    def forward(self, inp):
        x = self._encoder(inp)
        out = self._decoder(x)
        return out

    def tb_writer(self):
        return self._tb_writer

    def encoder(self):
        return self._encoder

    def is_double(self):
        return self._data_type_is_double
