import random
import pandas
import numpy as np
from collections import defaultdict

def split_dataset(fed_args, script_args, dataset):
    train_dataset = dataset
    dataset = train_dataset.shuffle(seed=script_args.seed)  # Shuffle the dataset with seed
    local_datasets = []

    if fed_args.split_strategy == "iid":
        # for i in range(fed_args.num_clients):
        #     local_datasets.append(dataset.shard(fed_args.num_clients, i))
        
        num_clients = fed_args.num_clients
        samples_per_type_per_client = 1500
        type_list = list(set(dataset.to_pandas()['type']))
        type_data = defaultdict(list)
        
        np.random.seed(script_args.seed)

        for idx, example in enumerate(dataset):
            type_data[example['type']].append(idx)

        client_indices = [[] for _ in range(num_clients)]
        
        for type_key in type_list:
            type_samples = type_data[type_key]
            num_samples = len(type_samples)
            
            if num_samples == 0:
                #print(f"Warning: No samples for categoryuage {category}")
                continue
            
            np.random.shuffle(type_samples)
            available_per_client = num_samples // num_clients
            num_to_take = min(samples_per_type_per_client, available_per_client)
            
            if num_to_take == 0:
                print(f"Warning: Not enough data of type '{type_key}' ({num_samples} total) to give at least 1 sample per client. Skipping type for distribution.")
                continue
                
            for i in range(num_clients):
                start_index = i * num_to_take
                end_index = start_index + num_to_take
                indices_for_client_this_type = type_samples[start_index:end_index]
                client_indices[i].extend(indices_for_client_this_type)

        for i in range(num_clients):
            if len(client_indices[i]) == 0:
                print(f"Warning: Client {i} has no data.")
            local_datasets.append(dataset.select(client_indices[i]))

    # m-a-p
    elif fed_args.split_strategy == "dirichlet-noniid" and script_args.dataset_name=="m-a-p/CodeFeedback-Filtered-Instruction":
        num_clients = fed_args.num_clients
        alpha = fed_args.noniid_degree  
        
        lang_list = list(set(dataset.to_pandas()['lang']))
        lang_data = defaultdict(list)
        
        np.random.seed(script_args.seed)

        for idx, example in enumerate(dataset):
            lang_data[example['lang']].append(idx)

        client_indices = [[] for _ in range(num_clients)]
        
        for lang in lang_list:
            lang_samples = lang_data[lang]
            num_samples = len(lang_samples)
            
            if num_samples == 0:
                #print(f"Warning: No samples for language {lang}")
                continue
            
            proportions = np.random.dirichlet([alpha] * num_clients)
            #print(f"Dirichlet proportions for lang '{lang}': {proportions}")
            
            split_indices = np.split(
                lang_samples,
                (np.cumsum(proportions)[:-1] * num_samples).astype(int)
            )

            for i, indices in enumerate(split_indices):
                client_indices[i].extend(indices)
                # print(f"Client {i} sample count: {len(indices)}")
            
        for i in range(num_clients):
            if len(client_indices[i]) == 0:
                print(f"Warning: Client {i} has no data.")
            local_datasets.append(dataset.select(client_indices[i]))

    # dolly
    elif fed_args.split_strategy == "dirichlet-noniid" and script_args.dataset_name=="databricks/dolly-15k":
        num_clients = fed_args.num_clients
        alpha = fed_args.noniid_degree  
        
        category_list = list(set(dataset.to_pandas()['category']))
        category_data = defaultdict(list)
        
        np.random.seed(script_args.seed)

        for idx, example in enumerate(dataset):
            category_data[example['category']].append(idx)

        client_indices = [[] for _ in range(num_clients)]
        
        for category in category_list:
            category_samples = category_data[category]
            num_samples = len(category_samples)
            
            if num_samples == 0:
                #print(f"Warning: No samples for categoryuage {category}")
                continue
            
            proportions = np.random.dirichlet([alpha] * num_clients)
            #print(f"Dirichlet proportions for category '{category}': {proportions}")
            
            split_indices = np.split(
                category_samples,
                (np.cumsum(proportions)[:-1] * num_samples).astype(int)
            )

            for i, indices in enumerate(split_indices):
                client_indices[i].extend(indices)
                #print(f"Client {i} sample count: {len(indices)}")

        for i in range(num_clients):
            if len(client_indices[i]) == 0:
                print(f"Warning: Client {i} has no data.")
            local_datasets.append(dataset.select(client_indices[i]))
            
    ## TIGER-Lab
    elif fed_args.split_strategy == "dirichlet-noniid" and script_args.dataset_name=="TIGER-Lab/MathInstruct":
        num_clients = fed_args.num_clients
        alpha = fed_args.noniid_degree  
        
        source_list = list(set(dataset.to_pandas()['source']))
        source_data = defaultdict(list)
        
        np.random.seed(script_args.seed)

        for idx, example in enumerate(dataset):
            source_data[example['source']].append(idx)

        client_indices = [[] for _ in range(num_clients)]
        
        for source in source_list:
            source_samples = source_data[source]
            num_samples = len(source_samples)
            
            if num_samples == 0:
                #print(f"Warning: No samples for sourceuage {source}")
                continue
            
            proportions = np.random.dirichlet([alpha] * num_clients)
            #print(f"Dirichlet proportions for source '{source}': {proportions}")
            
            split_indices = np.split(
                source_samples,
                (np.cumsum(proportions)[:-1] * num_samples).astype(int)
            )

            for i, indices in enumerate(split_indices):
                client_indices[i].extend(indices)
                #print(f"Client {i} sample count: {len(indices)}")

        for i in range(num_clients):
            if len(client_indices[i]) == 0:
                print(f"Warning: Client {i} has no data.")
            local_datasets.append(dataset.select(client_indices[i]))
            
    ##"EleutherAI/hendrycks_math"
    elif fed_args.split_strategy == "dirichlet-noniid" and script_args.dataset_name=="EleutherAI/hendrycks_math":
        num_clients = fed_args.num_clients
        alpha = fed_args.noniid_degree  
        
        category_list = list(set(dataset.to_pandas()['type']))
        category_data = defaultdict(list)
        
        np.random.seed(script_args.seed)

        for idx, example in enumerate(dataset):
            category_data[example['type']].append(idx)

        client_indices = [[] for _ in range(num_clients)]
        
        for category in category_list:
            category_samples = category_data[category]
            num_samples = len(category_samples)
            
            if num_samples == 0:
                #print(f"Warning: No samples for categoryuage {category}")
                continue
            
            proportions = np.random.dirichlet([alpha] * num_clients)
            #print(f"Dirichlet proportions for category '{category}': {proportions}")
            
            split_indices = np.split(
                category_samples,
                (np.cumsum(proportions)[:-1] * num_samples).astype(int)
            )

            for i, indices in enumerate(split_indices):
                client_indices[i].extend(indices)
                #print(f"Client {i} sample count: {len(indices)}")

        for i in range(num_clients):
            if len(client_indices[i]) == 0:
                print(f"Warning: Client {i} has no data.")
            local_datasets.append(dataset.select(client_indices[i]))
    
    ##"meta-math/MetaMathQA"
    elif fed_args.split_strategy == "dirichlet-noniid" and script_args.dataset_name=="meta-math/MetaMathQA":
        num_clients = fed_args.num_clients
        alpha = fed_args.noniid_degree  
        
        type_list = list(set(dataset.to_pandas()['type']))
        type_data = defaultdict(list)
        
        np.random.seed(script_args.seed)

        for idx, example in enumerate(dataset):
            type_data[example['type']].append(idx)

        client_indices = [[] for _ in range(num_clients)]
        
        for type in type_list:
            type_samples = type_data[type]
            num_samples = len(type_samples)
            
            if num_samples == 0:
                #print(f"Warning: No samples for categoryuage {category}")
                continue
            
            proportions = np.random.dirichlet([alpha] * num_clients)
            #print(f"Dirichlet proportions for category '{category}': {proportions}")
            
            split_indices = np.split(
                type_samples,
                (np.cumsum(proportions)[:-1] * num_samples).astype(int)
            )

            for i, indices in enumerate(split_indices):
                client_indices[i].extend(indices)
                #print(f"Client {i} sample count: {len(indices)}")

        for i in range(num_clients):
            if len(client_indices[i]) == 0:
                print(f"Warning: Client {i} has no data.")
            local_datasets.append(dataset.select(client_indices[i]))
    
    return local_datasets

def get_dataset_this_round(dataset, round, fed_args, script_args):
    
    num2sample = script_args.batch_size * script_args.gradient_accumulation_steps * script_args.max_steps
    num2sample = min(num2sample, len(dataset))
    random.seed(round)
    random_idx = random.sample(range(0, len(dataset)), num2sample)
    dataset_this_round = dataset.select(random_idx)

    return dataset_this_round