import subprocess as sp
import sys
import gc

from copy import deepcopy

import numpy as np
import torch

import rotor

from hrockmate.solvers.def_op import RunOp

import tensorly as tl
# from tensorly.plugins import use_opt_einsum
tl.set_backend('pytorch')

from graph_utils import get_rockmate_graphs, get_autograd_schedule_from_graph, get_twremat_graph, twremat_to_rockmate_schedule

def get_gpu_memory():
    command = "nvidia-smi --query-gpu=memory.free --format=csv"
    memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
    memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
    return memory_free_values

def mem_usage():
    # gc.collect()

    if hasattr(torch.cuda, "reset_peak_memory_stats"):  # pytorch 1.4+
        torch.cuda.reset_peak_memory_stats()

    ma = torch.cuda.memory_allocated() / (1024 * 1024)
    max_ma = torch.cuda.max_memory_allocated() / (1024 * 1024)
    ca = torch.cuda.memory_reserved() / (1024 * 1024)
    max_ca = torch.cuda.max_memory_reserved() / (1024 * 1024)

    
    print(
        f"MA {round(ma, 4)} MB \
        Max_MA {round(max_ma, 4)} MB \
        CA {round(ca, 4)} MB \
        Max_CA {round(max_ca, 4)} MB "
        )
    if hasattr(torch.cuda, "reset_peak_memory_stats"):  # pytorch 1.4+
        torch.cuda.reset_peak_memory_stats()

    return (ma, max_ma, ca, max_ca)



def get_dtype_bytes(dtype):
    if dtype == torch.float32:
        return 4
    elif dtype == torch.float64:
        return 8
    elif dtype == torch.float16:
        return 2

def get_mem_params(model):
    mem_params = 0
    for _, p in model.named_parameters():
        if p is not None:
            mem_params += p.numel() * get_dtype_bytes(p.dtype)
    return mem_params

def get_memory_for_schedule(op_list, alive_list, K_graph):
    
    data_nodes = np.array(K_graph.list_kdn + [K_graph.input_kdn_grad, K_graph.input_kdn_data])
    comp_nodes_overhead = {cnode.name : cnode.overhead for cnode in K_graph.list_kcn}

    mem_list = []

    for alive_list, op in zip(alive_list, op_list):
        mem = sum(dnode.mem for dnode in data_nodes[alive_list])
        if isinstance(op, RunOp):
            mem += comp_nodes_overhead[op.name]
        mem_list.append(mem)

    return max(mem_list)
    

def get_max_memory_from_graph(K_graph):

    op_list, alive_list = get_autograd_schedule_from_graph(K_graph)
    max_memory = get_memory_for_schedule(op_list, alive_list, K_graph)

    return max_memory

def get_min_memory_from_graph(K_graph):

    mem_dict= {}
    target = K_graph.list_kcn[-1].unique_id

    for cnode in K_graph.list_kcn:

        ### Memory to store input tensors of the operation
        mem_inputs = sum(dnode.mem for dnode in cnode.deps_real)

        # if (cnode.main_target != 'loss'):
        if cnode.unique_id != target:
            ### Memory to store output tensors of the operation
            mem_outputs = sum(dnode.mem for dnode in cnode.users)

            ### Maximum among memories to store input tensors to the next computation
            mem_contributors = 0
            for dnode in cnode.users:
                mem_contributors_to_dnode_users = max([sum([dn.mem for dn in cn.deps_real if not (dn in cnode.users)]) for cn in dnode.users_real])
                mem_contributors = max(mem_contributors, mem_contributors_to_dnode_users)

        mem_dict[cnode.unique_id] = mem_inputs + mem_outputs + mem_contributors + cnode.overhead

    min_memory = max(mem_dict.values())
    return min_memory

        

def get_memory_limits_from_graph(model, sample, intervals):

    K_graph = get_rockmate_graphs(model, sample)
    torch.cuda.empty_cache()
    gc.collect()

    min_memory = get_min_memory_from_graph(K_graph)
    max_memory = get_max_memory_from_graph(K_graph)

    gpu_budget = get_gpu_memory()[0]
    max_memory = min(max_memory, int(gpu_budget*0.87)*1024*1024)

    delta_memory = (max_memory - min_memory)//intervals
    
    mem_limits = [(max_memory - delta_memory*k) for k in range(intervals)]
    print(f'Memory limits {[mem/1024**2 for mem in mem_limits]}')

    K_graph.release_fake_input_kdn_grad()

    del K_graph
    torch.cuda.empty_cache()
    gc.collect()

    return mem_limits


def get_memory_limits_with_twremat(model, sample):

    # Get graphs with rk-GB
    K_graph = get_rockmate_graphs(model, sample)
    torch.cuda.empty_cache()
    gc.collect()

    # Generate input graph for twremat
    node_info, target, _ = get_twremat_graph(K_graph, contains_data_node=False)

    # Call twremat to generate schedule
    steps = twremat_caller.runtwremat(node_info, get_gpu_memory()[0]*1024**2, target,[])

    # Convert to the schedule compatible with rk-Exec
    _, sched = twremat_to_rockmate_schedule(K_graph, steps)

    max_memory = get_memory_for_schedule(sched.op_list, sched.alive_list, K_graph)
    gpu_budget = get_gpu_memory()[0]
    max_memory = min(max_memory, int(gpu_budget*0.87*1024*1024))
    min_memory = get_min_memory_from_graph(K_graph)

    delta_memory = (max_memory - min_memory)//4
    
    mem_limits = [(max_memory - delta_memory*k) for k in range(4)]
    print(f'Memory limits {[mem/1024**2 for mem in mem_limits]}')
    return mem_limits


def get_memory_limits_with_rotor(model, sample, device='cuda'):

    _model = deepcopy(model).to(device)
    for n, p in _model.named_parameters():
            if p.grad is None:
                p.grad = torch.zeros_like(p)
    _sample = sample.clone().to(device).requires_grad_()

    gpu_budget = get_gpu_memory()[0]

    checkp_model = rotor.Checkpointable(_model, _sample, verbosity=0)
    
    checkp_model.compute_pytorch_sequence() 
    max_memory = checkp_model.get_expected_memory()
    max_memory = min(max_memory, int(gpu_budget*0.87)*1024*1024)
    
    checkp_model.compute_min_sequence()
    min_memory = checkp_model.get_expected_memory()

    delta_memory = (max_memory - min_memory)//4
    
    mem_limits = [(max_memory - delta_memory*k) for k in range(4)]
    print(f'Memory limits {[mem/1024**2 for mem in mem_limits]}')

    del _model
    del checkp_model
    del _sample

    return mem_limits
