from functools import partial
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import transforms
import torchvision.transforms.functional as TF

from conf.dataset import CelebAParams


class toTensorNoNorm(nn.Module):
    """
    Cast to tensor without normalizing the values
    """
    def forward(self, tensor):
        return torch.Tensor(tensor)


class toType(nn.Module):
    def __init__(self, target_type):
        super(toType, self).__init__()
        self.target_type = target_type

    def forward(self, x):
        return x.to(self.target_type)


class Permute(nn.Module):
    def __init__(self, permutation):
        super(Permute, self).__init__()
        self.permutation = permutation

    def forward(self, x):
        return x.permute(self.permutation)


class Reshape(nn.Module):
    def __init__(self, new_shape):
        super(Reshape, self).__init__()
        self.new_shape = new_shape

    def forward(self, img):
        return img.reshape(self.new_shape)


class FuseOnChannel(nn.Module):
    def __init__(self, pass_dim: bool = False):
        super(FuseOnChannel, self).__init__()
        self.pass_dim = pass_dim

    def forward(self, concat_domains):
        fused = torch.cat(concat_domains, dim=0)

        if self.pass_dim:
            return fused, {
                'fused_dims': [i.shape[0] for i in concat_domains],
                'fused_type': [i.type() for i in concat_domains],
            }
        else:
            return fused


class WrapFusedTransform(nn.Module):
    def __init__(self, operation):
        super(WrapFusedTransform, self).__init__()
        self.operation = operation

    def forward(self, c):
        sample, data = c
        return self.operation(sample), data


class DefuseOnChannel(nn.Module):
    def __init__(self, channels: List[int] = None, types: List = None):
        super(DefuseOnChannel, self).__init__()
        self.C = None
        self.C_types = types
        if channels is not None:
            C = [0]
            for c in channels:
                C.append(C[-1]+c)
            self.C = C

    def forward(self, fused: torch.Tensor) -> List[torch.Tensor]:
        if self.C is not None:
            defused = [fused[c1:c2] for c1, c2 in zip(self.C, self.C[1:])]
            retyped = [d.type(t) for d, t in zip(defused, self.C_types)]
            return retyped

        sample, data = fused
        fused_dims  = data['fused_dims']
        fused_types = data['fused_type']

        C = [0]
        for c in fused_dims:
                C.append(C[-1]+c)

        defused = [sample[c1:c2] for c1, c2 in zip(C, C[1:])]
        retyped = [d.type(t) for d, t in zip(defused, fused_types)]
        return retyped


def get_image_transform(params: CelebAParams):
    img_transform = [
        transforms.Resize([params.height, params.width]),
        transforms.ToTensor(),
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ]
    if params.width0 is not None and params.height0 is not None:
        first_r = transforms.Resize([params.height0, params.width0])
        img_transform.insert(0, first_r)
    return transforms.Compose(img_transform)


def get_sketch_transform(params: CelebAParams):
    sketch_transform = [
        transforms.Resize([params.height, params.width]),
        transforms.ToTensor(),
        transforms.Grayscale(num_output_channels=1),  # the sketches are in 3x1024x1024 but the 3 channels are equals
        # transforms.Normalize((0.5,), (0.5,), inplace=True),
    ]
    if params.width0 is not None and params.height0 is not None:
        first_r = transforms.Resize([params.height0, params.width0])
        sketch_transform.insert(0, first_r)
    return transforms.Compose(sketch_transform)


def get_mask_transform(params: CelebAParams):
    compose = [
        toTensorNoNorm(),                 # to tensor without any normalization
        toType(target_type=torch.int64),  # cast to the type in order to perform one hot encoding
        transforms.Resize([params.height, params.width], TF.InterpolationMode.NEAREST),
        Reshape([params.height, params.width]),
        partial(F.one_hot, num_classes=params.number_class_before_fusion),
        Permute([2, 0, 1]),
    ]
    if params.width0 is not None and params.height0 is not None:
        first_r = transforms.Resize([params.height0, params.width0], TF.InterpolationMode.NEAREST)
        compose.insert(2, first_r)
    if params.segmentation_fusion:
        compose.append(CollapseMask())

    return transforms.Compose(compose)


def get_join_transform(params: CelebAParams):
    if params.random_flip:
        trans_list = [
            FuseOnChannel(pass_dim=True),
            WrapFusedTransform(transforms.RandomHorizontalFlip()),
            DefuseOnChannel(),
        ]
    else:
        trans_list = []
    return transforms.Compose(trans_list)


attr_list = [
    'background',  # added it's the class 0
    'skin',  # 1

    'l_brow',  # 2
    'r_brow',  # 3

    'l_eye',  # 4
    'r_eye',  # 5

    'eye_g',  # 6

    'l_ear',  # 7
    'r_ear',  # 8

    'ear_r',  # 9

    'nose',  # 10

    'mouth',  # 11

    'u_lip',  # 12
    'l_lip',  # 13

    'neck',  # 14

    'neck_l',  # 15

    'cloth',  # 16
    'hair',  # 17
    'hat',  # 18
]

attr_collapse = [
    ('background',),
    ('skin',),
    ('l_brow', 'r_brow'),
    ('l_eye', 'r_eye'),
    ('cloth', 'hat', 'ear_r', 'eye_g', 'neck_l'),
    ('l_ear', 'r_ear'),
    ('nose',),
    ('u_lip', 'l_lip', 'mouth'),
    ('neck', ),
    ('hair', ),
]


def __get_attr_index(attr):
    return attr_list.index(attr)


def __get_attr_indexes(attrs):
    return set(__get_attr_index(a) for a in attrs)


collapse_indexes = dict()
# region create collapse indexes
seen = set()
for a_i, a in enumerate(attr_list):
    if a in seen:
        continue
    seen.add(a)

    for attr_collapse_tuple in attr_collapse:
        if a in attr_collapse_tuple:
            collapse_indexes[a_i] = __get_attr_indexes(attr_collapse_tuple)
            seen |= set(attr_collapse_tuple)
new_nb_chan = len(collapse_indexes)
# endregion


def collapse_mask(mask):
    new_mask = torch.zeros_like(mask)

    for i_new_map, (_, index_list) in enumerate(collapse_indexes.items()):
        for i in index_list:
            new_mask[i_new_map, :, :] += mask[i, :, :]
        new_mask = torch.clamp(new_mask, 0, 1)
    new_mask = new_mask[:len(collapse_indexes)]

    return new_mask


class CollapseMask(nn.Module):
    @staticmethod
    def getCollapsedName(index_in_map: int) -> List[str]:
        collapse_indexes_list = sorted(collapse_indexes.items())
        _, collapse_set = collapse_indexes_list[index_in_map]
        return [attr_list[i] for i in collapse_set]

    def forward(self, mask):
        return collapse_mask(mask)
