"""Element multiplication with the A matrix based on its sign."""
import torch
import time
from typing import Optional, Tuple
from torch import Tensor
from ..patches import Patches


torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)


# @torch.jit.script
def _reference_multiply_by_A_signs(A: Tensor, d_pos: Tensor, d_neg: Tensor,
        b_pos: Optional[Tensor], b_neg: Optional[Tensor], patches_mode: bool) -> Tuple[Tensor, Tensor]:
    """Reference implementation."""
    A_pos = A.clamp(min=0)
    A_neg = A.clamp(max=0)
    A_new = d_pos * A_pos + d_neg * A_neg
    bias_pos = bias_neg = torch.tensor(0.)
    if b_pos is not None:
        if patches_mode:
            bias_pos = torch.einsum('sb...chw,sb...chw->sb...', A_pos, b_pos)
        else:
            bias_pos = torch.einsum('sb...,sb...->sb', A_pos, b_pos)
    if b_neg is not None:
        if patches_mode:
            bias_neg = torch.einsum('sb...chw,sb...chw->sb...', A_neg, b_neg)
        else:
            bias_neg = torch.einsum('sb...,sb...->sb', A_neg, b_neg)
    return A_new, bias_pos + bias_neg


class ClampedMultiplication(torch.autograd.Function):
    @staticmethod
    @torch.jit.script
    def clamp_mutiply_forward(A: Tensor, d_pos: Tensor, d_neg: Tensor,
            b_pos: Optional[Tensor], b_neg: Optional[Tensor], patches_mode: bool) -> Tuple[Tensor, Tensor]:
        """Forward operations; actually the same as the reference implementation."""
        A_pos = A.clamp(min=0)
        A_neg = A.clamp(max=0)
        A_new = d_pos * A_pos + d_neg * A_neg
        bias_pos = bias_neg = torch.tensor(0.)
        if b_pos is not None:
            if patches_mode:
                bias_pos = torch.einsum('sb...chw,sb...chw->sb...', A_pos, b_pos)
            else:
                bias_pos = torch.einsum('sb...,sb...->sb', A_pos, b_pos)
        if b_neg is not None:
            if patches_mode:
                bias_neg = torch.einsum('sb...chw,sb...chw->sb...', A_neg, b_neg)
            else:
                bias_neg = torch.einsum('sb...,sb...->sb', A_neg, b_neg)
        return A_new, bias_pos + bias_neg

    @staticmethod
    @torch.jit.script
    def clamp_mutiply_backward(A: Tensor, d_pos: Tensor, d_neg: Tensor,
            b_pos: Optional[Tensor], b_neg: Optional[Tensor], grad_output_A: Tensor, grad_output_bias: Optional[Tensor],
            patches_mode: bool) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], None]:
        """Improved backward operation. This should be better than the backward function generated by Pytorch."""
        if grad_output_bias is not None:
            extension_dim = len(A.shape) - len(grad_output_bias.shape)
            grad_output_bias = grad_output_bias.view(grad_output_bias.shape + (1, ) * extension_dim)
        A_pos_mask = (A >= 0).to(dtype=grad_output_A.dtype)
        A_neg_mask = 1. - A_pos_mask
        A_pos_grad_output_A = A_pos_mask * grad_output_A
        A_neg_grad_output_A = A_neg_mask * grad_output_A
        gd_pos = A * A_pos_grad_output_A
        gd_neg = A * A_neg_grad_output_A
        if b_pos is not None and b_neg is not None and grad_output_bias is not None:
            A_pos_grad_output_bias = A_pos_mask * grad_output_bias
            A_neg_grad_output_bias = A_neg_mask * grad_output_bias
            gb_neg = A * A_neg_grad_output_bias
            gb_pos = A * A_pos_grad_output_bias
            # gA has 4 terms.
            gA = d_pos * A_pos_grad_output_A + d_neg * A_neg_grad_output_A + b_pos * A_pos_grad_output_bias + b_neg * A_neg_grad_output_bias
        elif b_neg is not None and grad_output_bias is not None:
            A_neg_grad_output_bias = A_neg_mask * grad_output_bias
            gb_neg = A * A_neg_grad_output_bias
            gb_pos = None
            # gA has 3 terms.
            gA = d_pos * A_pos_grad_output_A + d_neg * A_neg_grad_output_A + b_neg * A_neg_grad_output_bias
        elif b_pos is not None and grad_output_bias is not None:
            A_pos_grad_output_bias = A_pos_mask * grad_output_bias
            gb_pos = A * A_pos_grad_output_bias
            gb_neg = None
            # gA has 3 terms.
            gA = d_pos * A_pos_grad_output_A + d_neg * A_neg_grad_output_A + b_pos * A_pos_grad_output_bias
        else:
            # gA has 2 terms.
            gA = d_pos * A_pos_grad_output_A + d_neg * A_neg_grad_output_A
            gb_pos = gb_neg = None
        return gA, gd_pos, gd_neg, gb_pos, gb_neg, None

    @staticmethod
    def forward(ctx, A, d_pos, d_neg, b_pos, b_neg, patches_mode):
        # No need to save the intermediate A_pos, A_neg as they have been fused into the computation.
        ctx.save_for_backward(A, d_pos, d_neg, b_pos, b_neg)
        ctx.patches_mode = patches_mode
        return ClampedMultiplication.clamp_mutiply_forward(A, d_pos, d_neg, b_pos, b_neg, patches_mode)

    @staticmethod
    def backward(ctx, grad_output_A, grad_output_bias):
        A, d_pos, d_neg, b_pos, b_neg = ctx.saved_tensors
        patches_mode = ctx.patches_mode
        return ClampedMultiplication.clamp_mutiply_backward(A, d_pos, d_neg, b_pos, b_neg, grad_output_A, grad_output_bias, patches_mode)


def multiply_by_A_signs(A, d_pos, d_neg, b_pos, b_neg, contiguous='auto'):
    if isinstance(A, Tensor):
        if contiguous is True or contiguous == 'auto':
            # For dense mode, convert d_pos and d_neg to contiguous tensor by default.
            d_pos = d_pos.contiguous()
            d_neg = d_neg.contiguous()
        if d_pos.ndim == 1:
            # Special case for LSTM, the bias term is 1-dimension. (FIXME)
            assert d_neg.ndim == 1 and b_pos.ndim == 1 and b_neg.ndim == 1
            new_A = A.clamp(min=0) * d_pos + A.clamp(max=0) * d_neg
            new_bias = A.clamp(min=0) * b_pos + A.clamp(max=0) * b_neg
            return new_A, new_bias
        return ClampedMultiplication.apply(A, d_pos, d_neg, b_pos, b_neg, False)
    elif isinstance(A, Patches):
        if contiguous:
            # For patches mode, do not convert d_pos and d_neg to contiguous tensor by default.
            d_pos = d_pos.contiguous()
            d_neg = d_neg.contiguous()
        assert A.identity == 0  # TODO: handle the A.identity = 1 case. Currently not used.
        patches = A.patches
        patches_shape = patches.shape
        # patches shape: [out_c, batch_size, out_h, out_w, in_c, H, W]. Here out_c is the spec dimension.
        # or (unstable_size, batch_size, in_c, H, W) when it is sparse.
        if len(patches_shape) == 6:
            patches = patches.view(*patches_shape[:2], -1, *patches_shape[-2:])
            d_pos = d_pos.view(*patches_shape[:2], -1, *patches_shape[-2:]) if d_pos is not None else None
            d_neg = d_neg.view(*patches_shape[:2], -1, *patches_shape[-2:]) if d_neg is not None else None
            b_pos = b_pos.view(*patches_shape[:2], -1, *patches_shape[-2:]) if b_pos is not None else None
            b_neg = b_neg.view(*patches_shape[:2], -1, *patches_shape[-2:]) if b_neg is not None else None
        # Apply the multiplication based on signs.
        A_prod, bias = ClampedMultiplication.apply(patches, d_pos, d_neg, b_pos, b_neg, True)
        # prod has shape [out_c, batch_size, out_h, out_w, in_c, H, W] or (unstable_size, batch_size, in_c, H, W) when it is sparse.
        # For sparse patches the return bias size is (unstable_size, batch).
        # For regular patches the return bias size is (spec, batch, out_h, out_w).
        if len(patches_shape) == 6:
            A_prod = A_prod.view(*patches_shape)
        return A.create_similar(A_prod), bias


def _speed_test(A, d_pos, d_neg, b_pos, b_neg, patches_mode=False, n_test=20, warmup=3):
    """Benchmarking function."""
    print(f'patches_mode = {patches_mode}, b_pos is {type(b_pos)}, b_neg is {type(b_neg)}')
    total_ref = 0.
    total_new = 0.
    run = ['ref', 'new']
    for i in range(n_test):
        ref_time = new_time = 0.

        if 'ref' in run:
            torch.cuda.synchronize()
            start = time.time()
            ref_A, ref_bias = _reference_multiply_by_A_signs(A, d_pos, d_neg, b_pos, b_neg, patches_mode)
            ref_loss = ref_A.sum() + ref_bias.sum()
            ref_loss.backward()
            torch.cuda.synchronize()
            ref_time = time.time() - start
            ref_gA = A.grad.detach().clone()
            ref_gd_pos = d_pos.grad.detach().clone()
            ref_gd_neg = d_neg.grad.detach().clone()
            ref_gb_pos = b_pos.grad.detach().clone() if b_pos is not None else torch.tensor(0.)
            ref_gb_neg = b_neg.grad.detach().clone() if b_neg is not None else torch.tensor(0.)
            A.grad = d_pos.grad = d_neg.grad = None
            if b_pos is not None:
                b_pos.grad = None
            if b_neg is not None:
                b_neg.grad = None
            del ref_loss

        if 'new' in run:
            torch.cuda.synchronize()
            start = time.time()
            new_A, new_bias = multiply_by_A_signs(A, d_pos, d_neg, b_pos, b_neg, patches_mode)
            new_loss = new_A.sum() + new_bias.sum()
            new_loss.backward()
            torch.cuda.synchronize()
            new_time = time.time() - start
            new_gA = A.grad.detach().clone()
            new_gd_pos = d_pos.grad.detach().clone()
            new_gd_neg = d_neg.grad.detach().clone()
            new_gb_pos = b_pos.grad.detach().clone() if b_pos is not None else torch.tensor(0.)
            new_gb_neg = b_neg.grad.detach().clone() if b_neg is not None else torch.tensor(0.)
            A.grad = d_pos.grad = d_neg.grad = None
            if b_pos is not None:
                b_pos.grad = None
            if b_neg is not None:
                b_neg.grad = None
            del new_loss

        print(f'Loop {i:3d} {"(warmup)" if i < warmup else "        "} time ref {ref_time:.5f} new {new_time:.6f} speedup {ref_time / new_time if i >= warmup else float("nan"):.3f}')
        if i >= warmup:
            total_ref += ref_time
            total_new += new_time

        if 'ref' in run and 'new' in run:
            A_diff = (ref_A - new_A).abs().sum().item() / ref_A.abs().sum().item()
            gA_diff = (ref_gA - new_gA).abs().sum().item() / ref_gA.abs().sum().item()
            bias_diff = (ref_bias - new_bias).abs().sum().item() / (ref_bias.abs().sum().item() + 1e-10)
            gd_pos_diff = (ref_gd_pos - new_gd_pos).abs().sum().item() / ref_gd_pos.abs().sum().item()
            gd_neg_diff = (ref_gd_neg - new_gd_neg).abs().sum().item() / ref_gd_neg.abs().sum().item()
            gb_pos_diff = (ref_gb_pos - new_gb_pos).abs().sum().item() / (ref_gb_pos.abs().sum().item() + 1e-10)
            gb_neg_diff = (ref_gb_neg - new_gb_neg).abs().sum().item() / (ref_gb_neg.abs().sum().item() + 1e-10)
            print(f'                  diff {A_diff} {gA_diff} {bias_diff} {gd_pos_diff} {gd_neg_diff} {gb_pos_diff} {gb_neg_diff}')
            assert A_diff < 1e-6 and bias_diff < 1e-6 and gA_diff < 1e-6 and gd_pos_diff < 1e-6 and gd_neg_diff < 1e-6
            assert gb_pos_diff < 1e-6 and gb_neg_diff < 1e-6


    avg_ref_time = total_ref / (n_test - warmup)
    avg_new_time = total_new / (n_test - warmup)
    print(f'Avg. time: reference {avg_ref_time:.5f} new {avg_new_time:.6f} speedup {avg_ref_time / avg_new_time:.3f}')


if __name__ == '__main__':
    for patches_mode in [True, False]:
        if patches_mode:
            shape = (256, 8, 8, 8, 16, 32)
        else:
            shape = (256, 8, 128, 256)
        A = torch.randn(shape, device='cuda', requires_grad=True)
        d_pos = torch.randn(shape, device='cuda', requires_grad=True)
        d_neg = torch.randn(shape, device='cuda', requires_grad=True)
        b_pos = torch.randn(shape, device='cuda', requires_grad=True)
        b_neg = torch.randn(shape, device='cuda', requires_grad=True)
        _speed_test(A, d_pos, d_neg, None, None, patches_mode=patches_mode)
        _speed_test(A, d_pos, d_neg, None, b_neg, patches_mode=patches_mode)
        _speed_test(A, d_pos, d_neg, b_pos, None, patches_mode=patches_mode)
        _speed_test(A, d_pos, d_neg, b_pos, b_neg, patches_mode=patches_mode)
        print('Press Enter key to continue.')
        input()
        del A, d_pos, d_neg, b_pos, b_neg
