import time
import logging

import functorch
import rich
import torch

from meshflow.unifyshard import unifyir
from meshflow.autoflow.solver import AutoFlowSolver
from .sharding_interpreter import MFTorchShardingAnn
from .bridge import torch2mf_bridge, get_torch_sharding_strategy
from .passes import eliminate_detach, fix_addmm, fix_batch_norm, sharding_transform
from .passes.sharding import get_device_mesh

logger = logging.getLogger(__name__)

ENABLE_TRANSFORM = False
BW_CONSTRAIN = None
INPUT_STRATEGY = None


def get_input_strategy():
    global INPUT_STRATEGY
    return INPUT_STRATEGY


def enable_transform():
    global ENABLE_TRANSFORM
    ENABLE_TRANSFORM = True


def _get_output_strategy(opt_strategy, unify_graph, input_strategy):
    partial_strategy = {}
    for op in unify_graph.op_list:
        op_key = op.unique_key()
        if op_key in opt_strategy:
            for idx, var in enumerate(op.outvars):
                if var in unify_graph.output_list:
                    strategy = opt_strategy[op_key]['strategy']['outvars_sharding'][idx]
                    partial_strategy[var] = strategy


    for var in unify_graph.input_list:
        if var in unify_graph.output_list:
            partial_strategy[var] = input_strategy[var]

    return partial_strategy


def _get_input_strategy(opt_strategy, unify_graph):
    partial_strategy = {}
    for op in reversed(unify_graph.op_list):
        op_key = op.unique_key()
        if op_key in opt_strategy:
            for idx, var in enumerate(op.invars):
                if var in unify_graph.input_list:
                    strategy = opt_strategy[op_key]['strategy']['invars_sharding'][idx]
                    partial_strategy[var] = strategy

    partial_strategy_list = []

    for var in unify_graph.input_list:
        if var in partial_strategy:
            partial_strategy_list.append(partial_strategy[var])
        else:
            partial_strategy_list.append(
                [unifyir.SPMD(unifyir.SPMD.REPLICATE),
                 unifyir.SPMD(unifyir.SPMD.REPLICATE)])

    return partial_strategy, partial_strategy_list


@functorch.compile.make_boxed_compiler
def meshflow_shard(fx_module: torch.fx.GraphModule, inps):

    global BW_CONSTRAIN, INPUT_STRATEGY

    fx_module = fix_batch_norm(fx_module)
    fx_module = fix_addmm(fx_module)
    fx_module = eliminate_detach(fx_module)
    fx_module.recompile()
    if logging.root.level <= logging.DEBUG:
        print(fx_module.graph)

    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    start_t = time.perf_counter()
    sharding_interpreter = MFTorchShardingAnn(fx_module)
    sharding_info, fwd_shape_info = sharding_interpreter.run(*inps)
    logger.info(f"[MFTorchShardingAnn.run]: {time.perf_counter() - start_t} s.")
    if logging.root.level <= logging.DEBUG:
        rich.print("sharding_info:\n", sharding_info)
        rich.print("fwd_shape_info:\n", fwd_shape_info)

    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False

    unify_graph = torch2mf_bridge(fx_module, sharding_info, fwd_shape_info)

    if logging.root.level <= logging.INFO:
        rich.print(unify_graph)

    device_mesh = get_device_mesh()
    device_mesh_shape = (device_mesh.size(0), device_mesh.size(1))

    solver = AutoFlowSolver(device_mesh_shape, BW_CONSTRAIN)
    solver.add_graph(unify_graph)
    start_t = time.perf_counter()
    count_invars = BW_CONSTRAIN is None
    opt_strategy = solver.ilp_optimize(count_invars)
    logger.info(f"[AutoFlowSolver.ilp_optimize]: {time.perf_counter() - start_t} s.")
    # start_t = time.perf_counter()
    # beam_search_strategy = solver.beam_search()
    # logger.info(f"[AutoFlowSolver.beam_search]: {time.perf_counter() - start_t} s.")

    if logging.root.level <= logging.INFO:
        rich.print(opt_strategy)

    if BW_CONSTRAIN is None:
        strategy_map, INPUT_STRATEGY = _get_input_strategy(opt_strategy, unify_graph)
        BW_CONSTRAIN = _get_output_strategy(opt_strategy, unify_graph, strategy_map)

    if ENABLE_TRANSFORM:
        sharding_strategy = get_torch_sharding_strategy(fx_module, opt_strategy)

        if logging.root.level <= logging.DEBUG:
            print(sharding_strategy)

        fx_module = sharding_transform(fx_module, sharding_strategy)

        if logging.root.level <= logging.DEBUG:
            print(fx_module.graph)

    return fx_module