import os
import sys

from copy import deepcopy
import numpy as np
import torch

from hrockmate import rkgb
from hrockmate.solvers.def_op import RunOp, DelOp, OpSchedule

import tensorly as tl
# from tensorly.plugins import use_opt_einsum
tl.set_backend('pytorch')



def get_rockmate_graphs(model, sample, device='cuda'):
    _model = deepcopy(model).to(device)
    # for _, p in _model.named_parameters():
    #         if p.grad is None:
    #             p.grad = torch.zeros_like(p)
    
    
    if isinstance(sample, list):
        _sample = []
        for input in sample:
            _sample.append(deepcopy(input).to(device))
    else:
        _sample = deepcopy(sample).to(device)

    # Get graphs with rk-GB
    try:
        rkgb_res = rkgb.make_all_graphs(_model, _sample, wanted_graphs=['K'])
        # rkgb.print_all_graphs(rkgb_res,name="fno1d",render_format="pdf")
        # forward_graph = rkgb_res.S_graph
        K_graph = rkgb_res.K_graph

        K_graph.fake_input_kdn_grad()
    
    except Exception as e:
        del _model 
        del _sample
        print(e)
        raise Exception("K-graph can't be built on one GPU for given model hyperparameters")

    del _model 
    del _sample

    torch.cuda.empty_cache()

    return K_graph

def get_autograd_schedule_from_graph(K_graph):

    data_nodes = K_graph.list_kdn + [ K_graph.input_kdn_grad, K_graph.input_kdn_data]
    kdn_idx_id_mapping = {idx: data_nodes[idx].unique_id for idx in range(len(data_nodes))}
    kdn_id_idx_mapping = {v : k for k,v in kdn_idx_id_mapping.items()}

    alive_status = np.zeros((len(data_nodes),))
    alive_list = []
    op_list = []
    op_name_list = []

    alive_status[kdn_id_idx_mapping[K_graph.input_kdn_data.unique_id]] = 1

    for i, cnode in enumerate(K_graph.list_kcn):

        for dnode in cnode.users:
            idx = kdn_id_idx_mapping[dnode.unique_id]
            alive_status[idx] = 1

        alive_list.append(alive_status.copy())

        op = RunOp(cnode)
        op.no_grad = False
        op_list.append(op)
        op_name_list.append([op.name, op.op_type])

        deps_cnodes_after = set()
        for cn in K_graph.list_kcn[i+1: ]:
            # for dn in set.union(cn.deps_real, cn.deps_fake):
            for dn in set.union(cn.deps_real):
                deps_cnodes_after.add(dn.unique_id)

        deps_cnodes_after.update([K_graph.input_kdn_data.unique_id])

        for id, idx in kdn_id_idx_mapping.items():

            if alive_list[-1][idx] == 1 and not (id in deps_cnodes_after):
                
                alive_status[idx] = 0
                alive_list.append(alive_status.copy())
                
                op = DelOp(data_nodes[idx])
                op.proxy = True
                op_list.append(op)
                op_name_list.append([op.name, op.op_type])

    return op_list, np.array(alive_list, dtype=bool)  

def get_twremat_graph(K_graph, contains_data_node=False, allow_loss_recomputation=False):
    '''Builds fw-bckw graph, containing only computational nodes

    For each computational node `cnode` from K_graph collects:
    - memory info on input data tensors, `mem_inputs`
    - memory info on output data tensors, `mem_outputs`
    - memory info on computational overhead, `cnode.overhead`
    - set of parent computational nodes, `deps`
    
    Parameters
    ----------
    K_graph: rkgb.K_graph 
        A data flow graph obrained with rk-GB, it has two types of nodes: data and computation nodes.

    Returns
    -------
        dict of dict
            Returns a dictionary `node_info`, where keys are ids of computational nodes and each value is a dictionary,
            which contains collected info on cnode in the format compatible with input to Haskell code of TWRemat, i.e.
        ::

            node_info[cnode.unique_id]= {
            'type': 'normal',
            'cpu': mem_inputs + mem_outputs,
            'mem': mem_outputs,
            'overhead': cnode.overhead, 
            'deps': cnode_deps
            }
        int
            Id of the computational node which is indicated as a final node (the node, which does not have any children nodes)
    '''

    node_info = {}
    target = K_graph.list_kcn[-1].unique_id # final computational node


    # for cnode in K_graph.list_kcn:
    #     mem_inputs = 0
    #     mem_outputs = 0
    #     cnode_deps = []

    #     for dnode in cnode.deps_real:
    #         mem_inputs += dnode.mem
            
    #         for cn in dnode.deps:
    #             cnode_deps.append(cn.unique_id)
        
    for cnode in K_graph.list_kcn:
        mem_inputs = 0
        mem_outputs = 0
        cnode_deps = [cn.unique_id for cn in cnode.deps_through_size_artefacts]

        for dnode in cnode.deps_real:
            mem_inputs += dnode.mem
            
            for cn in dnode.deps:
                cnode_deps.append(cn.unique_id)
        
        for dnode in cnode.deps_fake:
            mem_inputs += dnode.mem
            for cn in dnode.deps:
                cnode_deps.append(cn.unique_id)

        # if (cnode.main_target != 'loss'):
        if cnode.unique_id != target:
            for dnode in cnode.users:
                mem_outputs += dnode.mem

        node_info[cnode.unique_id]= {
        'type': 'normal',
        'cpu': mem_inputs + mem_outputs,
        'mem': mem_outputs,
        'overhead': cnode.overhead if cnode.overhead else 0, 
        'deps': cnode_deps
        }

    if not allow_loss_recomputation:
        loss_node = [[node for node in K_graph.list_kcn if 'loss' in node.name][0].unique_id]
    else:
        loss_node = []

    return node_info, [target], loss_node


def twremat_to_rockmate_schedule(K_graph, steps):
    kcn_id_to_node = {cnode.unique_id: cnode for cnode in K_graph.list_kcn}
    data_nodes = K_graph.list_kdn + [K_graph.input_kdn_grad, K_graph.input_kdn_data]

    print(f"OUTPUT SIZE TO KEEP ----{K_graph.output_kdn_data.mem}")

    op_list = []
    op_name_list = []
    alive_list = []
    alive_status = np.zeros(len(data_nodes), dtype=bool)
    alive_status[-1] = True # input data tensor should be always alive!

    for i, step in enumerate(steps):
        step_type, cnode_id = step

        cnode_id = int(cnode_id)
        cnode = kcn_id_to_node[cnode_id]

        if step_type == 'compute':

            ### Output data of the cnode computation should be stored
            for dnode in cnode.users:
                alive_status[data_nodes.index(dnode)] = 1
            alive_list.append(alive_status.copy()) 

            op = RunOp(cnode)
            op.no_grad = False
            op_list.append(op)
            op_name_list.append([op.name, op.op_type])

            ### Input data to the cnode computation that won't be used by later computations should be deleted

            # Get id of all computations that appear in the schedule after cnode
            cnodes_compute_after = set([step[1] for step in list(filter(lambda step: step[0]=='compute', steps[i+1:]))])
            
            # Get id of all data tensors that will serve as inputs to all computations performed after cnode
            deps_cnodes_after = set()
            for cn_id in cnodes_compute_after:
                cn = kcn_id_to_node[cn_id]
                for dn in set.union(cn.deps_real, cn.deps_fake):
                # for dn in set.union(cn.deps_real):
                    deps_cnodes_after.add(dn.unique_id)
            deps_cnodes_after.update([K_graph.input_kdn_data.unique_id])

            # Free data tensors that served as inputs to cnode, but won't contribute to later computations
            for dnode in data_nodes:
                idx = data_nodes.index(dnode)
                if alive_list[-1][idx] == 1 and not (dnode.unique_id in deps_cnodes_after):
                    
                    alive_status[idx] = 0
                    alive_list.append(alive_status.copy())
                    # print(alive_status)
                    
                    op = DelOp(data_nodes[idx])
                    op.proxy = True
                    op_list.append(op)
                    op_name_list.append([op.name, op.op_type])

        elif step_type == 'free':
            ### Twremat instruction "free computaion" we interpret as deleting data outputs of cnode computation
            for dnode in cnode.users:
                idx = data_nodes.index(dnode)
                if alive_status[idx]:
                    op = DelOp(data_nodes[idx])
                    op.proxy = True
                    op_list.append(op)
                    op_name_list.append([op.name, op.op_type])
                    alive_status[idx] = 0
                    alive_list.append(alive_status.copy())

        sched = OpSchedule(op_list, alive_list,
                    K_graph.input_kdn_data,
                    K_graph.input_kdn_grad,
                    K_graph.output_kdn_data,
                    K_graph.list_kdn,
                    no_grad=False)

    # from collections import Counter
    # counter = Counter(sched.op_name_list)
    # for k, v in counter.items():
    #     if 'fwd' in k:
    #         print(k, v, k.replace('fwd', 'bwd'), counter[k.replace('fwd', 'bwd')])

    # print(steps)
    # print(op_name_list)

    # print(f'Steps: {len(steps)},  op list: {len(sched.op_list)}')

    pred_mem = []
    for a, op in zip(sched.alive_list, sched.op_list):
        acc_mem = (
            np.dot(a, sched.mem_sizes) - sched.input_size[1]
        )
        
        pred_mem.append(acc_mem)
        if op.op_type == "Run":
            pred_mem[-1] += op.overhead
    
    # print(pred_mem)
    print("peak_mem from op schedule "+str(np.max(np.array(pred_mem))))

    return op_list, sched