from functools import reduce
from threading import Lock

import torch
from torch.nn import Module
import torch.nn.functional as F

from coach.coding import SensorDecoder, SensorEncoder
from coach.traverse import Traverse

'''
Intended usage:

class MyModel(Module):
    def __init__(self):
        self.sensoriplexer = Sensoriplexer()
        self.sensoriplexer.add('image', (48, 48, 1))
        self.sensoriplexer.add('sound', (221, 1, 2))
        ...

    def forward(self, x, y):
        # Wrap the input, expected a tensor with compatible shapes.
        signals = self.sensoriplexer(image=x, sound=y)
        xbar = signals['image']
        ...
'''

class DuplicateError(Exception):
    def __init__(self, key:str):
        super().__init__(f"{key} already exists. Please choose another name or `remove` it first.")


class Sensoriplexer(Module):

    def __init__(self, code_size:int=512, dtype=torch.float32, device=torch.device('cpu'), with_residuals:bool=False):
        super().__init__()

        self.signals = {}
        self.encoders = torch.nn.ModuleDict({})
        self.decoders = torch.nn.ModuleDict({})
        self.traverse = None
        self._upper_flat_shape = int(1e4)

        self.with_residuals = with_residuals
        self.signal_code_size = code_size
        self.dtype = dtype
        self.device = device

        self.mod_lock = Lock()

    def add(self, key:str, shape:tuple) -> None:
        try:
            self.mod_lock.acquire()
            if key in self.signals:
                raise DuplicateError(key)
            self.signals[key] = self.__plug(key, shape)
        finally:
            self.mod_lock.release()

    def _manageable_shape(self, shape:list) -> bool:
        flat_shape = reduce(lambda x, y: x*y, list(shape))
        return flat_shape <= self._upper_flat_shape

    def remove(self, key:str) -> None:
        try:
            self.mod_lock.acquire()
            del self.signals[key]
            del self.encoders[key]
            del self.decoders[key]
        finally:
            self.mod_lock.release()

    def __update_traverse(self, enc:SensorEncoder, dec:SensorDecoder) -> None:
        self.traverse = Traverse(self.signal_code_size, len(self.encoders), self.dtype, device=self.device)

    def __plug(self, key:str, shape:tuple) -> dict:
        enc, dec = self.__coding(shape)

        self.encoders.update({ key: enc })
        self.decoders.update({ key: dec })

        self.__update_traverse(enc, dec)

        return {'shape': shape, 'enc': enc, 'dec': dec}

    def __coding(self, shape:tuple) -> tuple:
        enc = SensorEncoder(shape, self.signal_code_size, with_residuals=self.with_residuals)
        dec = SensorDecoder(shape, self.signal_code_size, enc.unflatten_output_size, with_residuals=self.with_residuals)
        return (enc, dec)

    def __zero_input(self, batch_size, key):
        return torch.zeros((batch_size,) + self.signals[key]['shape'], dtype=self.dtype, device=self.device)

    def forward(self, kwargs):
        try:
            self.mod_lock.acquire()

            batch_size = None
            for signal in kwargs.values():
                batch_size = signal.shape[0]
                break
            if batch_size is None:
                raise Exception(f'Please specify at least one signal input among {", ".join(self.signals.keys())}')

            tallied_indices = [
                self.encoders[k](kwargs[k])
                if k in kwargs.keys() else self.encoders[k](self.__zero_input(batch_size, k)) for k in self.signals.keys()
            ]
            tallied, indices = list(zip(*tallied_indices))
            tallied_signals = torch.cat(tallied, dim=1)

            encoded_results = torch.chunk(self.traverse(tallied_signals), len(self.signals), dim=1)

            results = {
                k: self.decoders[k]((encoded_results[idx], indices[idx], (batch_size,) + self.signals[k]['shape']))[0]
                for idx, k in enumerate(self.signals.keys())
            }

            return results
        finally:
            self.mod_lock.release()
