import logging
import random
from typing import List

import mip

from meshflow.unifyshard.unifyir import SPMD, UnifyGraph, UnifyNode, UnifyVar

logger = logging.getLogger(__name__)


def shuffle_list(*ls):
    l = list(zip(*ls))

    random.shuffle(l)
    return zip(*l)


def get_idx_in_var_list(var: UnifyVar, var_list: List[UnifyVar]):
    var_list = [v.name for v in var_list]
    if var.name in var_list:
        return var_list.index(var.name)
    return None


def calculate_resharding_cost(var: UnifyVar, strategy_in: List[SPMD], strategy_out: List[SPMD],
                              device_mesh):
    var_size = var.get_var_size()

    all_to_all = lambda x: 1 * x
    all_gather = lambda x: 1 * x
    all_reduce = lambda x: 2 * x

    resharding_cost = 0

    s1_in, s2_in = strategy_in[0], strategy_in[1]
    s1_out, s2_out = strategy_out[0], strategy_out[1]

    if device_mesh[0] > 1:
        message_size = var_size
        if s2_in.state == SPMD.SHARD:
            message_size /= device_mesh[1]
        if s1_in.state == SPMD.SHARD:
            if s1_out.state == SPMD.SHARD:
                if s1_in.args != s1_out.args:
                    resharding_cost += all_to_all(message_size)
            else:
                resharding_cost += all_gather(message_size)
        elif s1_in.state == SPMD.PARTIAL:
            resharding_cost += all_reduce(message_size)

    if device_mesh[1] > 1:
        message_size = var_size
        if s1_in.state == SPMD.SHARD:
            message_size /= device_mesh[0]
        if s2_in.state == SPMD.SHARD:
            if s2_out.state == SPMD.SHARD:
                if s2_in.args != s2_out.args:
                    resharding_cost += all_to_all(message_size)
            else:
                resharding_cost += all_gather(message_size)
        elif s2_in.state == SPMD.PARTIAL:
            resharding_cost += all_reduce(message_size)

    return resharding_cost


def calculate_memory_cost(var: UnifyVar, strategy_in: List[SPMD], strategy_out: List[SPMD],
                          device_mesh):
    var_size = var.get_var_size()

    memory_cost = 0

    # FIXME: only use out_strategy here for shard_size is ok?
    for strategy in [strategy_out]:
        s1, s2 = strategy[0], strategy[1]
        shard_size = 1
        if s1.state == SPMD.SHARD:
            shard_size *= device_mesh[0]
        if s2.state == SPMD.SHARD:
            shard_size *= device_mesh[1]

        memory_cost += var_size // shard_size

    return memory_cost


def generate_comm_matrix(var: UnifyVar, up_strategy, down_strategy, idx_for_up, idx_for_down,
                         device_mesh):
    comm_matrix = [[0 for _ in range(len(down_strategy))] for _ in range(len(up_strategy))]

    for i in range(len(up_strategy)):
        for j in range(len(down_strategy)):
            var_up_strategy = up_strategy[i]['outvars_sharding'][idx_for_up]
            var_down_strategy = down_strategy[j]['invars_sharding'][idx_for_down]
            comm_matrix[i][j] = calculate_resharding_cost(var, var_up_strategy, var_down_strategy,
                                                          device_mesh)

    return comm_matrix


def generate_mem_matrix(var: UnifyVar, up_strategy, down_strategy, idx_for_up, idx_for_down,
                        device_mesh):
    comm_matrix = [[0 for _ in range(len(down_strategy))] for _ in range(len(up_strategy))]

    for i in range(len(up_strategy)):
        for j in range(len(down_strategy)):
            var_up_strategy = up_strategy[i]['outvars_sharding'][idx_for_up]
            var_down_strategy = down_strategy[j]['invars_sharding'][idx_for_down]
            comm_matrix[i][j] = calculate_memory_cost(var, var_up_strategy, var_down_strategy,
                                                      device_mesh)

    return comm_matrix


MAX_MEMORY_CONSTRAIN = 60


def set_max_memory(mem):
    global MAX_MEMORY_CONSTRAIN
    MAX_MEMORY_CONSTRAIN = mem


class AutoFlowSolver:

    def __init__(self, device_mesh=None, constrain=None) -> None:
        self.m = mip.Model("autoflow")
        self.nodes = {}
        self.edges = {}
        self.device_mesh = device_mesh
        self.constrain = constrain

        self.max_memory_ = MAX_MEMORY_CONSTRAIN * 1024 * 1024 * 1024  # 1 GB

    def add_graph(self, graph: UnifyGraph) -> None:
        self.graph = graph

        if self.constrain:
            for cons in self.constrain:
                if cons.name not in [i.name for i in self.graph.input_list]:
                    continue
                cons_node = UnifyNode("cons_" + cons.name, [], [cons], None)
                self.nodes[cons_node.unique_key()] = {
                    "node": cons_node,
                    "strategy": [{
                        'outvars_sharding': [self.constrain[cons]]
                    }],
                    "mip_var": [self.m.add_var(var_type=mip.BINARY)]
                }
                self.add_edge(cons, up_node=cons_node)

        self.liveness = graph.liveness()
        for op in graph.op_list:
            self.add_node(op)

    def add_node(self, node: UnifyNode) -> None:
        unique_key_ = node.unique_key()

        strategies = node.get_strategy()
        if len(strategies) > 0:
            self.nodes[unique_key_] = {
                "node": node,
                "strategy": strategies,
                "mip_var": [self.m.add_var(var_type=mip.BINARY) for _ in range(len(strategies))]
            }

            for var in node.invars:
                self.add_edge(var, down_node=node)

            for var in node.outvars:
                self.add_edge(var, up_node=node)

    def add_edge(self, edge: UnifyVar, up_node=None, down_node=None) -> None:
        unique_key_ = edge.name

        if unique_key_ not in self.edges:
            self.edges[unique_key_] = {
                "edge": edge,
                "up_node": None,
                "down_node": [],
                "idx_for_up": None,
                "idx_for_down": [],
                "comm_matrix": [],
                "mem_matrix": [],
                "mip_var": [],
            }
        if up_node is not None:
            self.edges[unique_key_]["up_node"] = up_node.unique_key()
            self.edges[unique_key_]["idx_for_up"] = get_idx_in_var_list(edge, up_node.outvars)

        if down_node is not None:
            self.edges[unique_key_]["down_node"].append(down_node.unique_key())
            self.edges[unique_key_]["idx_for_down"].append(
                get_idx_in_var_list(edge, down_node.invars))

            if self.edges[unique_key_]["up_node"] is not None:

                up_node_key = self.edges[unique_key_]["up_node"]
                up_strategy = self.nodes[up_node_key]["strategy"]

                down_node_key = self.edges[unique_key_]["down_node"][-1]
                down_strategy = self.nodes[down_node_key]["strategy"]

                self.edges[unique_key_]["mip_var"].append(
                    [[self.m.add_var(var_type=mip.BINARY) for _ in range(len(down_strategy))]
                     for _ in range(len(up_strategy))])

                # calculate ``comm_matrix`` for this edge
                idx_for_up = self.edges[unique_key_]["idx_for_up"]
                idx_for_down = self.edges[unique_key_]["idx_for_down"][-1]

                self.edges[unique_key_]["comm_matrix"].append(
                    generate_comm_matrix(
                        self.edges[unique_key_]["edge"],
                        up_strategy,
                        down_strategy,
                        idx_for_up,
                        idx_for_down,
                        self.device_mesh,
                    ))

                self.edges[unique_key_]["mem_matrix"].append(
                    generate_mem_matrix(
                        self.edges[unique_key_]["edge"],
                        up_strategy,
                        down_strategy,
                        idx_for_up,
                        idx_for_down,
                        self.device_mesh,
                    ))

    def ilp_optimize(self, count_invars=False):
        comm_cost, mem_cost = 0, 0
        for edge in self.edges.values():
            for idx in range(len(edge["mip_var"])):
                mip_var = edge["mip_var"][idx]
                comm_matrix = edge["comm_matrix"][idx]
                mem_matrix = edge["mem_matrix"][idx]
                shape_1 = len(mip_var)
                shape_2 = len(mip_var[0])
                comm_cost = comm_cost + mip.xsum(mip_var[i][j] * comm_matrix[i][j]
                                                 for i in range(shape_1) for j in range(shape_2))
                mem_cost = mem_cost + mip.xsum(mip_var[i][j] * mem_matrix[i][j]
                                               for i in range(shape_1) for j in range(shape_2))

            def _mem_cost(var_size, down_strategy, idx_for_down):
                memory_cost_list = []
                for i in range(len(down_strategy)):
                    strategy = down_strategy[i]['invars_sharding'][idx_for_down]

                    # FIXME: only use out_strategy here for shard_size is ok?
                    s1, s2 = strategy[0], strategy[1]
                    shard_size = 1
                    if s1.state == SPMD.SHARD:
                        shard_size *= self.device_mesh[0]
                    if s2.state == SPMD.SHARD:
                        shard_size *= self.device_mesh[1]

                    memory_cost_list.append(var_size // shard_size)

                return memory_cost_list

            if count_invars and edge["up_node"] is None:
                var_size = edge["edge"].get_var_size()
                for down_node_key, idx_for_down in zip(edge["down_node"], edge["idx_for_down"]):
                    down_strategy = self.nodes[down_node_key]["strategy"]
                    _mem_cost = _mem_cost(var_size, down_strategy, idx_for_down)
                    down_node_mip_var = self.nodes[down_node_key]["mip_var"]
                    mem_cost = mem_cost + mip.xsum(down_node_mip_var[i] * _mem_cost[i]
                                                   for i in range(len(down_node_mip_var)))

        for edge in self.edges.values():
            for idx in range(len(edge["mip_var"])):
                mip_var = edge["mip_var"][idx]
                shape_1 = len(mip_var)
                shape_2 = len(mip_var[0])
                self.m += mip.xsum(mip_var[i][j] for i in range(shape_1)
                                   for j in range(shape_2)) == 1

                up_node_key = edge["up_node"]
                up_node_mip_var = self.nodes[up_node_key]["mip_var"]

                down_node_key = edge["down_node"][idx]
                down_node_mip_var = self.nodes[down_node_key]["mip_var"]

                for i in range(shape_1):
                    for j in range(shape_2):
                        self.m += mip_var[i][j] <= up_node_mip_var[i]
                        self.m += mip_var[i][j] <= down_node_mip_var[j]
                        self.m += mip_var[i][j] >= up_node_mip_var[i] + down_node_mip_var[j] - 1

        idx_record = {}
        for idx_ in range(len(self.liveness)):
            mem_live = self.liveness[idx_]
            op_ = self.graph.op_list[idx_]

            mip_var_list = []
            mem_matrix_list = []
            for tensor_name in mem_live:
                if tensor_name not in self.edges:
                    continue
                if tensor_name not in idx_record:
                    idx_record[tensor_name] = 0
                if len(self.edges[tensor_name]["mip_var"]) > 0:
                    edge_idx = min(
                        len(self.edges[tensor_name]["mip_var"]) - 1, idx_record[tensor_name])
                    mip_var_list.append(self.edges[tensor_name]["mip_var"][edge_idx])
                    mem_matrix_list.append(self.edges[tensor_name]["mem_matrix"][edge_idx])

            for var in op_.invars:
                if var.name not in self.edges:
                    continue
                if var.name not in idx_record:
                    idx_record[var.name] = 0
                idx_record[var.name] += 1

            if len(mip_var_list) >= 1:
                need_sum = []
                for mip_var, mem_matrix in zip(mip_var_list, mem_matrix_list):
                    shape_1 = len(mip_var)
                    shape_2 = len(mip_var[0])
                    for i in range(shape_1):
                        for j in range(shape_2):
                            need_sum.append(mip_var[i][j] * mem_matrix[i][j])
                self.m += mip.xsum(i for i in need_sum) <= self.max_memory_

        for node in self.nodes.values():
            mip_var = node["mip_var"]
            shape_1 = len(mip_var)
            self.m += mip.xsum(mip_var[i] for i in range(shape_1)) == 1

        self.m.objective = mip.minimize(comm_cost + 0.00000001 * mem_cost)

        self.m.verbose = 0
        status = self.m.optimize()
        logger.info(f'=========== optimal status:\t {status}')
        logger.info(f'=========== solution cost:\t {self.m.objective_value}')

        return self.get_optimal_stratey()

    def get_optimal_stratey(self):
        optimal_stratey = {}
        for unique_key_ in self.nodes:
            node = self.nodes[unique_key_]['node']
            opt_ = None
            strategy_list = self.nodes[unique_key_]['strategy']
            mip_var = self.nodes[unique_key_]['mip_var']
            assert len(strategy_list) == len(mip_var)
            for s_, mip_var_s in zip(strategy_list, mip_var):
                if mip_var_s.x == 1:
                    opt_ = s_
            optimal_stratey[unique_key_] = {'node': node, 'strategy': opt_}
        return optimal_stratey

    def beam_search(self, candidate_num=100):

        def get_new_cost(strategy, node, strategy_idx):
            cost = 0.
            edge_list = [self.edges[invar.name] for invar in node['node'].invars]
            for edge in edge_list:
                up_node_key = edge["up_node"]

                idx = edge["down_node"].index(node['node'].unique_key())
                if up_node_key in strategy:
                    up_node_strategy_idx = strategy[up_node_key]['strategy_idx']
                    mem_cost = edge["mem_matrix"][idx][up_node_strategy_idx][strategy_idx]
                    comm_cost = edge["comm_matrix"][idx][up_node_strategy_idx][strategy_idx]
                    cost += comm_cost

            return cost

        def add_candidate(strategy_candidate, accumulate_cost, node):
            new_strategy_candidate = []
            new_accumulate_cost = []
            key_ = node['node'].unique_key()
            if len(strategy_candidate) == 0:
                for idx in range(len(node['strategy'])):
                    stratey = {key_: {'node': node['node'], 'strategy_idx': idx}}
                    new_strategy_candidate.append(stratey)
                    new_accumulate_cost.append(0.)
            else:
                for idx, strategy in enumerate(strategy_candidate):
                    old_cost = accumulate_cost[idx]
                    for idx in range(len(node['strategy'])):
                        new_cost = get_new_cost(strategy, node, idx)
                        new_strategy = {key: strategy[key] for key in strategy}
                        new_strategy[key_] = {'node': node['node'], 'strategy_idx': idx}
                        new_strategy_candidate.append(new_strategy)
                        new_accumulate_cost.append(old_cost + new_cost)

            return new_strategy_candidate, new_accumulate_cost

        def select_candidate(strategy_candidate, accumulate_cost, candidate_num):
            assert len(strategy_candidate) == len(accumulate_cost)
            if len(accumulate_cost) <= candidate_num:
                return strategy_candidate, accumulate_cost

            accumulate_cost, strategy_candidate = shuffle_list(accumulate_cost, strategy_candidate)
            accumulate_cost, strategy_candidate = zip(
                *sorted(zip(accumulate_cost, strategy_candidate), key=lambda x: x[0]))

            return strategy_candidate[:candidate_num], accumulate_cost[:candidate_num]

        strategy_candidate = []
        accumulate_cost = []
        for unique_key_ in self.nodes:
            node = self.nodes[unique_key_]

            strategy_candidate, accumulate_cost = add_candidate(strategy_candidate,
                                                                accumulate_cost, node)

            strategy_candidate, accumulate_cost = select_candidate(strategy_candidate,
                                                                   accumulate_cost, candidate_num)

            accumulate_cost, strategy_candidate = zip(
                *sorted(zip(accumulate_cost, strategy_candidate), key=lambda x: x[0]))

        strategy = strategy_candidate[0]

        optimal_stratey = {}
        for key in strategy:
            node = strategy[key]['node']
            node_strategy_list = self.nodes[node.unique_key()]['strategy']
            optimal_stratey[key] = {
                "node": node,
                "strategy": node_strategy_list[strategy[key]['strategy_idx']]
            }

        logger.info(f'=========== solution cost:\t {accumulate_cost[0]}')

        return optimal_stratey
