import subprocess as sp
import os
import sys

#sys.path.append(os.path.expanduser('~')+'/hrockmate/')
sys.path.append('../hrockmate/')
savepath='measure_result'

import gc
import argparse
import pandas as pd
import datetime
from copy import deepcopy
import importlib

import time
import numpy as np
import torch
import torch.nn as nn

from models.FNO1d import FNO1d
from models.FNO3d import FNO3d
from models.GPT import get_GPT
from models.UNO import Uno3D_T40
from models.UFNO import Net3d
from models.unet import UNet

import neuralop
import hrockmate

import rotor
from rotor import timing

import tensorly as tl
# from tensorly.plugins import use_opt_einsum
tl.set_backend('pytorch')

from argparse import Namespace
from configmypy import ConfigPipeline, YamlConfig, ArgparseConfig
from configmypy.utils import iter_nested_dict_flat
from memory_utils import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--arch', type=str, default='resnet18', 
                        choices=['fno1d', 'fno3d', 'tfno2', 'unet', 'uno3d_t40', 'yolop', 'gpt2', 'transformer', 'ufno', 'mlpmixer'] + rotor.models.resnet.__all__ + rotor.models.densenet.__all__ + rotor.models.inception.__all__ + rotor.models.vgg.__all__)
    parser.add_argument('--verbose', type=int, default=1)

    parser.add_argument('--fno.modes1', type=int, help='Number of Fourier modes to multiply, at most floor(N/2) + 1', default=16)
    parser.add_argument('--fno.modes2', type=int, help='Number of Fourier modes to multiply, at most floor(N/2) + 1', default=16)
    parser.add_argument('--fno.modes3', type=int, help='Number of Fourier modes to multiply, at most floor(N/2) + 1', default=16)
    parser.add_argument('--fno.width', type=int, help='Number of channels', default=64)
    parser.add_argument('--fno.block_number', type=int, help='Number of integral kernel block', default=4)

    parser.add_argument('--tfno2d.n_layers', type=int, help='Number of FNO layers')
    parser.add_argument('--tfno2d.implementation', type=str, default='reconstructed', choices=['reconstructed', 'factorized'])
    parser.add_argument('--tfno2d.factorization', type=str, default=None, choices=[None, 'tucker', 'cp', 'tt'],
        help="Tensor factorization of the parameters weight to use, by default 'tucker'")
    parser.add_argument('--tfno2d.rank', type=float, default=1.0, help='Rank of the tensor factorization of the Fourier weights')
    parser.add_argument('--tfno2d.separable', type=int, default=0)

    parser.add_argument('--gpt2.model', type=str, default='GPT2-small', choices=['GPT2-small', 'GPT2-medium', 'GPT2-large'])
    parser.add_argument('--gpt2.dropout', type=float, default=0.1)

 
    parser.add_argument('--transformer.num_decoder_layers', type=int, default=6, help='the number of sub-decoder-layers in the decoder')
    parser.add_argument('--transformer.num_encoder_layers', type=int, default=6)
    parser.add_argument('--transformer.n_head', type=int, default=16, help="the number of heads in the multiheadattention models")
    parser.add_argument('--transformer.d_model', type=int, default=512, help=" the number of expected features in the encoder/decoder inputs")
    parser.add_argument('--transformer.target_sequence_length', type=int, default=100, help='the source sequence length')
    parser.add_argument('--transformer.source_sequence_length', type=int, default=100, help='the target sequence length')
    
    parser.add_argument('--remat.memory_budget', type=int, help='Memory budget')
    parser.add_argument('--remat.budget_intervals', type=int, default=12, help='number of partitions for Memory budget')
    # parser.add_argument('--remat.HILP.nb_total_nodes', type=int, help='nb_total_nodes in the subgraph')
    parser.add_argument('--remat.algo', default='rotor', type=str, choices=['rotor', 'rockmate', 'twremat', 'hilp', 'checkmate', 'twremat+hilp',
                                                                            'hilp+rotor', 'twremat+hilp+rotor'])
    parser.add_argument('--remat.fix_memory_limit', action='store_true')
    parser.add_argument('--remat.train_iterations', type=int, default=15) 
    parser.add_argument('--remat.twremat.contains_data_node', action='store_false')

    parser.add_argument('--remat.hremat.max_size_S_graph_for_no_partitioning', type=int, default=40)
    parser.add_argument('--remat.hremat.partitioner_bottom_to_top.max_estimate_per_sub_graph', type=int, default=20) 
    parser.add_argument('--remat.hremat.partitioner_bottom_to_top.value_power_total_size', type=float, default=0.5) 

    
    #parser.add_argument('--data.subsampling_rate', type=int, help='subsampling_rate', default=3)
    parser.add_argument('--data.resolution_x', type=int, help='Grid size for X axis', default=64) # resolution = 2**13 // subsampling_rate
    parser.add_argument('--data.resolution_y', type=int, help='Grid size for Y axis', default=64)
    parser.add_argument('--data.resolution_z', type=int, help='Grid size for Z axis', default=40)
    parser.add_argument('--data.batch_size', type=int, default=20) # default = 10 for FNO3d
    parser.add_argument('--data.in_channels', type=int, default=1) # default = 10 for FNO3d

    parser.add_argument('--data.resolution', type=int, help='UNet image resolution', default=256)

    parser.add_argument('--data.train_resolution', type=int, help='Resolution of training data') #TFNO

    parser.add_argument('--model_dir', type=str)

 
    args = parser.parse_args()
    return args

class NestedNamespace(Namespace):
    def __setattr__(self, name, value):
        if '.' in name:
            name = name.split('.')
            name, rest = name[0], '.'.join(name[1:])
            if not hasattr(self, name):
                setattr(self, name, type(self)())
            setattr(getattr(self, name), rest, value)
        else:
            super().__setattr__(name, value)

    def dict(self):
        dict = {}
        for k, v in self.__dict__.items():
            if isinstance(v, NestedNamespace):
                dict[k] = v.dict()
            else:
                dict[k] = v
        return dict


def get_input(args):
    if args.arch == 'fno3d':
        resolution = (args.data.resolution_x, args.data.resolution_y, args.data.resolution_z)
        input_shape = (args.data.batch_size, *resolution, args.data.in_channels) 

    elif args.arch == 'fno1d':
        #resolution = 2 ** 13 // config.data.subsampling_rate
        resolution = args.data.resolution_x
        input_shape = (args.data.batch_size, resolution, args.data.in_channels)

    elif args.arch == 'tfno2d':
        resolution = args.data.train_resolution
        input_shape = (args.data.batch_size, 3, resolution, resolution)

    elif args.arch == 'unet':
        input_shape = (args.data.batch_size, 3, args.data.resolution, args.data.resolution)

    elif args.arch == 'uno3d_t40':
        S, T_in = 64, 10 
        input_shape = (args.data.batch_size, S, S, T_in, 1)
    
    elif args.arch == 'ufno':
        input_shape = (args.data.batch_size, 96, 200, 24, 12)

    elif args.arch == 'yolop':
        input_shape = (args.data.batch_size, 3, 256, 256)

    elif args.arch in ['Inception3', 'inception_v3']:
        assert args.data.batch_size >= 2
        input_shape = (args.data.batch_size, 3, 299, 299)
    
    elif args.arch == 'gpt2':

        if(args.gpt2.model == 'GPT2-large'):
            input = torch.randint(0,600, [256, args.data.batch_size]) 
        else:
            input = torch.randint(0,600, [500, args.data.batch_size]) 

        print(input.shape)
        return input

    elif args.arch == 'transformer':
        src = torch.rand((args.transformer.source_sequence_length, args.data.batch_size, args.transformer.d_model))
        tgt = torch.rand((args.transformer.target_sequence_length, args.data.batch_size, args.transformer.d_model))

        return [src, tgt]
    
    elif args.arch == "mlpmixer":
        input_shape = (args.data.batch_size, 3, 256, 256)

    else:
        input_shape = (args.data.batch_size, 3, 224, 224)

    input = torch.rand(*input_shape)

    return input

def get_model(args):
    if args.arch == 'fno3d':
        if 'V100' in torch.cuda.get_device_name():
            AssertionError('FNO3d has problems with running on V100, try A100 instead')

        model = FNO3d(
            args.fno.modes1,
            args.fno.modes2,
            args.fno.modes3,
            args.fno.width,
            block_number=args.fno.block_number
            )

    elif args.arch == 'fno1d':
        model = FNO1d(
            args.fno.modes1,
            args.fno.width,
            block_number=args.fno.block_number
            )

    elif args.arch == 'tfno2d':
        model = neuralop.get_model(args)

    elif args.arch == 'unet':
        model = UNet()

    elif args.arch == 'uno3d_t40':
        model = Uno3D_T40(in_width = 6, width = 8)

    elif args.arch == 'yolop':
        AssertionError('Does not work correctly yet, some bugs with inputs')
        from lib.models.YOLOP import get_netget_checkp_model
        model = get_net(False)

    elif args.arch in ['Inception3', 'inception_v3']:
        model = getattr(rotor.models, args.arch)(transform_input = True, aux_logits=False)

    elif args.arch == 'gpt2':
        model = get_GPT(args.gpt2.model)
    
    elif args.arch == 'ufno':
        mode1 = 10
        mode2 = 10
        mode3 = 10
        width = 36
        model = Net3d(mode1, mode2, mode3, width)

    elif args.arch == 'transformer':
        model = nn.Transformer(args.transformer.d_model, args.transformer.n_head, args.transformer.num_encoder_layers, args.transformer.num_decoder_layers)
    
    elif args.arch == 'mlpmixer':
        from mlp_mixer_pytorch import MLPMixer

        model = MLPMixer(
            image_size = 256,
            channels = 3,
            patch_size = 16,
            dim = 512,
            depth = 12,
            num_classes = 1000
        )
    else:
        model = getattr(rotor.models, args.arch)()
        # model = rotor.models.resnet18()


    return model


def run_rk_model(model, sample, budgets, algos:str, device='cuda', verbose=0,result_columns=[], nbar=10, nall=10):
    
    # Build Rockmate model
    _model = deepcopy(model).to(device)

    if isinstance(sample, list):
        _sample = []
        for input in sample:
            _input = deepcopy(input).to(device)
            _sample.append(_input)
    else:
        _sample = deepcopy(sample).to(device)

    start = time.time()
    rkMod = hrockmate.Rockmate(_model, _sample, max(budgets), nb_budget_abar=nbar, nb_budget_all=nall)
    solve_time = time.time() - start
    
    results = []
    for budget in budgets:
        
        print(f'budget -- {budget}')
        
        result = dict().fromkeys(result_columns)
        try:
            rkMod.get_sequence(budget)
            rkMod.get_compiled_fct()
            peak_mem, times = exec_rkmod(rkMod, _sample)
            # result["feasible"] = True
            result["average_time"] = np.mean(times[5:])
            result["times"] = times
            result["peak_mem"] = peak_mem
            result["budget"] = budget
            result["solve_time"] = solve_time
        except Exception as e:
            result["error"] = e

        results.append(result)
        


    try:
        print('run pytorch model')

        result = dict().fromkeys(result_columns)

        original_mod = deepcopy(rkMod.original_mod)
        del rkMod
        torch.cuda.empty_cache()

        peak_mem, times = exec_torch_mod(original_mod, _sample)
        result["average_time"] = np.mean(times[5:])
        result["times"] = times
        result["peak_mem"] = peak_mem
    except Exception as e:
        result["error"] = e

    results.append(result)

    return results
        
    
def run_H_rk_model(model:torch.nn.Module, sample, budgets, algos:str,result_columns=[], **dict_kwargs):
    print('run h_rk_model')

    for n, p in model.named_parameters():
        if p.grad is None:
            p.grad = torch.zeros_like(p)

    #Build HRemat module with given remat solvers
    _model = deepcopy(model).to(device)
    
    # print(dict_kwargs)
    list_solver = []
    solvers = {"hilp":hrockmate.solvers.HILP(), "rotor": hrockmate.solvers.RK_rotor(), "twremat": hrockmate.solvers.TwRemat()}
    for solver_name, solver in solvers.items():
        if solver_name in algos.lower():
            # if f"{solver_name}_config" in dict_kwargs:
            #     for k,v in dict_kwargs[f"{solver_name}_config"].items():
            #         setattr(solver.config, k, v)
            list_solver.append(solver)

            # print(solver.config.nb_total_nodes)

    if isinstance(sample, list):
        _sample = []
        for input in sample:
            _input = deepcopy(input).to(device)
            _sample.append(_input)
    else:
        _sample = deepcopy(sample).to(device)

    
    print('create HRemat')
    start = time.time()
    if algos.lower()=='checkmate':
        max_size_S_graph_for_no_partitioning = 999 # 60
        partitioners = None
        # partitioners = [rockmate.rkgb.Ptools.Partitioner()]

    elif algos.lower()=='rockmate':
        max_size_S_graph_for_no_partitioning = 0
        partitioners = [hrockmate.rkgb.Ptools.Partitioner_seq(sub_partitioner=hrockmate.rkgb.Ptools.Partitioner())]
    
    
    elif not ('rotor' in algos): # hilp, twremat+hilp
        max_size_S_graph_for_no_partitioning=args.remat.hremat.max_size_S_graph_for_no_partitioning
        partitioners = [hrockmate.rkgb.Ptools.Partitioner_bottom_to_top(
            can_use_rotor=False,
            value_power_total_size=args.remat.hremat.partitioner_bottom_to_top.value_power_total_size,
            max_estimate_per_sub_graph=args.remat.hremat.partitioner_bottom_to_top.max_estimate_per_sub_graph)]

    else: # twremat+hilp+rotor
        max_size_S_graph_for_no_partitioning=args.remat.hremat.max_size_S_graph_for_no_partitioning
        partitioners = [hrockmate.rkgb.Ptools.Partitioner_bottom_to_top(
            can_use_rotor=True,
            value_power_total_size=args.remat.hremat.partitioner_bottom_to_top.value_power_total_size,
            max_estimate_per_sub_graph=args.remat.hremat.partitioner_bottom_to_top.max_estimate_per_sub_graph),
            hrockmate.rkgb.Ptools.Partitioner_seq(
            hrockmate.rkgb.Ptools.Partitioner_bottom_to_top(
                can_use_rotor=True,
                value_power_total_size=args.remat.hremat.partitioner_bottom_to_top.value_power_total_size,
                main_graph_as_any_other=True,
                max_estimate_per_sub_graph=args.remat.hremat.partitioner_bottom_to_top.max_estimate_per_sub_graph))]

    rkMod = hrockmate.HRockmate(
            _model, _sample, max(budgets), 
            list_solvers=list_solver, 
            partitioners=partitioners,
            max_size_S_graph_for_no_partitioning=max_size_S_graph_for_no_partitioning
        )
    


    # rkMod.save_to_local(path=f'/beegfs/{user}/elfram/')
    solve_time = time.time() - start
    print('finish HRemat')
    
    results = []
    for budget in budgets:
        print(f'budget---{budget}')

        result = dict.fromkeys(result_columns)
        try:
            rkMod.solve_sched(budget, rec=False)
            rkMod.get_compiled_fct()
            peak_mem, times = exec_rkmod(rkMod, _sample)
            result["average_time"] = np.mean(times[5:])
            result["times"] = times
            result["peak_mem"] = peak_mem
            result["budget"] = budget
            result["solve_time"] = solve_time
        except Exception as e:
            result["error"] = e

        results.append(result)

        rkMod.op_sched = None
        torch.cuda.empty_cache()
        gc.collect()
    

    original_mod = deepcopy(rkMod.original_mod)
    del rkMod
    del _model

    torch.cuda.empty_cache()

    result = dict.fromkeys(result_columns)
    try:
        peak_mem, times = exec_torch_mod(original_mod, _sample)
        result["average_time"] = np.mean(times[5:])
        result["times"] = times
        result["peak_mem"] = peak_mem
    except Exception as e:
        result["error"] = e

    del original_mod
    del _sample
    torch.cuda.empty_cache()
    gc.collect()
    
    results.append(result)
        
    return results


def exec_rkmod(rkMod, inputs, repeat=15, dict_kwargs=None):
    # given rkMod and inputs, return time and memory in real execution
    rkMod.reinit()
    dict_inputs = hrockmate.rkgb.make_inputs(rkMod.original_mod, deepcopy(inputs), dict_kwargs)
    return exec_mod(rkMod, dict_inputs, repeat)

def exec_torch_mod(module, inputs, repeat=15, dict_kwargs=None):
    # given rkMod and inputs, return time and memory in real execution
    dict_inputs = hrockmate.rkgb.make_inputs(module, deepcopy(inputs), dict_kwargs)
    return exec_mod(module, dict_inputs, repeat)

def exec_mod(module, dict_inputs, repeat=15):
    
    for n, p in module.named_parameters():
        if p.grad is None:
            p.grad = torch.zeros_like(p)

    times = []
    timer = timing.make_timer(device)
    try:
        torch.cuda.reset_peak_memory_stats()
        max_before = torch.cuda.max_memory_allocated()
        for _ in range(repeat):
            timer.start()
            torch.random.manual_seed(0)
            y = module(**dict_inputs)
            loss = y.mean()
            loss.backward()
            
            timer.end()
            times.append(timer.elapsed())

            del y
            del loss
        
        peak_mem = torch.cuda.max_memory_allocated() - max_before
        print(f"*********************CUDA measured peak_mem is {peak_mem}")

    except Exception as e:
        print(e)
        peak_mem = None
        if type(e) != torch.cuda.OutOfMemoryError:
            raise e
    
    del dict_inputs
    torch.cuda.empty_cache()
    
    return peak_mem, times

def main(args):

    config_dict = {k: v for k,v in iter_nested_dict_flat(args)} # args.dict() is a nested dict
    config_columns = list(config_dict.keys())
    config_values = [config_dict[c] for c in config_columns]


    model = get_model(args)
    sample = get_input(args)

    if args.remat.memory_budget:
        mem_limits = [args.remat.memory_budget*1024**2]
    else:
        
        mem_limits = get_memory_limits_from_graph(model, sample, args.remat.budget_intervals)
        
        torch.cuda.empty_cache()
        gc.collect()

    
    results = []
    result_columns = ['average_time', 'peak_mem', 'budget', 'sequence', 'times', 'error', 'solve_time']
    
    partial_results = run_H_rk_model(model, sample, mem_limits, result_columns=result_columns, algos=args.remat.algo, repeat=args.remat.train_iterations, device=device,  verbose=args.verbose)

    # if args.remat.algo == 'rockmate':
    #     partial_results = run_rk_model(model, sample, mem_limits, result_columns=result_columns, algos=args.remat.algo, device=device,  verbose=args.verbose)
    # else:
    #     # config = {f'{args.remat.algo}_config' : {'nb_total_nodes' : args.remat.HILP.nb_total_nodes}}
    #     partial_results = run_H_rk_model(model, sample, mem_limits,result_columns=result_columns, algos=args.remat.algo, repeat=args.remat.train_iterations, device=device,  verbose=args.verbose)

    for partial_result in partial_results:
        result = config_values + list(partial_result.values())
        results.append(result)

    torch.cuda.empty_cache()
    gc.collect()

    df = pd.DataFrame(results, columns = config_columns + result_columns)
    print(df)

    os.makedirs(savepath, exist_ok=True) 

    current_time = datetime.datetime.now()

    if('fno' in args.arch):
        df.to_csv(f'{savepath}/{args.remat.algo}-{args.arch}-{args.data.batch_size}-{args.fno.block_number}-{current_time}.csv', index=False)
    
    elif('gpt2' in args.arch):
        df.to_csv(f'{savepath}/{args.remat.algo}-{args.gpt2.model}-{args.data.batch_size}-{current_time}.csv', index=False)
    
    else:
        df.to_csv(f'{savepath}/{args.remat.algo}-{args.arch}-{args.data.batch_size}-{current_time}.csv', index=False)
    # df.to_csv(f'{savepath}/{current_time}.csv', index=False)

    del model
    del sample
    torch.cuda.empty_cache()

    return results



if __name__=="__main__":
    # args = get_args()
    # args = NestedNamespace(**get_args().__dict__)

    # Read the configuration

    remat_config = YamlConfig('./remat_config.yaml', config_name='default', config_folder='./config')

    args = get_args()

    if(args.arch == 'resnet18'):
        model_config = YamlConfig('./resnet_config.yaml', config_name='default', config_folder='./config')
    
    elif(args.arch == 'transformer'):
        model_config = YamlConfig('./transformer_config.yaml', config_name='default', config_folder='./config')
    
    elif(args.arch == 'gpt2'):
        model_config = YamlConfig('./gpt2_config.yaml', config_name='default', config_folder='./config')
        print(model_config)
    
    elif(args.arch == 'unet'):
        model_config = YamlConfig('./unet_config.yaml', config_name='default', config_folder='./config')
    
    elif('fno' in args.arch):
        model_config = YamlConfig('./fno_config.yaml', config_name='default', config_folder='./config')
    else:
        model_config = YamlConfig('./resnet_config.yaml', config_name='default', config_folder='./config')

    pipe = ConfigPipeline([remat_config, model_config ,ArgparseConfig(),])
    args = pipe.read_conf()
    args.model_dir = os.path.expanduser('~')
    print(args)
    main(args)
