from libs import *
import argparse
import torch.autograd.profiler as profiler
DEBUG = False


class ProfileResult:
    def __init__(self, 
                 result_file,
                 num_iters=1,
                 cuda=True) -> None:
        '''
        Hard-coded result computation based on torch.autograd.profiler
        text printout, tested for PyTorch 1.8-1.10
        columns = ['Name', 'Self CPU %', 'Self CPU',
                        'CPU total %', 'CPU total', 'CPU time avg',
                        'Self CUDA', 'Self CUDA %', 'CUDA total', 'CUDA time avg',
                        'CPU Mem', 'Self CPU Mem', 'CUDA Mem', 'Self CUDA Mem',
                        '# of Calls', 'FLOPS']
        last col can be MFLOPS or GFLOPS
        '''
        
        self.result_file = result_file
        self.num_iters = num_iters
        self.cuda = cuda
        self._initialize()
    
    def _initialize(self) -> None:
        df = pd.read_csv(self.result_file,
                              sep=r'\s{2,}',
                              engine='python',
                              header=None,
                              skiprows=range(1))
        self.columns = df.iloc[0].values
        df = df.iloc[2:]
        df.columns = self.columns
        self.cpu_time_total = df.iloc[-2, 0].replace('Self CPU time total: ', '')
        if self.cuda:
            self.cuda_time_total = df.iloc[-1, 0].replace('Self CUDA time total: ', '')
        self.df = df.iloc[:-3]


    def compute_total_mem(self, col_names):
        total_mems = []
        for col_name in col_names:
            total_mem = 0
            col_vals = self.df[col_name].values
            for val in col_vals:
                if val is not None:
                    if val[-2:] == 'Gb':
                        total_mem += self.get_str_val(val[:-2])
                    elif val[-2:] == 'Mb':
                        total_mem += self.get_str_val(val[:-2])/1e3
            total_mems.append(round(total_mem, 2))
        return total_mems

    def compute_total_time(self, col_names):
        total_times = []
        for col_name in col_names:
            total_time = 0
            col_vals = self.df[col_name].values
            for val in col_vals:
                if val is not None:
                    if val[-2:] == 'ms':
                        total_time += float(val[:-2])
                    elif val[-2:] == 'us':
                        total_time += float(val[:-2])/1e3
            total_times.append(round(total_time, 2))
        return total_times

    def compute_total(self, col_names):
        totals = []
        for col_name in col_names:
            total = 0
            col_vals = self.df[col_name].values
            for val in col_vals:
                if val is not None:
                    if val[-1].isnumeric():
                        total += float(val)
            totals.append(round(total, 2))
        return totals

    def print_total_mem(self, col_names):
        total_mems = self.compute_total_mem(col_names)
        for i, col_name in enumerate(col_names):
            print(f"{col_name} total: {total_mems[i]} GB")

    def print_total(self, col_names):
        totals = self.compute_total(col_names)
        for i, col_name in enumerate(col_names):
            print(f"{col_name} total: {totals[i]}")

    def print_total_time(self):
        print(f"# of backprop iters: {self.num_iters}")
        print(f"CPU time total: {self.cpu_time_total}")
        if self.cuda:
            print(f"CUDA time total: {self.cuda_time_total}")

    def print_total_flops(self, flops_col):
        totals = self.compute_total(flops_col)
        for i, col in enumerate(flops_col):
            print(f"{col}: {totals[i]}")

    def print_flops_per_iter(self, flops_col, avg_time=False):
        """FLOPS per iter"""
        totals = self.compute_total(flops_col)
        cuda_time_total = re.findall(r'\d+\.*\d*', self.cuda_time_total)[0]
        for i, col in enumerate(flops_col):
            flops = totals[i]/float(cuda_time_total) if avg_time else totals[i]
            print(f"{col} per iteration: {flops/self.num_iters}")

    @staticmethod
    def get_str_val(string):
        if string[0] == '-':
            return -float(string[1:])
        else:
            return float(string)



def main():

    # Training settings
    parser = argparse.ArgumentParser(
        description='FLOPs profiling')
    parser.add_argument('--model', type=str, default='uit', metavar='model',
                        help='evaluation model name, uit (integral transformer), uit-c3 (UIT with 3 channels) , ut (with traditional softmax normalization), hut (hybrid ut with linear attention), xut (cross-attention with hadamard product interaction), fno2d (Fourier neural operator 2d), afno2d (FNO with token mixing layer), mwo (Multiwavelet neural operator), unet (traditional UNet with CNN, big baseline, 33m params), unets (UNet with the same number of layers with U-integral transformer), deeponet (Deep Operator Net with a CNN branch). default: uit)')
    parser.add_argument('--batch-size', type=int, default=8, metavar='N',
                        help='input batch size for profiling (default: 8)')
    parser.add_argument('--grid-size', type=int, default=128, metavar='n',
                        help='input grid size n (default: 128)')
    parser.add_argument('--num-iter', type=int, default=100, metavar='k',
                        help='input number of iteration of backpropagations for profiling (default: 1)')
    parser.add_argument('--no-memory', action='store_true', default=False,
                        help='disables memory profiling')
    parser.add_argument('--no-flops', action='store_true', default=False,
                        help='disables FLOPs profiling')
    parser.add_argument('--eval', action='store_true', default=False,
                        help='Evaluation')
    parser.add_argument('--no-profile', action='store_true', default=False,
                        help='print result only no profiling')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA in profiling')
    args = parser.parse_args()
    cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device('cuda' if cuda else 'cpu')
    current_path = os.path.dirname(os.path.abspath(__file__))
    
    config = load_yaml(r'./configs.yml', key=args.model)
    print("="*10+f"Model setting for {args.model}:"+"="*10)
    for a in config.keys():
        if not a.startswith('__'):
            print(f"{a}: {config[a]}")
    print("="*33)

    if args.model in ["uit", "uit-c3", "uit-c", "ut", "xut"]:
        from libs.ut import UTransformer
        model = UTransformer(**config)
    elif args.model in ["hut"]:
        from libs.hut import HybridUT
        model = HybridUT(**config)
    elif args.model in ["fno2d", "fno2d-big", "afno2d"]:
        from libs.fno import FourierNeuralOperator
        model = FourierNeuralOperator(**config)
    elif args.model in ["unet", "unet-small"]:
        from libs.unet import UNet
        model = UNet(**config)
    elif args.model in ["mwo"]:
        from libs.mwo import MWT2d
        model = MWT2d(**config)
    elif args.model in ["deeponet"]:
        from libs.deeponet import DeepONet2d
        model = DeepONet2d(**config)
    else:
        raise NotImplementedError

    print(f"\nNumber of params: {get_num_params(model)}")
    model.to(device);

    n_grid = args.grid_size
    x = torch.randn(args.batch_size, n_grid, n_grid, 1).to(device)
    gradx = torch.randn(args.batch_size, n_grid, n_grid, 2).to(device)
    target = torch.randn(args.batch_size, n_grid, n_grid, 1).to(device)
    grid = torch.randn(args.batch_size, n_grid, n_grid, 2).to(device)

    if not args.no_profile:
        with profiler.profile(profile_memory=not args.no_memory,
                            with_flops=not args.no_flops,
                            use_cuda=cuda,) as pf:
            with tqdm(total=args.num_iter, disable=(args.num_iter<10)) as pbar:
                for _ in range(args.num_iter):
                    y = model(x, gradx, grid)
                    y = y['preds']
                    if not args.eval:
                        loss = ((y-target)**2).mean()
                        loss.backward()
                    pbar.update()

    sort_by = "self_cuda_memory_usage" if cuda else "self_cpu_memory_usage"
    file_name = os.path.join(current_path, f'profile_result_{args.model}.txt')
    with open(file_name, 'w') as f:
        if not args.no_profile:
            print(pf.key_averages().table(sort_by=sort_by,
                                        row_limit=300,
                                        ), file=f)
        if not args.no_flops and not args.no_memory:
            pf = ProfileResult(file_name, num_iters=args.num_iter, cuda=cuda)
            if cuda:
                pf.print_total_mem(['CUDA Mem'])
            if 'FLOPs' in pf.columns[-1]:
                pf.print_flops_per_iter([pf.columns[-1]])
            pf.print_total_time()

if __name__ == '__main__':
    main()
