import copy
from collections import Counter

import torch
import numpy as np
from tqdm import tqdm
from typing import *
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import get_scheduler
from transformers import AutoModelForCausalLM, AutoTokenizer

import utils
from data import CustomDataset



def create_noniid_data(dataset, config, train: bool = True) -> List[CustomDataset]:
    """
    Synthesize a non-i.i.d dataset by taking a small portion of other class samples (`config.federated.noniid_size` per se) 
    while keeping the rest from the same class. 
    """
    if train:
        client_size = config.federated.train_client_size
    else:
        client_size = config.federated.eval_client_size
        
    if config.federated.noniid_ratio == 0:
        noniid_size = 0
    else:
        noniid_size = int(client_size / (config.federated.num_clients - 1 + 1 / config.federated.noniid_ratio))
    print("non iid size:", noniid_size)

    labels = np.array([sample[1] for sample in dataset], dtype=np.int32)
    class_ids = [[] for _ in range(config.federated.num_clients)]
    # only take data from class with id < num_clients
    for id, lbl in enumerate(labels):
        if lbl >= config.federated.num_clients: continue
        class_ids[lbl].append(id)

    # keep small portion of data from each class
    small_class_size = noniid_size * config.federated.num_clients
    small_class_ids = []
    for c in range(config.federated.num_clients):
        small_class_ids.append(class_ids[c][:small_class_size])
        class_ids[c] = class_ids[c][small_class_size:]

    for c in range(config.federated.num_clients):
        small_class_ids[c] = np.array(small_class_ids[c], dtype=np.int32)
        class_ids[c] = np.array(class_ids[c], dtype=np.int32)

    # for client, append small portion of data from others' class
    for ci in range(config.federated.num_clients):
        for cj in range(config.federated.num_clients):
            start_pos = ci * noniid_size
            noniid = small_class_ids[cj][start_pos: start_pos + noniid_size]
            class_ids[ci] = np.concatenate((noniid, class_ids[ci]), axis=0)
        class_ids[ci] = class_ids[ci][:client_size]
    
    # map non-iid ids back to data instances
    res = []
    for c in range(config.federated.num_clients):
        c_dataset = CustomDataset([dataset[id] for id in class_ids[c]])
        res.append(c_dataset)
        label_counter = Counter(labels[class_ids[c]])
        label_counter = {k: v for k, v in sorted(label_counter.items())}
        print(f"client {c}: {label_counter}")
    
    return res