from concurrent.futures import ThreadPoolExecutor, as_completed
from adbench.run import RunPipeline
import itertools
import numpy as np
from tqdm import tqdm

import os, glob

class RunPipeline_unsup(RunPipeline):
    def dataset_filter(self):
        dataset_list_org = list(itertools.chain(*self.data_generator.generate_dataset_list()))
        # dataset_list_org = [dataset_list_org[0]]

        def process_dataset(dataset):
            print(dataset)
            add = True
            for seed in self.seed_list:
                self.data_generator.seed = seed
                self.data_generator.dataset = dataset
                data = self.data_generator.generator(la=1.00, at_least_one_labeled=True)

                if not self.generate_duplicates and len(data['y_train']) + len(data['y_test']) < self.n_samples_threshold:
                    add = False
                elif self.mode == 'nla' and sum(data['y_train']) >= self.nla_list[-1]:
                    pass
                elif self.mode == 'rla' and sum(data['y_train']) > 0:
                    pass
                else:
                    add = False

                # remove high-dimensional CV and NLP datasets if generating synthetic anomalies or robustness test
                if self.realistic_synthetic_mode is not None or self.noise_type is not None:
                    if self.isin_NLPCV(dataset):
                        add = False

            if add:
                return dataset, len(data['y_train']) + len(data['y_test'])
            else:
                print(f"remove the dataset {dataset}")
                return None, None

        # Parallel processing with a progress bar
        dataset_list, dataset_size = [], []
        with ThreadPoolExecutor() as executor:
            future_to_dataset = {executor.submit(process_dataset, dataset): dataset for dataset in dataset_list_org}
            for future in tqdm(as_completed(future_to_dataset), total=len(dataset_list_org), desc="Processing Datasets"):
                result, size = future.result()
                if result is not None:
                    dataset_list.append(result)
                    dataset_size.append(size)

        # Sort datasets by their sample size
        dataset_list = [dataset_list[i] for i in np.argsort(np.array(dataset_size))]

        return dataset_list

    # run the experiments in ADBench and Graph
    def run(self, dataset=None, clf=None):
        # ADBench
        results = super().run(dataset=dataset, clf=clf)
        # Graph
        results.extend(self.run_graph(clf))
        
        if clf is not None:
            for i in range(len(results)):
                results[i][1] = clf(seed=42).model_name or results[i][1]
                
        return results
    
    def run_graph_old(self, clf=None):
        results = []
        npz_files = glob.glob('data/Graph_by_BGRL/*.npz') 
        for file_path in npz_files:
            data = np.load(file_path)
            pipeline = RunPipeline(suffix='ADBench', parallel='unsupervise', realistic_synthetic_mode=None, noise_type=None)
            tmp = pipeline.run(dataset=data, clf=clf)
            
            data_name = os.path.basename(file_path).split('.')[0]
            for i in range(len(tmp)):
                tmp[i][0] = (data_name, tmp[i][0][1], tmp[i][0][2])
                
            results.extend(tmp)
        return results
    
    def run_tabular_old(self, clf=None):
        results = []
        npz_files = glob.glob('data/Classical/*.npz') 
        for file_path in npz_files:
            data = np.load(file_path)
            pipeline = RunPipeline(suffix='ADBench', parallel='unsupervise', realistic_synthetic_mode=None, noise_type=None)
            tmp = pipeline.run(dataset=data, clf=clf)
            
            data_name = os.path.basename(file_path).split('.')[0]
            for i in range(len(tmp)):
                tmp[i][0] = (data_name, tmp[i][0][1], tmp[i][0][2])
                
            results.extend(tmp)
        return results
    
    def run_tabular(self, clf=None, model_name=None):
        self.cur_model_name = model_name
        npz_files = glob.glob('data/Classical/*.npz') 
        return self.run_pipeline(npz_files, clf)
    
    def run_pipeline(self, npz_files, clf=None, max_workers=10):
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            results = list(tqdm(executor.map(self.process_file, npz_files, [clf]*len(npz_files)), total=len(npz_files)))
    
        final_results = []
        for result in results:
            final_results.extend(result)
        return final_results

    def process_file(self, file_path, clf):
        data = np.load(file_path)
        pipeline = RunPipeline(suffix='ADBench', parallel='unsupervise', realistic_synthetic_mode=None, noise_type=None)
        tmp = pipeline.run(dataset=data, clf=clf)
        
        data_name = os.path.basename(file_path).split('.')[0]
        for i in range(len(tmp)):
            tmp[i][0] = (data_name, tmp[i][0][1], tmp[i][0][2])
            if self.cur_model_name is not None:
                tmp[i][1] = self.cur_model_name
        
        return tmp