import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from typing import List, Callable, Iterable, Tuple, Optional
from attributors.utils import set_seed
from data.utils import process_batch

def get_train_mask(group_ids: List[int], sorted_unique_group_ids: List[int], target: int) -> torch.Tensor:
    """
    Get a training mask that removes the sorted important groups until the target number of points is reached.
    """
    group_idx = 0
    num_removed = 0
    train_mask = torch.ones(len(group_ids), dtype=torch.int)
    group_ids, sorted_unique_group_ids = torch.tensor(group_ids), torch.tensor(sorted_unique_group_ids)
    while num_removed < target:
        train_idxs = group_ids == sorted_unique_group_ids[group_idx]
        remainder = target - num_removed
        if train_idxs.sum().item() > remainder:
            idxs = torch.where(train_idxs)[0]
            train_mask[idxs[:remainder]] = 0
        else:
            train_mask[train_idxs] = 0
        num_removed += train_idxs.sum().item()
        group_idx += 1
    return train_mask

def remove_and_retrain_evaluator_num_points(trainset: Dataset,
                                            testset: Dataset,
                                            sorted_unique_group_ids: List[int],
                                            group_ids: List[int],
                                            train_model: Callable[[Iterable[Dataset]], nn.Module],
                                            num_points_to_remove: List[int],
                                            use_model_cache: bool=False,
                                            device: str='cpu',
                                            verbose: bool=False,
                                            seed: Optional[int]=None) -> Tuple[List[List[int]], List[torch.Tensor]]:
    """
    Iteratively remove the most influential group(s) without replacement,
    and retrain the model and evaluate on the testset at each iteration.
    
    :param trainset: Training dataset
    :param testset: Test dataset
    :param sorted_unique_group_ids: Unique group IDs sorted by importance
    :param group_ids: group IDs for each datapoint in the trainset
    :param train_model: Function to train the model
    :param num_points_to_remove: Number of points to remove at each iteration
    :param model_dir: Directory to save/load models
    :param verbose: Print intermediate results
    return: List of training masks, List of test outputs
    """
    # Initialize variables
    train_mask = torch.ones(len(group_ids), dtype=torch.int)
    train_masks, test_outputs = [], []

    # Sort attributions and get group IDs
    if verbose:
        print("Number of groups:", len(sorted_unique_group_ids))
        

    # Iteratively remove groups
    for n_remove in num_points_to_remove:
        train_mask = get_train_mask(group_ids, sorted_unique_group_ids, n_remove).tolist()
        if verbose: print("Size of train_mask:", sum(train_mask))
        train_masks.append(train_mask)

        # Load model or retrain model and get test outputs
        if verbose: print("Beginning training...")
        if seed is not None:
            set_seed(seed)
        model = train_model(trainset, train_mask, device=device, verbose=verbose, use_model_cache=use_model_cache).eval()
        batch_size = 256

        if trainset.name in ['qnli', 'qnli_noisy']:
            from transformers import default_data_collator
            test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator)
        else:
            test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)


        # out_shape = model(testset[0][0].unsqueeze(0).float().to(device)).shape[1:]
        # outs1 = torch.zeros(len(testset), *out_shape)
        # for i, (x, y) in enumerate(test_loader):
        #     outs1[i*batch_size:(i+1)*batch_size] = model(x.float().to(device)).detach().cpu()

        outs = []
        for i, data in enumerate(test_loader):
            inputs, labels = process_batch(data, device=device)
            outs.append(model(**inputs).detach().cpu())
        outs = torch.cat(outs, dim=0)
        test_outputs.append(outs)

        if verbose:
            print(f"Number of points removed: {n_remove}\n")

    return train_masks, test_outputs


def remove_and_retrain_evaluator(trainset: Dataset,
                                 testset: Dataset,
                                 sorted_unique_group_ids: List[int],
                                 group_ids: List[int],
                                 train_model: Callable[[Iterable[Dataset]], nn.Module],
                                 topk: int=-1,
                                 step: int=1,
                                 use_model_cache: bool=False,
                                 device: str='cpu',
                                 verbose: bool=False) -> Tuple[List[List[int]], List[torch.Tensor]]:
    """
    Iteratively remove the most influential group(s) without replacement,
    and retrain the model and evaluate on the testset at each iteration.
    
    :param trainset: Training dataset
    :param testset: Test dataset
    :param sorted_unique_group_ids: Unique group IDs sorted by importance
    :param group_ids: group IDs for each datapoint in the trainset
    :param train_model: Function to train the model
    :param topk: Number of groups to remove in total
    :param step: Number of groups to remove at each iteration
    :param model_dir: Directory to save/load models
    :param verbose: Print intermediate results
    return: List of training masks, List of test outputs
    """
    # Initialize variables
    if topk == -1 or topk > len(sorted_unique_group_ids):
        topk = len(sorted_unique_group_ids)
    train_mask = torch.ones(len(group_ids), dtype=torch.int)
    train_masks, test_outputs = [], []

    # Sort attributions and get group IDs
    if verbose:
        print("Number of groups:", len(sorted_unique_group_ids))

    # Iteratively remove groups
    for i in range(0, topk, step):
        for unique_group_id in sorted_unique_group_ids[i:i+step]:
            train_mask[torch.tensor(group_ids) == unique_group_id] = 0
        train_masks.append(train_mask.tolist())

        # Load model or retrain model and get test outputs
        model = train_model(trainset, train_mask.tolist(), device=device, verbose=verbose, use_model_cache=use_model_cache).eval()
        test_outputs.append(model(testset.data.float().to(device)))

        if verbose:
            print(f"Groups removed: {i} to {i+step} of {topk}")

    return train_masks, test_outputs