import os, random, pathlib
from torch.utils.data import Dataset
import re
from copy import copy
import os
from ..constants import BACKDOOR_LABEL, CLEAN_LABEL

CLEAN_LABEL = 'clean'

def extract_models_paths(root_dir):
    paths = []
    for root, _, files in os.walk(root_dir):
        new_paths = [os.path.join(root, file) for file in files if file.endswith('.pt')]
        paths += new_paths
    return paths

class ModelDataset(Dataset):
    def __init__(self, root_folder, model_loader, sample=False,
                 sample_k=5, discards=None, version=None, more_clean=True, balanced=True):
        bads_folder = root_folder
        cleans_folder = os.path.join(root_folder, 'clean')
        self.data = []
        self.model_data_dict = {}
        self.loader = model_loader
        self.bads_data = []
        self.cleans_data = []
        
        self.version = version
        
        if discards is None:
            discards = []
        
        attack_folders = [x for x in os.listdir(bads_folder) if os.path.isdir(os.path.join(bads_folder, x))]
        for attack_folder in attack_folders:
            if attack_folder in discards + [CLEAN_LABEL]:
                continue
            bad_data = [{} for model_path in \
                extract_models_paths(os.path.join(bads_folder, attack_folder))]
            
            for i in range(len(bad_data)):
                bad_data[i]['label'] = BACKDOOR_LABEL

            if sample:
                bad_data = random.sample(bad_data, sample_k)
            self.bads_data += bad_data
            self.model_data_dict[attack_folder] = bad_data
        
        if CLEAN_LABEL not in discards:
            
            cleans_data = [{'label': CLEAN_LABEL} for model_path in extract_models_paths(cleans_folder)]
            
            if balanced:
                cleans_data = random.sample(cleans_data, min(len(self.bads_data), len(cleans_data)))
            self.cleans_data = cleans_data
        
        self.data = self.bads_data + self.cleans_data
        self.model_data_dict[CLEAN_LABEL] = self.cleans_data
    
        random.shuffle(self.data)

    def load_model(self, model_data):
        return self.loader(model_data['path'], model_data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        model = self.load_model(self.data[idx])
        label = int(self.data[idx].get('attack') is None)
        
        return model, label
    