import os
import torch
import numpy as np
from typing import Any, List, Union, Dict, Tuple
from ml4co_kit import *
from torch.utils.data import DataLoader, Dataset
import random
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from meta_diffusion.env.denser import MetaDiffDenser
from meta_diffusion.env.sparser import MetaDiffSparser


class FakeDataset(Dataset):
    def __init__(self, data_size: int):
        self.data_size = data_size

    def __len__(self):
        return self.data_size
    
    def __getitem__(self, idx: int):
        return torch.tensor([idx])


class MetaDiffEnv(BaseEnv):
    def __init__(
        self,
        task: List[str] = None,
        mode: str = None,
        train_data_size: int = 128000,
        val_data_size: int = 128,
        train_batch_size: int = 4,
        val_batch_size: int = 1,
        num_workers: int = 4,
        sparse_factor: int = 50,
        device: str = "cpu",
        train_folder: Dict[str, str] = None,
        train_path: str = None,
        val_path: Dict[str, str] = None,
        store_data: bool = True,
        finetune: bool = False
    ):
        super().__init__(
            name="MetaDiffEnv",
            mode=mode,
            train_batch_size=train_batch_size,
            val_batch_size=val_batch_size,
            num_workers=num_workers,
            device=device
        )
        
        # basic
        self.task_pool = task
        self.sparse = sparse_factor > 0
        self.sparse_factor = sparse_factor
        self.finetune = finetune
        
        # train data folder and val path
        self.train_folder = train_folder
        self.val_path = val_path
        self.train_path = train_path
        
        # ml4co-kit solver
        self.atsp_solver = ATSPSolver()
        self.mcl_solver = MClSolver()
        self.mcut_solver = MCutSolver()
        self.mis_solver = MISSolver()
        self.mvc_solver = MVCSolver()
        self.tsp_solver = TSPSolver()
        
        # dataset (Fake)
        self.store_data = store_data
        self.train_dataset = FakeDataset(train_data_size)
        self.val_dataset = FakeDataset(val_data_size)
          
        # data_processor (sparser and denser)
        if self.sparse:
            self.data_processor = MetaDiffSparser(self.sparse_factor, self.device)
        else:
            self.data_processor = MetaDiffDenser(self.device)
        
        # load data
        if self.mode is not None:
            self.load_data()

    def load_data(self):
        if self.mode == "train":
            if self.train_path is None:
                self.train_sub_files = {
                    task: [os.path.join(folder, txt_file) for txt_file in os.listdir(folder)] \
                    for task, folder in self.train_folder.items()
                }
            else:
                self.train_sub_files = {
                    self.task_pool[0]: [self.train_path]
                }
            self.train_sub_files_num = {
                task: len(sub_file_list) \
                for task, sub_file_list in self.train_sub_files.items()
            }
            self.train_data_historty_cache = {task: dict() for task in self.task_pool}
            self.train_data_cache = {task: None for task in self.task_pool}
            self.val_data_cache = {task: None for task in self.task_pool}
            self.train_data_cache_idx = {task: 0 for task in self.task_pool}
        else:
            pass
        
    def train_dataloader(self):
        train_dataloader=DataLoader(
            self.train_dataset, 
            batch_size=self.train_batch_size, 
            shuffle=True,
            num_workers=self.num_workers, 
            pin_memory=True,
            persistent_workers=True, 
            drop_last=True
        )
        return train_dataloader

    def val_dataloader(self):
        val_dataloader=DataLoader(
            self.val_dataset, 
            batch_size=self.val_batch_size, 
            shuffle=False
        )
        return val_dataloader
    
    def test_dataloader(self):
        test_dataloader=DataLoader(
            self.test_dataset,
            batch_size=self.test_batch_size,
            shuffle=False
        )
        return test_dataloader

    #################################
    #       Generate Val Data       #           
    #################################

    def generate_val_data(self, task: str, val_idx: int) -> Any:
        begin_idx = val_idx * self.val_batch_size
        end_idx = begin_idx + self.val_batch_size
        if task == "ATSP":
            return self.generate_val_data_atsp(begin_idx, end_idx)
        elif task == "MCl":
            return self.generate_val_data_mcl(begin_idx, end_idx)
        elif task == "MCut":
            return self.generate_val_data_mcut(begin_idx, end_idx)
        elif task == "MIS":
            return self.generate_val_data_mis(begin_idx, end_idx)
        elif task == "MVC":
            return self.generate_val_data_mvc(begin_idx, end_idx)
        elif task == "TSP":
            return self.generate_val_data_tsp(begin_idx, end_idx) 

    def generate_val_data_atsp(self, begin_idx: int, end_idx: int) -> Any:
        task_name = "ATSP"
        if self.val_data_cache[task_name] is None:
            self.atsp_solver.from_txt(self.val_path[task_name], ref=True)
            self.val_data_cache[task_name] = {
                "dists": self.atsp_solver.dists,
                "ref_tours": self.atsp_solver.ref_tours
            }
        return self.data_processor.atsp_batch_data_process(
            dists=self.val_data_cache[task_name]["dists"][begin_idx:end_idx], 
            ref_tours=self.val_data_cache[task_name]["ref_tours"][begin_idx:end_idx]
        )

    def generate_val_data_mcl(self, begin_idx: int, end_idx: int) -> Any:
        task_name = "MCl"
        if self.val_data_cache[task_name] is None:
            self.mcl_solver.from_txt(self.val_path[task_name], ref=True)
            self.val_data_cache[task_name] = {
                "graph_data": self.mcl_solver.graph_data
            }
        return self.data_processor.mcl_batch_data_process(
            graph_data=self.val_data_cache[task_name]["graph_data"][begin_idx:end_idx]
        )

    def generate_val_data_mcut(self, begin_idx: int, end_idx: int) -> Any:
        task_name = "MCut"
        if self.val_data_cache[task_name] is None:
            self.mcut_solver.from_txt(self.val_path[task_name], ref=True)
            self.val_data_cache[task_name] = {
                "graph_data": self.mcut_solver.graph_data
            }
        return self.data_processor.mcut_batch_data_process(
            graph_data=self.val_data_cache[task_name]["graph_data"][begin_idx:end_idx]
        )
    
    def generate_val_data_mis(self, begin_idx: int, end_idx: int) -> Any:
        task_name = "MIS"
        if self.val_data_cache[task_name] is None:
            self.mis_solver.from_txt(self.val_path[task_name], ref=True)
            self.val_data_cache[task_name] = {
                "graph_data": self.mis_solver.graph_data
            }
        return self.data_processor.mis_batch_data_process(
            graph_data=self.val_data_cache[task_name]["graph_data"][begin_idx:end_idx]
        )
    
    def generate_val_data_mvc(self, begin_idx: int, end_idx: int) -> Any:
        task_name = "MVC"
        if self.val_data_cache[task_name] is None:
            self.mvc_solver.from_txt(self.val_path[task_name], ref=True)
            self.val_data_cache[task_name] = {
                "graph_data": self.mvc_solver.graph_data
            }
        return self.data_processor.mvc_batch_data_process(
            graph_data=self.val_data_cache[task_name]["graph_data"][begin_idx:end_idx]
        )

    def generate_val_data_tsp(self, begin_idx: int, end_idx: int) -> Any:
        task_name = "TSP"
        if self.val_data_cache[task_name] is None:
            self.tsp_solver.from_txt(self.val_path[task_name], ref=True, normalize=self.finetune)
            self.val_data_cache[task_name] = {
                "points": self.tsp_solver.points,
                "ref_tours": self.tsp_solver.ref_tours
            }
        return self.data_processor.tsp_batch_data_process(
            points=self.val_data_cache[task_name]["points"][begin_idx:end_idx], 
            ref_tours=self.val_data_cache[task_name]["ref_tours"][begin_idx:end_idx]
        )   
        
    #################################
    #      Generate Train Data      #
    #################################
    
    def generate_train_data(self, task, batch_size: int) -> Any:
        if task == "ATSP":
            return self.generate_train_data_atsp(batch_size)
        elif task == "MCl":
            return self.generate_train_data_mcl(batch_size)
        elif task == "MCut":
            return self.generate_train_data_mcut(batch_size)
        elif task == "MIS":
            return self.generate_train_data_mis(batch_size)
        elif task == "MVC":
            return self.generate_train_data_mvc(batch_size)
        elif task == "TSP":
            return self.generate_train_data_tsp(batch_size) 

    def generate_train_data_atsp(self, batch_size: int)  -> Any:
        task_name = "ATSP"
        # check data cache
        begin_idx = self.train_data_cache_idx[task_name]
        end_idx = begin_idx + batch_size
        if self.train_data_cache[task_name] is None or end_idx > self.train_data_cache[task_name]["data_size"]:
            # select one train file randomly
            sel_idx = np.random.randint(low=0, high=self.train_sub_files_num[task_name], size=(1,))[0]
            sel_train_sub_file_path = self.train_sub_files[task_name][sel_idx]
            
            # check if the data is in the cache when store_data is True
            if self.store_data and sel_train_sub_file_path in self.train_data_historty_cache[task_name].keys():
                # using data cache if the data is in the cache
                print(f"\nusing data cache ({sel_train_sub_file_path})")
                self.train_data_cache[task_name] = self.train_data_historty_cache[task_name][sel_train_sub_file_path]
            else:  
                # load data from the train file
                print(f"\nload atsp train data from {sel_train_sub_file_path}")
                self.atsp_solver.from_txt(sel_train_sub_file_path, show_time=True, ref=True)
                self.train_data_cache[task_name] = {
                    "dists": self.atsp_solver.dists,
                    "ref_tours": self.atsp_solver.ref_tours,
                    "data_size": self.atsp_solver.dists.shape[0]
                }
                if self.store_data:
                    self.train_data_historty_cache[task_name][sel_train_sub_file_path] = self.train_data_cache[task_name]
                
            # update cache and index
            self.train_data_cache_idx[task_name] = 0
            begin_idx = self.train_data_cache_idx[task_name]
            end_idx = begin_idx + batch_size
        
        # retrieve a portion of data from the cache
        dists = self.train_data_cache[task_name]["dists"][begin_idx:end_idx]
        ref_tours = self.train_data_cache[task_name]["ref_tours"][begin_idx:end_idx]
        self.train_data_cache_idx[task_name] = end_idx
        
        # data process
        return self.data_processor.atsp_batch_data_process(dists, ref_tours)
    
    def generate_train_data_mcl(self, batch_size: int) -> Any:
        task_name = "MCl"
        # check data cache
        begin_idx = self.train_data_cache_idx[task_name]
        end_idx = begin_idx + batch_size
        if self.train_data_cache[task_name] is None or end_idx > self.train_data_cache[task_name]["data_size"]:
            # select one train file randomly
            sel_idx = np.random.randint(low=0, high=self.train_sub_files_num[task_name], size=(1,))[0]
            sel_train_sub_file_path = self.train_sub_files[task_name][sel_idx]

            # check if the data is in the cache when store_data is True
            if self.store_data and sel_train_sub_file_path in self.train_data_historty_cache[task_name].keys():
                # using data cache if the data is in the cache
                print(f"\nusing data cache ({sel_train_sub_file_path})")
                self.train_data_cache[task_name] = self.train_data_historty_cache[task_name][sel_train_sub_file_path]
            else:
                # load data from the train file
                print(f"\nload mcl train data from {sel_train_sub_file_path}")
                self.mcl_solver.from_txt(sel_train_sub_file_path, show_time=True, ref=True)
                self.train_data_cache[task_name] = {
                    "graph_data": self.mcl_solver.graph_data,
                    "data_size": len(self.mcl_solver.graph_data)
                }
                if self.store_data:
                    self.train_data_historty_cache[task_name][sel_train_sub_file_path] = self.train_data_cache[task_name]
                
            # update cache and index
            self.train_data_cache_idx[task_name] = 0
            begin_idx = self.train_data_cache_idx[task_name]
            end_idx = begin_idx + batch_size
        
        # retrieve a portion of data from the cache
        graph_data = self.train_data_cache[task_name]["graph_data"][begin_idx:end_idx]
        self.train_data_cache_idx[task_name] = end_idx
        
        # data process
        return self.data_processor.mcl_batch_data_process(graph_data)

    def generate_train_data_mcut(self, batch_size: int) -> Any:
        task_name = "MCut"
        # check data cache
        begin_idx = self.train_data_cache_idx[task_name]
        end_idx = begin_idx + batch_size
        if self.train_data_cache[task_name] is None or end_idx > self.train_data_cache[task_name]["data_size"]:
            # select one train file randomly
            sel_idx = np.random.randint(low=0, high=self.train_sub_files_num[task_name], size=(1,))[0]
            sel_train_sub_file_path = self.train_sub_files[task_name][sel_idx]

            # check if the data is in the cache when store_data is True
            if self.store_data and sel_train_sub_file_path in self.train_data_historty_cache[task_name].keys():
                # using data cache if the data is in the cache
                print(f"\nusing data cache ({sel_train_sub_file_path})")
                self.train_data_cache[task_name] = self.train_data_historty_cache[task_name][sel_train_sub_file_path]
            else:
                # load data from the train file
                print(f"\nload mcut train data from {sel_train_sub_file_path}")
                self.mcut_solver.from_txt(sel_train_sub_file_path, show_time=True, ref=True)
                self.train_data_cache[task_name] = {
                    "graph_data": self.mcut_solver.graph_data,
                    "data_size": len(self.mcut_solver.graph_data)
                }
                if self.store_data:
                    self.train_data_historty_cache[task_name][sel_train_sub_file_path] = self.train_data_cache[task_name]
                
            # update cache and index
            self.train_data_cache_idx[task_name] = 0
            begin_idx = self.train_data_cache_idx[task_name]
            end_idx = begin_idx + batch_size
        
        # retrieve a portion of data from the cache
        graph_data = self.train_data_cache[task_name]["graph_data"][begin_idx:end_idx]
        self.train_data_cache_idx[task_name] = end_idx
            
        # sparse process
        return self.data_processor.mcut_batch_data_process(graph_data)

    def generate_train_data_mis(self, batch_size: int) -> Any:
        task_name = "MIS"
        # check data cache
        begin_idx = self.train_data_cache_idx[task_name]
        end_idx = begin_idx + batch_size
        if self.train_data_cache[task_name] is None or end_idx > self.train_data_cache[task_name]["data_size"]:
            # select one train file randomly
            sel_idx = np.random.randint(low=0, high=self.train_sub_files_num[task_name], size=(1,))[0]
            sel_train_sub_file_path = self.train_sub_files[task_name][sel_idx]
            
            # check if the data is in the cache when store_data is True
            if self.store_data and sel_train_sub_file_path in self.train_data_historty_cache[task_name].keys():
                # using data cache if the data is in the cache
                print(f"\nusing data cache ({sel_train_sub_file_path})")
                self.train_data_cache[task_name] = self.train_data_historty_cache[task_name][sel_train_sub_file_path]
            else:
                # load data from the train file
                print(f"\nload mis train data from {sel_train_sub_file_path}")
                self.mis_solver.from_txt(sel_train_sub_file_path, show_time=True, ref=True)
                self.train_data_cache[task_name] = {
                    "graph_data": self.mis_solver.graph_data,
                    "data_size": len(self.mis_solver.graph_data)
                }
                if self.store_data:
                    self.train_data_historty_cache[task_name][sel_train_sub_file_path] = self.train_data_cache[task_name]
            
            # update cache and index
            self.train_data_cache_idx[task_name] = 0
            begin_idx = self.train_data_cache_idx[task_name]
            end_idx = begin_idx + batch_size
        
        # retrieve a portion of data from the cache
        graph_data = self.train_data_cache[task_name]["graph_data"][begin_idx:end_idx]
        self.train_data_cache_idx[task_name] = end_idx
        
        # data process
        return self.data_processor.mis_batch_data_process(graph_data)

    def generate_train_data_mvc(self, batch_size: int) -> Any:
        task_name = "MVC"
        # check data cache
        begin_idx = self.train_data_cache_idx[task_name]
        end_idx = begin_idx + batch_size
        if self.train_data_cache[task_name] is None or end_idx > self.train_data_cache[task_name]["data_size"]:
            # select one train file randomly
            sel_idx = np.random.randint(low=0, high=self.train_sub_files_num[task_name], size=(1,))[0]
            sel_train_sub_file_path = self.train_sub_files[task_name][sel_idx]

            # check if the data is in the cache when store_data is True
            if self.store_data and sel_train_sub_file_path in self.train_data_historty_cache[task_name].keys():
                # using data cache if the data is in the cache
                print(f"\nusing data cache ({sel_train_sub_file_path})")
                self.train_data_cache[task_name] = self.train_data_historty_cache[task_name][sel_train_sub_file_path]
            else: 
                # load data from the train file
                print(f"\nload mvc train data from {sel_train_sub_file_path}")
                self.mvc_solver.from_txt(sel_train_sub_file_path, show_time=True, ref=True)
                self.train_data_cache[task_name] = {
                    "graph_data": self.mvc_solver.graph_data,
                    "data_size": len(self.mvc_solver.graph_data)
                }
                if self.store_data:
                    self.train_data_historty_cache[task_name][sel_train_sub_file_path] = self.train_data_cache[task_name]
            
            # update cache and index
            self.train_data_cache_idx[task_name] = 0
            begin_idx = self.train_data_cache_idx[task_name]
            end_idx = begin_idx + batch_size
        
        # retrieve a portion of data from the cache
        graph_data = self.train_data_cache[task_name]["graph_data"][begin_idx:end_idx]
        self.train_data_cache_idx[task_name] = end_idx
        
        # data process
        return self.data_processor.mvc_batch_data_process(graph_data)
     
    def generate_train_data_tsp(self, batch_size: int) -> Any:
        task_name = "TSP"
        # check data cache
        begin_idx = self.train_data_cache_idx[task_name]
        end_idx = begin_idx + batch_size
        if self.train_data_cache[task_name] is None or end_idx > self.train_data_cache[task_name]["data_size"]:
            # select one train file randomly
            sel_idx = np.random.randint(low=0, high=self.train_sub_files_num[task_name], size=(1,))[0]
            sel_train_sub_file_path = self.train_sub_files[task_name][sel_idx]

            # check if the data is in the cache when store_data is True
            if self.store_data and sel_train_sub_file_path in self.train_data_historty_cache[task_name].keys():
                # using data cache if the data is in the cache
                print(f"\nusing data cache ({sel_train_sub_file_path})")
                self.train_data_cache[task_name] = self.train_data_historty_cache[task_name][sel_train_sub_file_path]
            else: 
                # load data from the train file
                print(f"\nload tsp train data from {sel_train_sub_file_path}")
                self.tsp_solver.from_txt(sel_train_sub_file_path, show_time=True, ref=True, normalize=self.finetune)
                self.train_data_cache[task_name] = {
                    "points": self.tsp_solver.points,
                    "ref_tours": self.tsp_solver.ref_tours,
                    "data_size": self.tsp_solver.points.shape[0]
                }
                if self.store_data:
                    self.train_data_historty_cache[task_name][sel_train_sub_file_path] = self.train_data_cache[task_name]
            
            # update cache and index
            self.train_data_cache_idx[task_name] = 0
            begin_idx = self.train_data_cache_idx[task_name]
            end_idx = begin_idx + batch_size
    
        # retrieve a portion of data from the cache
        points = self.train_data_cache[task_name]["points"][begin_idx:end_idx]
        ref_tours = self.train_data_cache[task_name]["ref_tours"][begin_idx:end_idx]
        self.train_data_cache_idx[task_name] = end_idx
            
        # data process
        return self.data_processor.tsp_batch_data_process(points, ref_tours)

            
    #################################
    #            Finetune           #
    #################################
    
    def finetune_sparse(
        self, task: str, pred: Tensor, edges_feature, edge_index: Tensor
    ) -> Tensor:
        if task == "ATSP":
            return self.data_processor.atsp_finetune()
        elif task == "MCl":
            return self.data_processor.mcl_finetune()
        elif task == "MCut":
            return self.data_processor.mcut_finetune(
                pred=pred, edges_feature=edges_feature, edge_index=edge_index
            )
        elif task == "MIS":
            return self.data_processor.mis_finetune(
                pred=pred, edges_feature=edges_feature, edge_index=edge_index
            )
        elif task == "MVC":
            return self.data_processor.mvc_finetune() 
        elif task == "TSP":
            return self.data_processor.tsp_finetune()

    def finetune_dense(
        self, task: str, pred: Tensor, graph: Tensor
    ) -> Tensor:
        if task == "ATSP":
            return self.data_processor.atsp_finetune()
        elif task == "MCl":
            return self.data_processor.mcl_finetune()
        elif task == "MCut":
            return self.data_processor.mcut_finetune()
        elif task == "MIS":
            return self.data_processor.mis_finetune()
        elif task == "MVC":
            return self.data_processor.mvc_finetune() 
        elif task == "TSP":
            return self.data_processor.tsp_finetune()