import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class SkeletonConv(nn.Module):
    def __init__(self, neighbour_list, in_channels, out_channels, kernel_size, joint_num, stride=1, padding=0,
                 bias=True, padding_mode='zeros', add_offset=False, in_offset_channel=0):
        self.in_channels_per_joint = in_channels // joint_num
        self.out_channels_per_joint = out_channels // joint_num
        if in_channels % joint_num != 0 or out_channels % joint_num != 0:
            raise Exception('BAD')
        super(SkeletonConv, self).__init__()

        if padding_mode == 'zeros':
            padding_mode = 'constant'
        if padding_mode == 'reflection':
            padding_mode = 'reflect'

        self.expanded_neighbour_list = []
        self.expanded_neighbour_list_offset = []
        self.neighbour_list = neighbour_list
        self.add_offset = add_offset
        self.joint_num = joint_num

        self.stride = stride
        self.dilation = 1
        self.groups = 1
        self.padding = padding
        self.padding_mode = padding_mode
        self._padding_repeated_twice = (padding, padding)

        for neighbour in neighbour_list:
            expanded = []
            for k in neighbour:
                for i in range(self.in_channels_per_joint):
                    expanded.append(k * self.in_channels_per_joint + i)
            self.expanded_neighbour_list.append(expanded)

        if self.add_offset:
            self.offset_enc = SkeletonLinear(neighbour_list, in_offset_channel * len(neighbour_list), out_channels)

            for neighbour in neighbour_list:
                expanded = []
                for k in neighbour:
                    for i in range(add_offset):
                        expanded.append(k * in_offset_channel + i)
                self.expanded_neighbour_list_offset.append(expanded)

        self.weight = torch.zeros(out_channels, in_channels, kernel_size)
        if bias:
            self.bias = torch.zeros(out_channels)
        else:
            self.register_parameter('bias', None)

        self.mask = torch.zeros_like(self.weight)
        for i, neighbour in enumerate(self.expanded_neighbour_list):
            self.mask[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...] = 1
        self.mask = nn.Parameter(self.mask, requires_grad=False)

        self.description = 'SkeletonConv(in_channels_per_armature={}, out_channels_per_armature={}, kernel_size={}, ' \
                           'joint_num={}, stride={}, padding={}, bias={})'.format(
                               in_channels // joint_num, out_channels // joint_num, kernel_size, joint_num, stride, padding, bias
                           )

        self.reset_parameters()

    def reset_parameters(self):
        for i, neighbour in enumerate(self.expanded_neighbour_list):
            """ Use temporary variable to avoid assign to copy of slice, which might lead to unexpected result """
            tmp = torch.zeros_like(self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1),
                                   neighbour, ...])
            nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
            self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1),
                        neighbour, ...] = tmp
            if self.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
                    self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...])
                bound = 1 / math.sqrt(fan_in)
                tmp = torch.zeros_like(
                    self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)])
                nn.init.uniform_(tmp, -bound, bound)
                self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)] = tmp

        self.weight = nn.Parameter(self.weight)
        if self.bias is not None:
            self.bias = nn.Parameter(self.bias)

    def set_offset(self, offset):
        if not self.add_offset:
            raise Exception('Wrong Combination of Parameters')
        self.offset = offset.reshape(offset.shape[0], -1)

    def forward(self, input):
        # print('SkeletonConv')
        weight_masked = self.weight * self.mask
        #print(f'input: {input.size()}')
        res = F.conv1d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
                       weight_masked, self.bias, self.stride,
                       0, self.dilation, self.groups)

        if self.add_offset:
            offset_res = self.offset_enc(self.offset)
            offset_res = offset_res.reshape(offset_res.shape + (1, ))
            res += offset_res / 100
        #print(f'res: {res.size()}')
        return res


class SkeletonLinear(nn.Module):
    def __init__(self, neighbour_list, in_channels, out_channels, extra_dim1=False):
        super(SkeletonLinear, self).__init__()
        self.neighbour_list = neighbour_list
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.in_channels_per_joint = in_channels // len(neighbour_list)
        self.out_channels_per_joint = out_channels // len(neighbour_list)
        self.extra_dim1 = extra_dim1
        self.expanded_neighbour_list = []

        for neighbour in neighbour_list:
            expanded = []
            for k in neighbour:
                for i in range(self.in_channels_per_joint):
                    expanded.append(k * self.in_channels_per_joint + i)
            self.expanded_neighbour_list.append(expanded)

        self.weight = torch.zeros(out_channels, in_channels)
        self.mask = torch.zeros(out_channels, in_channels)
        self.bias = nn.Parameter(torch.Tensor(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        for i, neighbour in enumerate(self.expanded_neighbour_list):
            tmp = torch.zeros_like(
                self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour]
            )
            self.mask[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = 1
            nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
            self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = tmp

        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.bias, -bound, bound)

        self.weight = nn.Parameter(self.weight)
        self.mask = nn.Parameter(self.mask, requires_grad=False)

    def forward(self, input):
        input = input.reshape(input.shape[0], -1)
        weight_masked = self.weight * self.mask
        res = F.linear(input, weight_masked, self.bias)
        if self.extra_dim1:
            res = res.reshape(res.shape + (1,))
        return res


class SkeletonPool(nn.Module):
    def __init__(self, edges, pooling_mode, channels_per_edge, last_pool=False):
        super(SkeletonPool, self).__init__()

        if pooling_mode != 'mean':
            raise Exception('Unimplemented pooling mode in matrix_implementation')

        self.channels_per_edge = channels_per_edge
        self.pooling_mode = pooling_mode
        self.edge_num = len(edges)
        # self.edge_num = len(edges) + 1
        self.seq_list = []
        self.pooling_list = []
        self.new_edges = []
        degree = [0] * 100  # each element represents the degree of the corresponding joint

        for edge in edges:
            degree[edge[0]] += 1
            degree[edge[1]] += 1

        # seq_list contains multiple sub-lists where each sub-list is an edge chain from the joint whose degree > 2 to the end effectors or joints whose degree > 2.
        def find_seq(j, seq):
            nonlocal self, degree, edges

            if degree[j] > 2 and j != 0:
                self.seq_list.append(seq)
                seq = []

            if degree[j] == 1:
                self.seq_list.append(seq)
                return

            for idx, edge in enumerate(edges):
                if edge[0] == j:
                    find_seq(edge[1], seq + [idx])

        find_seq(0, [])
        # print(f'self.seq_list: {self.seq_list}')

        for seq in self.seq_list:
            if last_pool:
                self.pooling_list.append(seq)
                continue
            if len(seq) % 2 == 1:
                self.pooling_list.append([seq[0]])
                self.new_edges.append(edges[seq[0]])
                seq = seq[1:]
            for i in range(0, len(seq), 2):
                self.pooling_list.append([seq[i], seq[i + 1]])
                self.new_edges.append([edges[seq[i]][0], edges[seq[i + 1]][1]])
        # print(f'self.pooling_list: {self.pooling_list}')
        # print(f'self.new_egdes: {self.new_edges}')

        # add global position
        # self.pooling_list.append([self.edge_num - 1])

        self.description = 'SkeletonPool(in_edge_num={}, out_edge_num={})'.format(
            len(edges), len(self.pooling_list)
        )

        self.weight = torch.zeros(len(self.pooling_list) * channels_per_edge, self.edge_num * channels_per_edge)

        for i, pair in enumerate(self.pooling_list):
            for j in pair:
                for c in range(channels_per_edge):
                    self.weight[i * channels_per_edge + c, j * channels_per_edge + c] = 1.0 / len(pair)

        self.weight = nn.Parameter(self.weight, requires_grad=False)

    def forward(self, input: torch.Tensor):
        # print('SkeletonPool')
        # print(f'input: {input.size()}')
        # print(f'self.weight: {self.weight.size()}')
        return torch.matmul(self.weight, input)


class SkeletonUnpool(nn.Module):
    def __init__(self, pooling_list, channels_per_edge):
        super(SkeletonUnpool, self).__init__()
        self.pooling_list = pooling_list
        self.input_edge_num = len(pooling_list)
        self.output_edge_num = 0
        self.channels_per_edge = channels_per_edge
        for t in self.pooling_list:
            self.output_edge_num += len(t)

        self.description = 'SkeletonUnpool(in_edge_num={}, out_edge_num={})'.format(
            self.input_edge_num, self.output_edge_num,
        )

        self.weight = torch.zeros(self.output_edge_num * channels_per_edge, self.input_edge_num * channels_per_edge)

        for i, pair in enumerate(self.pooling_list):
            for j in pair:
                for c in range(channels_per_edge):
                    self.weight[j * channels_per_edge + c, i * channels_per_edge + c] = 1

        self.weight = nn.Parameter(self.weight)
        self.weight.requires_grad_(False)

    def forward(self, input: torch.Tensor):
        # print('SkeletonUnpool')
        # print(f'input: {input.size()}')
        # print(f'self.weight: {self.weight.size()}')
        return torch.matmul(self.weight, input)


"""
Helper functions for skeleton operation
"""


def dfs(x, fa, vis, dist):
    vis[x] = 1
    for y in range(len(fa)):
        if (fa[y] == x or fa[x] == y) and vis[y] == 0:
            dist[y] = dist[x] + 1
            dfs(y, fa, vis, dist)



def build_edge_topology(topology):
    # get all edges (pa, child)
    edges = []
    joint_num = len(topology)
    edges.append((0, joint_num))  # add an edge between the root joint and a virtual joint
    for i in range(1, joint_num):
        edges.append((topology[i], i))
    return edges


def build_joint_topology(edges, origin_names):
    parent = []
    offset = []
    names = []
    edge2joint = []
    joint_from_edge = []  # -1 means virtual joint
    joint_cnt = 0
    out_degree = [0] * (len(edges) + 10)
    for edge in edges:
        out_degree[edge[0]] += 1

    # add root joint
    joint_from_edge.append(-1)
    parent.append(0)
    offset.append(np.array([0, 0, 0]))
    names.append(origin_names[0])
    joint_cnt += 1

    def make_topology(edge_idx, pa):
        nonlocal edges, parent, offset, names, edge2joint, joint_from_edge, joint_cnt
        edge = edges[edge_idx]
        if out_degree[edge[0]] > 1:
            parent.append(pa)
            offset.append(np.array([0, 0, 0]))
            names.append(origin_names[edge[1]] + '_virtual')
            edge2joint.append(-1)
            pa = joint_cnt
            joint_cnt += 1

        parent.append(pa)
        offset.append(edge[2])
        names.append(origin_names[edge[1]])
        edge2joint.append(edge_idx)
        pa = joint_cnt
        joint_cnt += 1

        for idx, e in enumerate(edges):
            if e[0] == edge[1]:
                make_topology(idx, pa)

    for idx, e in enumerate(edges):
        if e[0] == 0:
            make_topology(idx, 0)

    return parent, offset, names, edge2joint


def calc_edge_mat(edges):
    edge_num = len(edges)
    # edge_mat[i][j] = distance between edge(i) and edge(j)
    edge_mat = [[100000] * edge_num for _ in range(edge_num)]
    for i in range(edge_num):
        edge_mat[i][i] = 0

    # initialize edge_mat with direct neighbor
    for i, a in enumerate(edges):
        for j, b in enumerate(edges):
            link = 0
            for x in range(2):
                for y in range(2):
                    if a[x] == b[y]:
                        link = 1
            if link:
                edge_mat[i][j] = 1

    # calculate all the pairs distance
    for k in range(edge_num):
        for i in range(edge_num):
            for j in range(edge_num):
                edge_mat[i][j] = min(edge_mat[i][j], edge_mat[i][k] + edge_mat[k][j])
    return edge_mat


def find_neighbor(edges, d):
    """
    Args:
        edges: The list contains N elements, each element represents (parent, child).
        d: Distance between edges (the distance of the same edge is 0 and the distance of adjacent edges is 1).

    Returns:
        The list contains N elements, each element is a list of edge indices whose distance <= d.
    """
    edge_mat = calc_edge_mat(edges)
    neighbor_list = []
    edge_num = len(edge_mat)
    for i in range(edge_num):
        neighbor = []
        for j in range(edge_num):
            if edge_mat[i][j] <= d:
                neighbor.append(j)
        neighbor_list.append(neighbor)

    # # add neighbor for global part
    # global_part_neighbor = neighbor_list[0].copy()
    # """
    # Line #373 is buggy. Thanks @crissallan!!
    # See issue #30 (https://github.com/DeepMotionEditing/deep-motion-editing/issues/30)
    # However, fixing this bug will make it unable to load the pretrained model and
    # affect the reproducibility of quantitative error reported in the paper.
    # It is not a fatal bug so we didn't touch it and we are looking for possible solutions.
    # """
    # for i in global_part_neighbor:
    #     neighbor_list[i].append(edge_num)
    # neighbor_list.append(global_part_neighbor)

    return neighbor_list


def calc_node_depth(topology):
    def dfs(node, topology):
        if topology[node] < 0:
            return 0
        return 1 + dfs(topology[node], topology)
    depth = []
    for i in range(len(topology)):
        depth.append(dfs(i, topology))

    return depth


def residual_ratio(k):
    return 1 / (k + 1)


class Affine(nn.Module):
    def __init__(self, num_parameters, scale=True, bias=True, scale_init=1.0):
        super(Affine, self).__init__()
        if scale:
            self.scale = nn.Parameter(torch.ones(num_parameters) * scale_init)
        else:
            self.register_parameter('scale', None)

        if bias:
            self.bias = nn.Parameter(torch.zeros(num_parameters))
        else:
            self.register_parameter('bias', None)

    def forward(self, input):
        output = input
        if self.scale is not None:
            scale = self.scale.unsqueeze(0)
            while scale.dim() < input.dim():
                scale = scale.unsqueeze(2)
        output = output.mul(scale)

        if self.bias is not None:
            bias = self.bias.unsqueeze(0)
            while bias.dim() < input.dim():
                bias = bias.unsqueeze(2)
        output += bias

        return output


class BatchStatistics(nn.Module):
    def __init__(self, affine=-1):
        super(BatchStatistics, self).__init__()
        self.affine = nn.Sequential() if affine == -1 else Affine(affine)
        self.loss = 0

    def clear_loss(self):
        self.loss = 0

    def compute_loss(self, input):
        input_flat = input.view(input.size(1), input.numel() // input.size(1))
        mu = input_flat.mean(1)
        logvar = (input_flat.pow(2).mean(1) - mu.pow(2)).sqrt().log()

        self.loss = mu.pow(2).mean() + logvar.pow(2).mean()

    def forward(self, input):
        self.compute_loss(input)
        return self.affine(input)


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation, batch_statistics=False, last_layer=False):
        super(ResidualBlock, self).__init__()

        self.residual_ratio = residual_ratio
        self.shortcut_ratio = 1 - residual_ratio

        residual = []
        residual.append(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding))
        if batch_statistics:
            residual.append(BatchStatistics(out_channels))
        if not last_layer:
            residual.append(nn.PReLU() if activation == 'relu' else nn.Tanh())
        self.residual = nn.Sequential(*residual)

        self.shortcut = nn.Sequential(
            nn.AvgPool1d(kernel_size=2) if stride == 2 else nn.Sequential(),
            nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
            BatchStatistics(out_channels) if (in_channels != out_channels and batch_statistics is True) else nn.Sequential()
        )

    def forward(self, input):
        return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio)


class ResidualBlockTranspose(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation):
        super(ResidualBlockTranspose, self).__init__()

        self.residual_ratio = residual_ratio
        self.shortcut_ratio = 1 - residual_ratio

        self.residual = nn.Sequential(
            nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding),
            nn.PReLU() if activation == 'relu' else nn.Tanh()
        )

        self.shortcut = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='linear', align_corners=False) if stride == 2 else nn.Sequential(),
            nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        )

    def forward(self, input):
        return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio)


class SkeletonResidual(nn.Module):
    def __init__(self, topology, neighbour_list, joint_num, in_channels, out_channels, kernel_size, stride, padding, padding_mode, bias, extra_conv, pooling_mode, activation, last_pool):
        super(SkeletonResidual, self).__init__()

        kernel_even = False if kernel_size % 2 else True

        seq = []
        for _ in range(extra_conv):
            # (T, J, D) => (T, J, D)
            seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=in_channels,
                                    joint_num=joint_num, kernel_size=kernel_size - 1 if kernel_even else kernel_size,
                                    stride=1,
                                    padding=padding, padding_mode=padding_mode, bias=bias))
            seq.append(nn.PReLU() if activation == 'relu' else nn.Tanh())
        # (T, J, D) => (T/2, J, 2D)
        seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels,
                                joint_num=joint_num, kernel_size=kernel_size, stride=stride,
                                padding=padding, padding_mode=padding_mode, bias=bias, add_offset=False))
        seq.append(nn.GroupNorm(10, out_channels))  # FIXME: REMEMBER TO CHANGE BACK !!!
        self.residual = nn.Sequential(*seq)

        # (T, J, D) => (T/2, J, 2D)
        self.shortcut = SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels,
                                     joint_num=joint_num, kernel_size=1, stride=stride, padding=0,
                                     bias=True, add_offset=False)

        seq = []
        # (T/2, J, 2D) => (T/2, J', 2D)
        pool = SkeletonPool(edges=topology, pooling_mode=pooling_mode,
                            channels_per_edge=out_channels // len(neighbour_list), last_pool=last_pool)
        if len(pool.pooling_list) != pool.edge_num:
            seq.append(pool)
        seq.append(nn.PReLU() if activation == 'relu' else nn.Tanh())
        self.common = nn.Sequential(*seq)

    def forward(self, input):
        output = self.residual(input) + self.shortcut(input)

        return self.common(output)


class SkeletonResidualTranspose(nn.Module):
    def __init__(self, neighbour_list, joint_num, in_channels, out_channels, kernel_size, padding, padding_mode, bias, extra_conv, pooling_list, upsampling, activation, last_layer):
        super(SkeletonResidualTranspose, self).__init__()

        kernel_even = False if kernel_size % 2 else True

        seq = []
        # (T, J, D) => (2T, J, D)
        if upsampling is not None:
            seq.append(nn.Upsample(scale_factor=2, mode=upsampling, align_corners=False))
        # (2T, J, D) => (2T, J', D)
        unpool = SkeletonUnpool(pooling_list, in_channels // len(neighbour_list))
        if unpool.input_edge_num != unpool.output_edge_num:
            seq.append(unpool)
        self.common = nn.Sequential(*seq)

        seq = []
        for _ in range(extra_conv):
            # (2T, J', D) => (2T, J', D)
            seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=in_channels,
                                    joint_num=joint_num, kernel_size=kernel_size - 1 if kernel_even else kernel_size,
                                    stride=1,
                                    padding=padding, padding_mode=padding_mode, bias=bias))
            seq.append(nn.PReLU() if activation == 'relu' else nn.Tanh())
        # (2T, J', D) => (2T, J', D/2)
        seq.append(SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels,
                                joint_num=joint_num, kernel_size=kernel_size - 1 if kernel_even else kernel_size,
                                stride=1,
                                padding=padding, padding_mode=padding_mode, bias=bias, add_offset=False))
        self.residual = nn.Sequential(*seq)

        # (2T, J', D) => (2T, J', D/2)
        self.shortcut = SkeletonConv(neighbour_list, in_channels=in_channels, out_channels=out_channels,
                                     joint_num=joint_num, kernel_size=1, stride=1, padding=0,
                                     bias=True, add_offset=False)

        if activation == 'relu':
            self.activation = nn.PReLU() if not last_layer else None
        else:
            self.activation = nn.Tanh() if not last_layer else None

    def forward(self, input):
        output = self.common(input)
        output = self.residual(output) + self.shortcut(output)

        if self.activation is not None:
            return self.activation(output)
        else:
            return output