import copy
from collections import Counter

import torch
import numpy as np
from tqdm import tqdm
from typing import *
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import get_scheduler
from transformers import AutoModelForCausalLM, AutoTokenizer

import utils
from data import CustomDataset



def create_noniid_data(dataset, config, train: bool = True) -> List[CustomDataset]:
    """
    Synthesize a non-i.i.d dataset by taking a small portion of other class samples (`config.federated.noniid_size` per se) 
    while keeping the rest from the same class. 
    """
    if train:
        client_size = config.train_client_size
    else:
        client_size = config.eval_client_size
        
    if config.noniid_ratio == 0:
        noniid_size = 0
    else:
        noniid_size = int(client_size / (config.num_clients - 1 + 1 / config.noniid_ratio))
    print("non iid size:", noniid_size)

    labels = np.array([sample[1] for sample in dataset], dtype=np.int32)
    class_ids = [[] for _ in range(config.num_clients)]
    # only take data from class with id < num_clients
    for id, lbl in enumerate(labels):
        if lbl >= config.num_clients: continue
        class_ids[lbl].append(id)

    # keep small portion of data from each class
    small_class_size = noniid_size * config.num_clients
    small_class_ids = []
    for c in range(config.num_clients):
        small_class_ids.append(class_ids[c][:small_class_size])
        class_ids[c] = class_ids[c][small_class_size:]

    for c in range(config.num_clients):
        small_class_ids[c] = np.array(small_class_ids[c], dtype=np.int32)
        class_ids[c] = np.array(class_ids[c], dtype=np.int32)

    # for client, append small portion of data from others' class
    for ci in range(config.num_clients):
        for cj in range(config.num_clients):
            start_pos = ci * noniid_size
            noniid = small_class_ids[cj][start_pos: start_pos + noniid_size]
            class_ids[ci] = np.concatenate((noniid, class_ids[ci]), axis=0)
        class_ids[ci] = class_ids[ci][:client_size]
    
    # map non-iid ids back to data instances
    res = []
    for c in range(config.num_clients):
        c_dataset = CustomDataset([dataset[id] for id in class_ids[c]])
        res.append(c_dataset)
        label_counter = Counter(labels[class_ids[c]])
        label_counter = {k: v for k, v in sorted(label_counter.items())}
        print(f"client {c}: {label_counter}")
    
    return res


def train_fed(server_model: Union[nn.Module, AutoModelForCausalLM],
              client_loaders: Dict[int, CustomDataset], 
              config: Dict,
              eval_fn: Callable,
              is_llm: bool = False,
              tokenizer = None,
              device=None,
              ) -> Union[nn.Module, AutoModelForCausalLM]:

    client_models = {}
    client_indices = client_loaders.keys()
    
    num_steps = config.federated.num_rounds * len(client_indices)
    progress_bar = tqdm(range(num_steps), desc="train_fl")
    for rnd in range(config.federated.num_rounds):
        for c_idx in client_indices:
            client_models[c_idx] = copy.deepcopy(server_model)
            if is_llm:
                client_models[c_idx] = train_model_llm_pytorch(client_models[c_idx],
                                                               client_loaders[c_idx]["train"],
                                                               tokenizer=tokenizer,
                                                               config=config.local,
                                                               device=device,
                                                               ) 
            else:
                client_models[c_idx] = train_model(client_models[c_idx], 
                                                   client_loaders[c_idx]["train"], 
                                                   verbose=False, 
                                                   config=config.local, 
                                                   device=device,
                                                   )
            progress_bar.update(1)

        # do weighted aggregation
        client_weights = []
        server_device = next(server_model.parameters()).device
        for c_idx in client_indices:
            num_train_samples = len(client_loaders[c_idx]["train"].dataset)
            client_weights.append(num_train_samples)
            client_models[c_idx].to(server_device)
        client_weights = [x / sum(client_weights) for x in client_weights]

        aggr_params(server_model, client_models.values(), client_weights)

        # free up gpus
        for c_idx in client_indices:
            del client_models[c_idx]
            torch.cuda.empty_cache()

        # do evaluation and logging
        if eval_fn is not None:
            eval_results = eval_fn(server_model)
            if (rnd + 1) % config.federated.log_frequency == 0 or (rnd + 1) == config.federated.num_rounds: 
                print("round:", rnd)
                for k, v in eval_results.items():
                    print("{}: {:.4f}".format(k, v))

    return server_model 


def train_model(model: nn.Module, 
                train_loader: DataLoader, 
                config: Dict,
                eval_loader: DataLoader = None, 
                verbose: bool = False, 
                device=None,
                ) -> nn.Module:
    """
    Train a local model on client data.
    # WARNING: Don't put "model.to(device)" inside this function coz it causes problems with shared memory in multiprocessing.
    """
    model.to(device)
    model.train()

    # create optimizer
    optimizer_cls = getattr(torch.optim, config.optimizer)
    optimizer_kwargs = utils.get_kwargs(optimizer_cls, config)
    optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs) 

    # create loss
    loss_func = getattr(torch.nn, config.loss)
    compute_loss = loss_func()

    # create learning rate scheduler if necessary
    if getattr(config, 'lr_scheduler', None) is None:
        lr_scheduler = None
    else:
        lr_scheduler_cls = getattr(torch.optim.lr_scheduler, config.lr_scheduler)
        lr_scheduler_kwargs = utils.get_kwargs(lr_scheduler, config)
        lr_scheduler = lr_scheduler_cls(optimizer, **lr_scheduler_kwargs)

    # start training
    count_step = 0
    average_loss = 0.0
    log_freq = 500
    for epoch in range(config.num_epochs):
        for batch in train_loader:
            count_step += 1
            optimizer.zero_grad()
            inputs = batch[0].to(device)
            targets = batch[1].to(device)
            outputs = model(inputs)
            loss = compute_loss(outputs, targets)
            loss.backward()
            average_loss += loss.item()
            if getattr(config, 'grad_clip', None) is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) 
            optimizer.step()
            if count_step % log_freq == 0:
                if verbose:
                    print(f'step {count_step}, average_loss: {average_loss / log_freq}')
                average_loss = 0.0
        if lr_scheduler is not None:
            lr_scheduler.step()
        if eval_loader is not None:
            raise NotImplementedError
            eval_loss = utils.check_loss(model, eval_loader, compute_loss, device=device)
            eval_acc = utils.check_accuracy(model, eval_loader, device=device)
            print('eval_loss:', eval_loss)
            print('eval_acc:', eval_acc)
    
    return model

def train_model_llm_pytorch(model: AutoModelForCausalLM, 
                            train_loader: DataLoader, 
                            config: Dict, 
                            tokenizer: AutoTokenizer,
                            eval_loader: DataLoader = None, 
                            verbose: bool = False,
                            device=None,
                            ) -> AutoModelForCausalLM:

    # prepare the model
    model.to(device)
    model.train()

    # create optimizer
    optimizer_cls = getattr(torch.optim, config.optimizer)
    optimizer_kwargs = utils.get_kwargs(optimizer_cls, config)
    optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs) 

    # initialize learning rate scheduler
    num_training_steps = config.num_epochs * len(train_loader)
    lr_scheduler = get_scheduler("linear", 
                                 optimizer=optimizer, 
                                 num_warmup_steps=0, 
                                 num_training_steps=num_training_steps,
                                 )

    # training loop
    if verbose:
        progress_bar = tqdm(range(num_training_steps))
    log_freq = 10
    for epoch in range(config.num_epochs):
        for batch in train_loader:  # use shuffled_train_loader instead of train_loader
            # move batch to device
            batch = {k: v.to(device) for k, v in batch.items()}

            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()
            # gradient clipping
            if getattr(config, 'grad_clip', None) is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)

            # optimizer step
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            # logging
            if verbose:
                if progress_bar.n % log_freq == 0:
                    print(f"Epoch: {epoch+1}/{config.num_epochs}, Step: {progress_bar.n+1}/{num_training_steps}, Loss: {loss.item():.4f}")
                progress_bar.update(1)

    if eval_loader is not None:
        raise NotImplementedError
        eval_acc = utils.check_accuracy_llm(model, eval_loader, tokenizer, device=device)
        eval_loss = utils.check_loss_llm(model, eval_loader, tokenizer, device=device)
    return model

def convert_dataloader_to_hf_dataset_with_text(dataloader, tokenizer):
    """
    Convert a PyTorch DataLoader to a Hugging Face Dataset with text conversion.

    Args:
    dataloader (torch.utils.data.DataLoader): The DataLoader to convert.
    tokenizer (transformers.PreTrainedTokenizer): The tokenizer used for encoding the text.

    Returns:
    datasets.Dataset: The Hugging Face Dataset with text fields.
    """
    def convert_batch_to_dict(batch):
        return {
            'input_ids': batch['input_ids'].tolist(),
            'labels': batch['labels'].tolist(),
            'attention_mask': batch['attention_mask'].tolist()
        }

    # Collect all batches from the DataLoader
    input_ids = []
    labels = []
    attention_masks = []
    input_texts = []
    label_texts = []

    for batch in dataloader:
        batch_dict = convert_batch_to_dict(batch)
        input_ids.extend(batch_dict['input_ids'])
        labels.extend(batch_dict['labels'])
        attention_masks.extend(batch_dict['attention_mask'])
        
        # Decode input_ids and labels back to text, filtering out -100 values
        for ids in batch_dict['input_ids']:
            if ids is not None:
                input_texts.append(tokenizer.decode([id for id in ids if id != -100], skip_special_tokens=True))
            else:
                input_texts.append(None)
        for ids in batch_dict['labels']:
            if ids is not None:
                label_texts.append(tokenizer.decode([id for id in ids if id != -100], skip_special_tokens=True))
            else:
                label_texts.append(None)

    # for attention_mask in attention_masks:
    #     attention_mask[-1] = True
    
    # for iiii in range(len(attention_masks)):
    #     attention_masks[iiii][-1] = True
    #     attention_masks[iiii] = [attention_masks[iiii]]
    # Create a single dictionary with lists of all data
    dataset_dict = {
        'input_ids': input_ids,
        'labels': labels,
        'attention_mask': attention_masks,
        # 'input_texts': input_texts,
        # 'label_texts': label_texts
    }

    # Create the Hugging Face Dataset
    hf_dataset = Dataset.from_dict(dataset_dict)
    
    return hf_dataset

def formatting_func(
    example: Dict[str, List[List[str]]],
) -> List[str]:
    """
    Formats each conversation in the dataset for training.
    Args:
        example (Dict[str, List[List[str]]]): A dataset.

    Returns:
        List[str]: A dataset with combined field.
    """
    output_texts = []
    for i in range(len(example["input_texts"])):
        text = example["input_texts"]
        output_texts.append(text)
    return output_texts

def aggr_params(server_model, client_models, client_weights):
    is_first_client = True
    for c_model, c_weight in zip(client_models, client_weights):
        for param, c_param in zip(server_model.parameters(), c_model.parameters()):
            if param.requires_grad:
                if is_first_client:
                    param.data += -param.data   # restart to be 0, direct assignment will replace the original tensor with the new one
                    param.data += c_param.data * c_weight
                else:
                    param.data += (c_param.data * c_weight)
        is_first_client = False
    return server_model
