import os
import json
import random
from collections import defaultdict
import numpy as np 

class IncrementalDatasetBuilder:
    def __init__(self, 
                 dataset_path,
                 train_label_file,
                 test_label_file,
                 n_tasks=20,
                 first_task_class_num=5,
                 val_ratio=0.1,
                 random_seed=2023,
                 random_order=True):
       
        self.dataset_path = dataset_path
        self.train_label_file = train_label_file
        self.test_label_file = test_label_file
        self.n_tasks = n_tasks
        self.first_task_class_num = first_task_class_num
        self.val_ratio = val_ratio
        self.random_seed = random_seed
        self.random_order = random_order
        
        random.seed(self.random_seed)
        np.random.seed(self.random_seed)

    def load_label_file(self, label_file):

        label_dict = defaultdict(list)
        with open(label_file, 'r') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                sample_info = line.split()[0]
                if '/' not in sample_info:
                # if '/' not in line:
                    continue
                class_name, sample_name = sample_info.split('/', 1)
                # class_name, sample_name = line.split('/', 1)
                label_dict[class_name].append(sample_name)
        return dict(label_dict)

    def split_into_tasks(self, all_classes):
      
        if self.random_order:
            random.shuffle(all_classes)
        
        first_task_classes = all_classes[:self.first_task_class_num]
        remaining_classes = all_classes[self.first_task_class_num:]
        
        tasks = [first_task_classes]
        if self.n_tasks > 1:
            num_remaining_tasks = self.n_tasks - 1
            avg = len(remaining_classes) // num_remaining_tasks
            extra = len(remaining_classes) % num_remaining_tasks
            start_idx = 0
            for i in range(num_remaining_tasks):
                end_idx = start_idx + avg + (1 if i < extra else 0)
                tasks.append(remaining_classes[start_idx:end_idx])
                start_idx = end_idx
        return tasks

    def create_data_for_tasks(self, train_dict, test_dict, tasks):
        data = {
            'train': [],
            'test': []
        }
        if self.val_ratio > 0:
            data['val'] = []
            
        for task_classes in tasks:
            task_train = {}
            task_test = {}
            if self.val_ratio > 0:
                task_val = {}
            for cls in task_classes:
                if cls in train_dict:
                    all_train_samples = train_dict[cls][:]
                    if self.val_ratio > 0:
                        random.shuffle(all_train_samples)
                        split_idx = int(len(all_train_samples) * self.val_ratio)
                        task_val[cls] = all_train_samples[:split_idx]
                        task_train[cls] = all_train_samples[split_idx:]
                    else:
                        task_train[cls] = all_train_samples
                else:
                    if self.val_ratio > 0:
                        task_val[cls] = []
                    task_train[cls] = []
                
                if cls in test_dict:
                    task_test[cls] = test_dict[cls]
                else:
                    task_test[cls] = []
            data['train'].append(task_train)
            data['test'].append(task_test)
            if self.val_ratio > 0:
                data['val'].append(task_val)
        return data

    def build(self, save_json_path=None):
       
        train_dict = self.load_label_file(self.train_label_file)
        test_dict  = self.load_label_file(self.test_label_file)
        
        all_classes = list(set(list(train_dict.keys()) + list(test_dict.keys())))
        all_classes.sort() 
        
        tasks = self.split_into_tasks(all_classes)
        data = self.create_data_for_tasks(train_dict, test_dict, tasks)
        
        if save_json_path is not None:
            with open(save_json_path, "w") as f:
                json.dump(data, f, indent=4)
        else:
            print("done!")
        return data