import logging
import os
import re
import copy
from typing import Any, Dict, Iterator, Tuple

import torch
import torch.utils._pytree as pytree
from spmd import Replicate
from spmd.tensor.api import DTensor
from torch.fx.graph_module import GraphModule
from torch.fx.interpreter import Interpreter
from torch.fx.node import Argument, Target, _get_qualified_name

from meshflow.torch.passes.sharding import get_device_mesh
from meshflow.unifyshard import UnifyOp, view_propagation

logger = logging.getLogger(__name__)


def to_meta(_node_output):
    if type(_node_output) == DTensor:
        device_mesh = get_device_mesh()
        if device_mesh:
            _node_output = _node_output.redistribute(device_mesh, [Replicate()] *
                                                     device_mesh.ndim)._local_tensor

    if type(_node_output) is torch.Tensor:
        return _node_output.detach().to(device="meta").contiguous()
    elif type(_node_output) is torch.nn.parameter.Parameter:
        return _node_output.data.detach().to(device="meta").contiguous()
    else:
        return _node_output


def get_shape_info(_node_output):
    if type(_node_output) in [torch.Tensor, torch.nn.parameter.Parameter]:
        return {"shape": _node_output.shape, "dtype": _node_output.dtype}
    else:
        return _node_output


def inject_view_propagation(sharding_info):

    view_op_list = [
        "torch.ops.aten.view.default", "torch.ops.aten._unsafe_view.default",
        "torch.ops.aten.expand.default"
    ]

    for op_name in sharding_info:
        if op_name in view_op_list:
            for args in sharding_info[op_name]:
                # example args for view:
                # "(tensor(..., device='meta', size=(2, 4, 256, 256)), [8, 256, 256]) | {}"
                input_shape = [
                    int(num)
                    for num in re.findall(r'size=\((\d+(?:, \d+)*)\)', args)[0].split(', ')
                ]
                output_shape = [
                    int(num) for num in re.findall(r'\[(\d+(?:, \d+)*)\]', args)[0].split(', ')
                ]
                sharding_info[op_name][args] = view_propagation(input_shape, output_shape)
        if "torch.ops.aten._reshape_alias.default" == op_name:
            for args in sharding_info[op_name]:
                input_shape = [
                    int(num)
                    for num in re.findall(r'size=\((\d+(?:, \d+)*)\)', args)[0].split(', ')
                ]
                output_shape = [
                    int(num) for num in re.findall(r'\[(\d+(?:, \d+)*)\]', args)[0].split(', ')
                ]
                stride_shape = [
                    int(num) for num in re.findall(r'\[(\d+(?:, \d+)*)\]', args)[1].split(', ')
                ]
                ori_stride = 1
                for i in range(len(stride_shape) - 1, -1, -1):
                    if stride_shape[i] != ori_stride:
                        continue
                    if i + 1 < len(output_shape):
                        ori_stride *= output_shape[i + 1]

                sharding_info[op_name][args] = view_propagation(input_shape, output_shape)

    return sharding_info


class MFTorchShardingAnn(Interpreter):

    def __init__(self, module: GraphModule, use_cache=True, garbage_collect_values: bool = True):
        super().__init__(module, garbage_collect_values)
        self.use_cache = use_cache
        self.shape_info = {}
        self.sharding_info = {}

        self.pass_by_ops = ["_operator.getitem"]

    def run(self, *args) -> Any:
        """
        Run `module` via interpretation and return the result.
        Args:
            *args: The arguments to the Module to run, in positional order
            initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution.
                This is a dict mapping `Node` to any value. This can be used, for example, to
                pre-populate results for certain `Nodes` so as to do only partial evaluation within
                the interpreter.
        Returns:
            Any: The value returned from executing the Module
        """
        self.env = {}

        args = pytree.tree_map(to_meta, args)

        # Positional function args are consumed left-to-right by
        # `placeholder` nodes. Use an iterator to keep track of
        # position and extract those values.
        self.args_iter: Iterator[Any] = iter(args)

        for node in self.module.graph.nodes:

            if node in self.env:
                # Short circuit if we have this value. This could
                # be used, for example, for partial evaluation
                # where the caller has pre-populated `env` with
                # values for a subset of the program.
                continue

            _node_output = self.run_node(node)
            self.env[node] = pytree.tree_map(to_meta, _node_output)
            self.shape_info[node.name] = pytree.tree_map(get_shape_info, self.env[node])

            if self.garbage_collect_values:
                for to_delete in self.user_to_last_uses.get(node, []):
                    del self.env[to_delete]

            if node.op == 'output':
                # output_val = self.env[node]
                self.env = {}
                return inject_view_propagation(self.sharding_info), self.shape_info

    def call_function(self, target: 'Target', args: Tuple[Argument, ...],
                      kwargs: Dict[str, Any]) -> Any:
        """
        Execute a ``call_function`` node and return the result.
        Args:
            target (Target): The call target for this node. See
                `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
                details on semantics
            args (Tuple): Tuple of positional args for this invocation
            kwargs (Dict): Dict of keyword arguments for this invocation
        Return
            Any: The value returned by the function invocation
        """
        assert not isinstance(target, str)

        ops_name = _get_qualified_name(target)

        if ops_name in self.pass_by_ops:
            real_out = target(*args, **kwargs)
            return real_out

        args_meta = pytree.tree_map(to_meta, args)
        kwargs_meta = pytree.tree_map(to_meta, kwargs)

        op_perf_key = {
            "ops_name": ops_name,
            "args_meta": str(args_meta) + ' | ' + str(kwargs_meta)
        }
        logger.debug(op_perf_key)

        output_device = os.environ.get("MESHFLOW_DEVICE", "cuda")

        def generate_tensor_args(x):
            if isinstance(x, torch.Tensor):
                if x.dtype == torch.bool:
                    return torch.rand(x.size(), dtype=torch.float, device=output_device) > 0.5
                elif torch.is_floating_point(x):
                    return torch.rand(x.size(), dtype=x.dtype, device=output_device)
                else:
                    return torch.randint(high=8, size=x.size(), dtype=x.dtype, device=output_device)
            return x

        # materialize the input of ops, execute the function and return the result
        args = pytree.tree_map(generate_tensor_args, args)
        kwargs = pytree.tree_map(generate_tensor_args, kwargs)

        unify_op_ = UnifyOp(func=target, input_args=(args, kwargs), name=ops_name)
        real_out = unify_op_.exec()

        # sharding discovery
        if op_perf_key["ops_name"] not in self.sharding_info:
            self.sharding_info[op_perf_key["ops_name"]] = {}

        if not self.use_cache or op_perf_key["args_meta"] not in self.sharding_info[
                op_perf_key["ops_name"]]:
            prompt_annotation = None
            if self.use_cache and len(self.sharding_info[op_perf_key["ops_name"]]) >= 1:
                prompt_annotation = list(
                    self.sharding_info[op_perf_key["ops_name"]].values())[0]["sharding_ann"]
            
            if "reshape" in ops_name:
                sharding_ann, combination_ann = None, None
            else:
                sharding_ann, combination_ann = unify_op_.sharding_discovery(
                    prompt_annotation=copy.deepcopy(prompt_annotation))

            self.sharding_info[op_perf_key["ops_name"]][op_perf_key["args_meta"]] = {
                "sharding_ann": sharding_ann,
                "combination_ann": combination_ann,
            }

        return pytree.tree_map(to_meta, real_out)
