import numpy as np
from torch import Tensor
from typing import Union, List
from ml4co_kit import (
    to_numpy, MCutSolver, ATSPSolver, TSPSolver
)


class MetaDiffDecoder(object):
    def __init__(self) -> None:
        pass
        
    def sparse_decode(
        self, heatmap: Tensor, task: str, nodes_feature: Tensor, x: Tensor, edges_feature: Tensor, 
        e: Tensor, edge_index: Tensor, graph_list: List[Tensor], ground_truth: Tensor,
        nodes_num_list: list, edges_num_list: list, ref_tour_list: list, return_cost: bool = False
    ) -> Union[List[np.ndarray], np.floating]:

        solutions = list()
        if task in ["MIS", "MVC", "MCl"]:
            begin_idx = 0
            for idx in range(len(graph_list)):
                end_idx = begin_idx + nodes_num_list[idx]
                solutions.append(self._node_sparse_decode(
                    heatmap=heatmap[begin_idx:end_idx], 
                    graph=graph_list[idx]
                ))
                begin_idx = end_idx
        elif task in ["MCut"]:
            node_begin_idx = 0
            edge_begin_idx = 0
            for idx in range(len(edges_num_list)):
                node_end_idx = node_begin_idx + edges_num_list[idx]
                edge_end_idx = edge_begin_idx + edges_num_list[idx]
                solutions.append(self._node_sparse_decode(
                    heatmap=heatmap[node_begin_idx:node_end_idx],
                    graph=graph_list[idx], 
                    edge_index=edge_index[:, edge_begin_idx:edge_end_idx],
                ))
                node_begin_idx = node_end_idx
                edge_begin_idx = edge_end_idx
        elif task in ["TSP"]:
            node_begin_idx = 0
            edge_begin_idx = 0
            for idx in range(len(edges_num_list)):
                node_end_idx = node_begin_idx + edges_num_list[idx]
                edge_end_idx = edge_begin_idx + edges_num_list[idx]
                solutions.append(self._edge_sparse_decode(
                    heatmap=heatmap[edge_begin_idx:edge_end_idx], 
                    x=nodes_feature[node_begin_idx:node_end_idx],
                    edge_index=edge_index[:, edge_begin_idx:edge_end_idx],
                    nodes_num=nodes_num_list[idx]
                ))
                node_begin_idx = node_end_idx
                edge_begin_idx = edge_end_idx
        else:
            raise NotImplementedError()
        
        # check if return cost
        if return_cost:
            if task in ["MIS", "MVC", "MCl"]:
                costs = [sum(sol) for sol in solutions]
                costs = np.average(np.array(costs))
                ref_obj = ground_truth.sum()
                gap = (ref_obj - costs).abs() / ref_obj * 100
            elif task in ["MCut"]:
                edge_index = to_numpy(edge_index)
                mcut_solver = MCutSolver()
                mcut_solver.from_adj_matrix(
                    adj_matrix=[to_numpy(g) for g in graph_list],
                    nodes_label=solutions
                )
                mcut_solver.graph_data[0].ref_nodes_label = ground_truth
                costs, _, gap, _ = mcut_solver.evaluate(calculate_gap=True)
            elif task in ["TSP"]:
                tsp_solver = TSPSolver()
                tsp_solver.from_data(
                    points=to_numpy(nodes_feature), tours=ref_tour_list
                )
                costs, _, gap, _ = tsp_solver.evaluate(calculate_gap=True)
            else:
                raise NotImplementedError()
            return costs, gap
        else:
            return solutions
    
    def dense_decode(
        self, heatmap: Tensor, task: str, nodes_feature: Tensor, x: Tensor, graph: Tensor, 
        e: Tensor, ground_truth: Tensor, nodes_num_list: list, ref_tour_list: list,
        return_cost: bool = False
    ) -> Union[List[np.ndarray], np.floating]:
        solutions = list()
        assert len(nodes_num_list) == 1
        # get solutions
        if task in ["ATSP", "TSP"]:
            for idx in range(heatmap.shape[0]):
                solutions.append(self._edge_dense_decode(
                    heatmap=heatmap[idx], x=nodes_feature[idx], graph=graph[idx]
                )
            )
        else:
            raise NotImplementedError()
        
        # check if return cost
        if return_cost:
            if task == "ATSP":
                atsp_solver = ATSPSolver()
                atsp_solver.from_data(
                    dists=to_numpy(graph), tours=solutions
                )
                atsp_solver.ref_tours = ref_tour_list
                costs, _, gap, _ = atsp_solver.evaluate(calculate_gap=True)
            elif task == "TSP":
                tsp_solver = TSPSolver()
                tsp_solver.from_data(
                    points=to_numpy(nodes_feature), tours=solutions
                )
                tsp_solver.ref_tours = ref_tour_list
                costs, _, gap, _ = tsp_solver.evaluate(calculate_gap=True)
            else:
                raise NotImplementedError()
            return costs, gap
        else:
            return solutions
        
    def _node_sparse_decode(self) -> np.ndarray:
        raise NotImplementedError(
            "``_node_sparse_decode`` is required to implemented in subclasses."
        )
        
    def _edge_sparse_decode(self) -> np.ndarray:
        raise NotImplementedError(
            "``_edge_sparse_decode`` is required to implemented in subclasses."
        )
        
    def _edge_dense_decode(self) -> np.ndarray:
        raise NotImplementedError(
            "``_edge_dense_decode`` is required to implemented in subclasses."
        )