# # Load Package & Set simulation parameters
# load system packages
import getopt
import argparse

# load torch module
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# load modules for data processing
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pandas as pd

# import comparing float module & prealign_utils
from comparing_float import *
from comparing_float_cuda import *
from prealign_utils import *
from prealign_mm import fp_linear_with_pre_alignment

# import custom cuda
import decompose_fp_cuda
import reconstruct_fp_cuda
import prealign_linear_cuda
import bf16_prealign_linear_cuda
import bgemm_cuda
import sgemm_cuda

# for profiling
import torch.cuda.nvtx as nvtx

torch.set_printoptions(precision=10)

# prevent using tf32
torch.backends.cuda.matmul.allow_tf32 = False



"""
[STEP1] Get options 
"""
parser = argparse.ArgumentParser(description = 'Script to evaluate MAC results')

parser.add_argument('-w', '--wbits', default=4, type=int)
parser.add_argument('--num-repeats', default=2, type=int)
parser.add_argument('--num-batch', default=200, type=int)
parser.add_argument('--max-in-features', default=8192, type=int)
parser.add_argument('--out-features', default=128, type=int)
args = parser.parse_args()


"""
[STEP2] Evaluation of MAC results
"""
# repeated test
for i in range(0, args.num_repeats):
    # generate input  (args.num_batches, args.in_features)
    num_input = args.num_batch * args.max_in_features
    
    # generate random input data
    input_parent = torch.randn(args.num_batch, args.max_in_features)
    input_parent_cuda = input_parent.cuda()
    
    # generate random weight data
    q_weight_parent = torch.randn(args.out_features, args.max_in_features).gt(0).to(torch.float32)
    q_weight_parent = q_weight_parent.to(torch.float32)
    q_weight_parent_cuda = q_weight_parent.cuda()
    
    
    # input with bfloat16 format
    input_parent_bf16 = input_parent.to(torch.bfloat16).to(torch.float32)
    input_parent_cuda_bf16 = input_parent_cuda.to(torch.bfloat16).to(torch.float32)
    
    # correct vs cpu, gpu, pre_align
    num_in_features = math.ceil(math.log(args.max_in_features/128, 2))
    list_in_features = [128*(2**(i)) for i in range(0, num_in_features+1)]
    num_systolic_row_list = [128, 64, 32, -1]
    #list_in_features = [8192]
    #num_systolic_row_list = [8192]
    
    
    # scan with arraySize
    for in_features in list_in_features:
        print(f"test in_features: {in_features}")
        nvtx.range_push("in_features "+str(in_features))
        """ input: FP32 """
        # get input
        input = input_parent[:, 0:in_features]
        input_cuda = input_parent_cuda[:, 0:in_features].contiguous()

        """ input: BF16 """
        # get input
        input_bf16 = input_parent_bf16[:, 0:in_features]
        input_cuda_bf16 = input_parent_cuda_bf16[:, 0:in_features].contiguous()

        # get q_weight
        q_weight = q_weight_parent[:, 0:in_features]
        q_weight_cuda = q_weight_parent_cuda[:, 0:in_features].contiguous()
    

        # get nn.Linear layer to process MAC operation
        linear = nn.Linear(in_features, args.out_features, bias=False)
        with torch.no_grad():
            linear.weight.copy_(q_weight)
        
        # get torch gpu sum
        linear.cuda()
        with torch.no_grad():
            nvtx.range_push("gpu_mac")
            gpu_mac = linear.forward(input_cuda).contiguous()
            nvtx.range_pop()
            nvtx.range_push("gpu_mac_custom")
            sgemm_cuda.forward(input_cuda, q_weight_cuda, gpu_mac)
            nvtx.range_pop()
        
        ## get prealign sum with custom kernel
        rounding_mode = 1
        for num_systolic_row in num_systolic_row_list:
    
            if num_systolic_row == -1:
                num_systolic_row = in_features
    
            #for extra_bits in range(0, 4):
            for extra_bits in range(0, 1):
                nvtx.range_push(f"naive_prealign_fp32_{num_systolic_row}")
                tmp = fp_linear_with_pre_alignment(input_cuda, q_weight_cuda, num_systolic_row, extra_bits)
                nvtx.range_pop()
                nvtx.range_push(f"custom_prealign_fp32_{num_systolic_row}")
                prealign_linear_cuda.forward(input_cuda, q_weight_cuda, gpu_mac, num_systolic_row, extra_bits, rounding_mode)
                nvtx.range_pop()
        nvtx.range_pop()
   
