import time
import math

import torch
from torch.autograd import Function
from torch.autograd.function import once_differentiable

def MonoidReduce(
    fun_name,
    *,
    init, # construct identity-tensors based on a and b
    avals, # number of A-tensors
    bvals, # number of B-tensors
    islice, # slicing over i-axis
    jslice, # slicing over j-axis
    proj_fold, # fused projection and fold over A and B-slices.
    proj_fold_bwd, # fused projection and fold backwards op. 
    binary_reduce, # monoidal product
    ):

    #aslice = slice(0, avals)
    astop = avals
    bstop = astop + bvals
    #bslice = slice(avals, avals+bvals)
    #pslice = slice(avals+bvals, None)

    class DynamicFunction(torch.autograd.Function):
        @staticmethod
        def forward(*inputs):
            a = inputs[0:astop]
            b = inputs[astop:bstop]
            p = init(a, b)
            # Running over ai can be done in parallel
            # (requires rework of how to construct final p)
            for ai, pi in zip(
                    islice(a),
                    islice(p),
                    ):
                pi_acc = pi
                # Each bj slice can be run in parallel, requires
                # parallel reduction of the pij-values over j.
                for bj in jslice(b):
                    pij = proj_fold(ai, bj)
                    pi_acc = binary_reduce(pi_acc, pij)
                    del pij

                for pi_p, pi_acc_p in zip(pi, pi_acc):
                    pi_p.copy_(pi_acc_p)

            return p
            
    
        @staticmethod
        def setup_context(ctx, inputs, outputs):
            ctx.save_for_backward(
                *inputs, *outputs
            )
    
        @staticmethod
        @once_differentiable
        def backward(ctx, *gp):
            saved = ctx.saved_tensors

            a = saved[:astop]
            b = saved[astop:bstop]
            p = saved[bstop:]
            
            ga = [a_p.new_zeros(a_p.shape) for a_p in a]
            amask = [a_p.requires_grad for a_p in a]
            gb = [b_p.new_zeros(b_p.shape) for b_p in b]
            bmask = [b_p.requires_grad for b_p in b]
            # Running over ai can be done in parallel
            # (requires atomic adds or other parallel sum over gaij)
            for ai, gai, pi, gpi in zip(
                    islice(a),
                    islice(ga),
                    islice(p),
                    islice(gp),
                    ):
                # Running over bi can be done in parallel
                # (requires atomic adds or other parallel sum over gbij)
                for bj, gbj in zip(
                        jslice(b),
                        jslice(gb),
                        ):
                    lgaij, lgbij = proj_fold_bwd(ai, bj, pi, gpi)
                    for gai_p, lgaij_p, requires_grad in zip(gai, lgaij, amask):
                        if requires_grad:
                            gai_p.add_(lgaij_p)
                    for gbj_p, lgbij_p, requires_grad in zip(gbj, lgbij, bmask):
                        if requires_grad:
                            gbj_p.add_(lgbij_p)
                    del lgaij
                    del lgbij

            return *ga, *gb
    
    DynamicFunction.__name__ = f'MonoidReduce_{fun_name}'
    DynamicFunction.__qualname__ = f'MonoidReduce_{fun_name}'
    DynamicFunction.__module__ = getattr(init, '__module__', __name__)

    return DynamicFunction

def check_equality(f1, f2, inputs, mock):
    for p in inputs:
        if p.grad is not None:
            p.grad.zero_()
    y1 = f1(*inputs)
    (y1 * mock).sum().backward()
    y1 = y1.detach()
    g1 = []
    for p in inputs:
        if p.grad is not None:
            g1.append(p.grad.detach().clone())
    for p in inputs:
        if p.grad is not None:
            p.grad.zero_()
    y2 = f2(*inputs)
    (y2 * mock).sum().backward()
    y2 = y2.detach()
    g2 = []
    for p in inputs:
        if p.grad is not None:
            g2.append(p.grad.detach().clone())
    for p in inputs:
        if p.grad is not None:
            p.grad.zero_()
    
    def check_pair(a, b):
        delta = a - b
        shapes_match = a.shape == b.shape
        all_close = torch.allclose(a, b, atol=1e-04, rtol=1e-04)
        rmse = delta.pow(2).mean().sqrt()
        mae = delta.abs().mean()
        mean = delta.mean()
        std = delta.std()
        z_hi = (delta.max() - mean) / std
        z_lo = (delta.min() - mean) / std
        l2_diff = delta.pow(2).sum().sqrt()
        max_diff = delta.abs().max()
        print(f'{" numel": <20}: {delta.numel()}')
        print(f'{" shapes match": <20}: {shapes_match}')
        print(f'{" all close": <20}: {all_close}')
        print(f'{" l2 diff": <20}: {l2_diff}')
        print(f'{" max_diff": <20}: {max_diff}')
        print(f'{" RMSE": <20}: {rmse}')
        print(f'{" MAE": <20}: {mae}')
        print()
        if all_close and shapes_match:
            print('   All good! :)')
        else:
            print('   Something might be off ...')
        print()

    print(f'{" output ":=^30}')
    check_pair(y1, y2)

    print(f'{" grad ":=^30}')
    for i, (a, b) in enumerate(zip(g1, g2)):
        name = f' grad_{i} '
        print(f'  {name:-^26}  ')
        check_pair(a, b)


def check_speed(f1, inputs, mock, runs=10, warmup=3):
    for i in range(warmup):
        for p in inputs:
            if p.grad is not None:
                p.grad.zero_()
        y1 = f1(*inputs)
        (y1 * mock).sum().backward()

    start = time.perf_counter()
    for i in range(runs):
        acc = []
        y1 = f1(*inputs)
        (y1 * mock).sum().backward()
        with torch.no_grad():
            for p in inputs:
                if p.grad is not None:
                    acc.append(p.grad.sum())
        acc = sum(acc).item()

    return (time.perf_counter() - start) / runs
    
def check(f1, f2, inputs, mock, runs=10, warmup=3):
    check_equality(f1, f2, inputs, mock)
    print(f'{" speed ":=^30}')
    print(f' {" f1 ":-^26}  ')
    s1 = check_speed(f1, inputs, mock, runs, warmup)
    print(f' {s1:.2f}')
    print(f' {" f2 ":-^26}  ')
    s2 = check_speed(f2, inputs, mock, runs, warmup)
    print(f' {s2:.2f}')
    ratio = (s1 / s2)
    print(f' relative time: {ratio:2f}')
    if ratio > 1:
        print(f'   f1 is slower :(')
    else:
        print(f'   f1 is faster :)')
