import torch
import numpy as np
from torch import Tensor
from typing import List, Union, Any
from ml4co_kit import (
    check_dim, MClGraphData, MCutGraphData, 
    MISGraphData, MVCGraphData
)
from meta_diffusion.env.dense import (
    atsp_dense_process, tsp_dense_process
)

class MetaDiffDenser(object):
    def __init__(self, device: str) -> None:
        self.device = device
    
    #################################
    #        Raw Data Process       #
    #################################
    
    def initial_lists(self):
        self.nodes_feature_list = list()
        self.x_list = list()
        self.graph_list = list()
        self.e_list = list()
        self.ground_truth_list = list()
        self.nodes_num_list = list()
        self.ref_tour_list = list()
        
    def update_lists(self, dense_data: Any):
        self.nodes_feature_list.append(dense_data[0])
        self.x_list.append(dense_data[1])
        self.graph_list.append(dense_data[2])
        self.e_list.append(dense_data[3])
        self.ground_truth_list.append(dense_data[4])
        self.nodes_num_list.append(dense_data[5])
        self.ref_tour_list.append(dense_data[6])
    
    def merge_process(self, task: str, with_gt: bool) -> Any:
        # nodes feature
        if self.nodes_feature_list[0] is not None:
            nodes_feature = torch.stack(self.nodes_feature_list, 0).to(self.device)
        else:
            nodes_feature = None
            
        # nodes decision variable
        x = None
        
        # edges decision variable
        e = torch.stack(self.e_list, 0).to(self.device)
        
        # graph
        graph = torch.stack(self.graph_list, 0).to(self.device)

        # ground truth
        if with_gt:
            ground_truth = torch.stack(self.ground_truth_list, 0).to(self.device) # (B, V, V) or (B, V)
        else:
            ground_truth = None
        
        return (
            task, nodes_feature, x, graph, e, ground_truth, self.nodes_num_list, self.ref_tour_list
        )  
        
    def atsp_batch_data_process(
        self, dists: np.ndarray, ref_tours: np.ndarray, sampling_num: int = 1
    ) -> Any:
        # check dimension
        check_dim(dists, 3)
        check_dim(ref_tours, 2)
        
        # initialize lists
        self.initial_lists()
        
        # dense process
        for idx in range(dists.shape[0]):
            dense_data = atsp_dense_process(
                dists=dists[idx], 
                ref_tour=ref_tours[idx] if ref_tours is not None else None
            )
            for _ in range(sampling_num):
                self.update_lists(dense_data)
            
        # merge
        return self.merge_process(
            task="ATSP", 
            with_gt=True if ref_tours is not None else False
        )

    def mcl_batch_data_process(
        self, graph_data: List[MClGraphData], sampling_num: int = 1
    ) -> Any:
        raise NotImplementedError()
    
    def mcut_batch_data_process(
        self, graph_data: List[MCutGraphData], sampling_num: int = 1
    ) -> Any:
        raise NotImplementedError()

    def mis_batch_data_process(
        self, graph_data: List[MISGraphData], sampling_num: int = 1
    ) -> Any:
        raise NotImplementedError()

    def mvc_batch_data_process(
        self, graph_data: List[MVCGraphData], sampling_num: int = 1
    ) -> Any:
        raise NotImplementedError()
    
    def tsp_batch_data_process(
        self, points: np.ndarray, ref_tours: np.ndarray, sampling_num: int = 1
    ) -> Any:
        # check dimension
        check_dim(points, 3)
        check_dim(ref_tours, 2)
        
        # initialize lists
        self.initial_lists()
        
        # dense process
        for idx in range(points.shape[0]):
            dense_data = tsp_dense_process(
                points=points[idx], 
                ref_tour=ref_tours[idx] if ref_tours is not None else None, 
            )
            for _ in range(sampling_num):
                self.update_lists(dense_data)
        
        # merge
        return self.merge_process(
            task="TSP", 
            with_gt=True if ref_tours is not None else False
        )