"""
The random test script used to verify the correctness of fused implementations.
"""
import torch
import fselu
import numpy as np
import os
import torch.nn.functional as F
import torch.nn as nn

os.environ["CUDA_VISIBLE_DEVICES"] = '3'

assert torch.cuda.is_available(), "CUDA is not available"
device = torch.device('cuda')

# Define the auto-gradient functions

class FuselSELUfn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha, beta, lambda_):
        y = fselu.fuse_lselu_f(x, alpha, beta, lambda_)
        ctx.save_for_backward(x, lambda_)
        ctx.alpha = alpha
        ctx.beta = beta
        return y
    
    @staticmethod
    def backward(ctx, grad_y):
        x, lambda_ = ctx.saved_tensors
        grad_x, grad_lambda = fselu.fuse_lselu_b(grad_y, x, ctx.alpha, ctx.beta, lambda_, True)
        return grad_x,  None, None, grad_lambda
    

class FusesSELUfn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha, beta, lambda_):
        y = fselu.fuse_sselu_f(x, alpha, beta, lambda_)
        ctx.save_for_backward(x, lambda_)
        ctx.alpha = alpha
        ctx.beta = beta
        return y
    
    @staticmethod
    def backward(ctx, grad_y):
        x, lambda_ = ctx.saved_tensors
        grad_x, grad_lambda = fselu.fuse_sselu_b(grad_y, x, ctx.alpha, ctx.beta, lambda_, True)
        return grad_x,  None, None, grad_lambda
    

fuse_lselu = FuselSELUfn.apply
fuse_sselu = FusesSELUfn.apply

def sselu(x, alpha, beta, lambda_):
    return lambda_ * (F.elu(beta * x, alpha) + F.relu(input=x, inplace=False) * (1 - beta))

def lselu(x, alpha, beta, lambda_):
    return lambda_ * (F.elu(x, alpha) + F.leaky_relu(input=x, negative_slope=beta, inplace=False) - F.relu(input=x, inplace=False))

# Single test

def single_test(func, func_ref):
    # random input feature map size
    N = np.random.randint(low=1, high=128)
    C = np.random.randint(low=1, high=128)
    H = np.random.randint(low=1, high=128)
    W = np.random.randint(low=1, high=128)
    alpha = np.random.uniform(low=0, high=1.)
    beta = np.random.uniform(low=0, high=1.)
    l = np.random.uniform(low=0.5, high=2.)
    
    lambda_ = torch.ones(1).to(device) * l
    lambda_.requires_grad_(True)
    
    lambda_2 = torch.ones(1).to(device) * l
    lambda_2.requires_grad_(True)
    
    feature = torch.randn(size=(N, C, H, W), device=device, requires_grad=False, dtype=torch.float32)
    print(feature.numel())
    
    feat_in = feature.clone().requires_grad_(True)
    feat_in_ref = feature.clone().requires_grad_(True)
    
    feat_out = func(feat_in, alpha, beta, lambda_)
    feat_out_ref = func_ref(feat_in_ref, alpha, beta, lambda_2)
    
    grad = torch.rand_like(feat_out)/10000.
    
    feat_out.backward(grad)
    feat_out_ref.backward(grad)
    
    grad_x = feat_in.grad
    grad_x_ref = feat_in_ref.grad
    
    g_l = lambda_.grad
    g_l_ref = lambda_2.grad
    
    print(g_l)
    print(g_l_ref)
    
    error_out = feat_out - feat_out_ref
    error_grad = grad_x - grad_x_ref
    
    passed = True
    max_error_out = torch.max(torch.abs(error_out)).item()
    if max_error_out > 1e-5 or np.isnan(max_error_out):
        print("[Forward] there are %d different entries in overall %d entries. The maximum difference is %f" 
              % (torch.nonzero(error_out).size(0), error_out.numel(), max_error_out))
        passed = False
    
    max_error_grad = torch.max(torch.abs(error_grad)).item()
    if max_error_grad > 1e-5 or np.isnan(max_error_grad):
        print("[Backward] there are %d different entries in overall %d entries. The maximum difference is %f" 
              % (torch.nonzero(error_grad).size(0), error_grad.numel(), max_error_grad))
        passed = False

    return passed


num_pass = 0
for i in range(10):
    if single_test(fuse_lselu, lselu):
        num_pass += 1

print("%d out of %d tests passed" % (num_pass, 10))

# mem_stat = torch.cuda.memory_stats(device=device)
# print(mem_stat['allocated_bytes.all.peak'])