import torch
import time 
from butterfly_cpp import butterfly_copy_forward, butterfly_transcendental_forward, butterfly_const_forward
import numpy as np

torch.manual_seed(0)

def butterfly(x, twiddle_factors_fft):
    x = (x * twiddle_factors_fft)
    x = x.sum(4)
    return x

l = 32768
B = 2
H = 32
m = 16
N = l * m


x = torch.randn(B * H * N, device='cuda').to(torch.chalf)
x = x.view(B, H, 1, m, m, l//m)
#x = torch.randn(B, H, 1, m, m, l//m, device='cuda').to(torch.chalf)
twiddle_factors_fft = torch.randn(m * m * m * l//m, device='cuda').to(torch.chalf)
twiddle_factors_fft = twiddle_factors_fft.view(1, 1, m, m, m, l//m)
# twiddle_factors_fft = torch.randn(1, 1, m, m, m, l//m, device='cuda').to(torch.chalf)

print(x.shape)
print(twiddle_factors_fft.shape)
repeats = 100

# butterfly(x, twiddle_factors_fft)
# butterfly(x, twiddle_factors_fft)
# torch.cuda.synchronize()
# start = time.time()
# for i in range(repeats):
#     out = butterfly(x, twiddle_factors_fft)
# torch.cuda.synchronize()
# butterfly_time = time.time() - start

# torch.clone(x)
# torch.clone(x)
# torch.cuda.synchronize()
# start = time.time()
# for i in range(repeats):
#     z = torch.clone(x)
# torch.cuda.synchronize()
# clone_time_1 = time.time() - start


# torch.clone(twiddle_factors_fft)
# torch.cuda.synchronize()
# start = time.time()
# for i in range(repeats):
#     z = torch.clone(twiddle_factors_fft)
# torch.cuda.synchronize()
# clone_time_2 = m * (time.time() - start)/2

# clone_time = clone_time_1 #+ clone_time_2


# butterfly_copy_forward(x, twiddle_factors_fft)
# butterfly_copy_forward(x, twiddle_factors_fft)
# torch.cuda.synchronize()
# start = time.time()
# for i in range(repeats):
#     out_cuda = butterfly_copy_forward(x, twiddle_factors_fft)
# torch.cuda.synchronize()
# butterfly_copy_time = time.time() - start

# butterfly_transcendental_forward(x)
# butterfly_transcendental_forward(x)
# torch.cuda.synchronize()
# start = time.time()
# for i in range(repeats):
#     out_cuda = butterfly_transcendental_forward(x)
# torch.cuda.synchronize()
# butterfly_transcendental_time = time.time() - start


# print(x.shape)
# print(twiddle_factors_fft.shape)
# out = butterfly(x, twiddle_factors_fft)
# print("out shape: ", out.shape)
# out_cuda = butterfly_copy_forward(x, twiddle_factors_fft)


# print("butterfly copy time: ", butterfly_copy_time/repeats)
# print("butterfly compute time: ", butterfly_transcendental_time/repeats)
# print("clone time: ", clone_time / repeats)
# print("clone bandwidth: TB/s", (x.numel() * x.element_size() * 2 + twiddle_factors_fft.numel() * twiddle_factors_fft.element_size()) / ((clone_time/repeats) * 1000 * 2**30) )
# print("Copy speedup: ",  clone_time / butterfly_copy_time)
# print("Compute speedup: ",  clone_time / butterfly_transcendental_time)
# print("compute time vs copy time: ", butterfly_transcendental_time / butterfly_copy_time)
# print("max diff: ", (out- out_cuda).abs().max())
# print("butterfly time vs butterfly copy time", butterfly_time/ butterfly_copy_time)
# print(out[:, :, 1, :, :].abs().max()) 

# print(out[:, :, :, 0, :])
# print(out_cuda[:, :, :, 0, :])


# # print(out_cuda)
# # print(out)



clone_times = []
butterfly_copy_times = []
butterfly_compute_times = []
const_compute_times = []

for m in [4, 8, 16, 32]:
    N = l * m
    x = torch.randn(B, H, 1, m, m, l//m, device='cuda').to(torch.chalf)
    twiddle_factors_fft = torch.randn(1, 1, m, m, m, l//m, device='cuda').to(torch.chalf)
    
    print(x.shape)
    out = butterfly(x, twiddle_factors_fft)
    
    butterfly_copy_forward(x, twiddle_factors_fft)
    butterfly_copy_forward(x, twiddle_factors_fft)
    torch.cuda.synchronize()
    start = time.time()
    for i in range(repeats):
        out_cuda = butterfly_copy_forward(x, twiddle_factors_fft)
    torch.cuda.synchronize()
    butterfly_copy_time = (time.time() - start)/ repeats
    butterfly_copy_times.append(butterfly_copy_time)

    torch.clone(x)
    torch.cuda.synchronize()
    start = time.time()
    for i in range(repeats):
        z = torch.clone(x)
    torch.cuda.synchronize()
    clone_time_1 = (time.time() - start)


    torch.clone(twiddle_factors_fft)
    torch.cuda.synchronize()
    start = time.time()
    for i in range(repeats):
        z = torch.clone(twiddle_factors_fft)
    torch.cuda.synchronize()
    clone_time_2 = (time.time() - start)

    clone_time = (clone_time_1)/repeats
    clone_times.append(clone_time)
    
    butterfly_transcendental_forward(x)
    butterfly_transcendental_forward(x)
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(repeats):
        z = butterfly_transcendental_forward(x)
    torch.cuda.synchronize()
    butterfly_compute_time = (time.time() - start)/repeats
    butterfly_compute_times.append(butterfly_compute_time)
    
    
    butterfly_const_forward(x)
    butterfly_const_forward(x)
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(repeats):
        z = butterfly_const_forward(x)
    torch.cuda.synchronize()
    const_compute_time = (time.time() - start)/repeats
    const_compute_times.append(const_compute_time)
    
    print((out - out_cuda).abs().max())
    


import matplotlib.pyplot as plt
plt.plot([4, 8, 16, 32], butterfly_copy_times, '*-', label='butterfly')
plt.plot([4, 8, 16, 32], clone_times, '*-', label='clone')
plt.plot([4, 8, 16, 32], butterfly_compute_times, '*-', label='butterfly')
#plt.plot([4, 8, 16, 32], const_compute_times, '*-', label='const')
plt.legend(['butterfly_copy', 'clone', "butterfly_compute", "const_compute"])
#plt.legend(['butterfly_copy', "butterfly_compute", "const_compute"])
plt.xlabel('m')
plt.ylabel('time (s)')
plt.grid()
plt.savefig('butterfly_vs_clone.png')

print(butterfly_copy_times)
print(butterfly_compute_times)
print(clone_times)

print('Butterfly copy vs clone')
print(np.array(butterfly_copy_times) / np.array(clone_times))

print('Butterfly compute vs clone')
print(np.array(butterfly_compute_times) / np.array(clone_times))



from benchmark import benchmark_forward, benchmark_backward, benchmark_combined, pytorch_profiler, benchmark_memory


# print((out- out_cuda).abs().max())
# benchmark_forward(butterfly_copy_forward, x, twiddle_factors_fft, repeats=repeats)
# benchmark_forward(butterfly_transcendental_forward, x, twiddle_factors_fft, repeats=repeats)
# benchmark_forward(torch.clone, x,  repeats=repeats)


# out = butterfly(x, twiddle_factors_fft)
# out_cuda = butterfly_copy_forward(x, twiddle_factors_fft)
# print((out_cuda - x.squeeze()).abs().max())
# print(out.abs().max())
# print(out_cuda.abs().max())




# pytorch_profiler(butterfly_copy_forward, x, twiddle_factors_fft, cpu=False, backward=False, trace_filename='butterfly_copy_forward.json')
# pytorch_profiler(torch.clone, x, cpu=False, backward=False, trace_filename='torch_clone.json')
# pytorch_profiler(butterfly, x, twiddle_factors_fft, cpu=False, backward=False, trace_filename='butterfly.json')


y = torch.randn(B, H, N, device='cuda').to(torch.chalf)
# pytorch_profiler(torch.clone, y, cpu=False, backward=False, trace_filename='torch_clone_complex.json')

#print('manual bandwidth for clone complex: ', (x.numel() * x.element_size() * 2) / (clone_time/repeats * 1000 * 2**30) )



# print(twiddle_factors_fft.abs().min())
# print(twiddle_factors_fft.abs().max())
# print(twiddle_factors_fft.abs().mean())
# print((out - out_cuda).abs().max())
# print(out.abs().max())
# print(out_cuda.abs().max())




# print((x - twiddle_factors_fft).abs().max())
# print(x.abs().max())
# print(twiddle_factors_fft.abs().max())
# print(out.abs().max())
# print(out_cuda.abs().max())


print(out_cuda.shape)
print(x.squeeze().shape)
print((out_cuda - out.squeeze()).abs().max())