from copy import deepcopy
from math import inf
from typing import Iterator, List, Sequence, Tuple

import torch
from torch import nn
from torch.nn.functional import one_hot
from torch.utils.data import ConcatDataset, DataLoader, Dataset, TensorDataset

from data import full_stage, full_stage_standard, load_stage_standard


# change one_hot_encode to false if you don't want to one_hot_encode
def dataset_error(train_data: Dataset, network: nn.Module, loss_func: nn.Module, adversarial: bool = False,
                  max_samples: int = inf, one_hot_encode: bool = False, batch_size: int = 100,
                  epsilon: float = 8 / 255, attack_iter: int = 10) -> float:
    """
    train_data: The whole training dataset
    network: The model
    loss_func: The loss function
    adversarial: Whether to evaluate on adversarial images generated from the test ones
    """
    error_lst = []
    size = len(train_data)
    num_samples = 0
    data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False, num_workers=0)
    for sample in data_loader:
        images, labels = sample
        curr_batch_size = labels.size(0)
        if one_hot_encode:
            labels = one_hot(labels).float()
        if adversarial:
            images = gen_adversarial(network, images, labels, loss_func, epsilon=epsilon, attack_iter=attack_iter,
                                     clamp_01=True)
        outputs = network(images)
        error_lst.append(curr_batch_size * loss_func(outputs, labels).item())
        num_samples += curr_batch_size
        if num_samples > max_samples:
            break
    return sum(error_lst) / min(size, max_samples)


def dataset_accuracy(train_data: Dataset, network: nn.Module, loss_func: nn.Module, adversarial: bool = False,
                     max_samples: int = inf, one_hot_encode: bool = False, batch_size: int = 100,
                     epsilon: float = 8 / 255, attack_iter: int = 10) -> float:
    """
    train_data: The whole training dataset
    network: The model
    loss_func: The loss function
    adversarial: Whether to evaluate on adversarial images generated from the test ones
    """
    correct = 0
    num_samples = 0
    data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False, num_workers=0)
    for sample in data_loader:
        images, labels_orig = sample
        labels = labels_orig.clone()
        curr_batch_size = labels.size(0)
        if one_hot_encode:
            labels = one_hot(labels)
        if adversarial:
            images = gen_adversarial(network, images, labels, loss_func, epsilon=epsilon, attack_iter=attack_iter,
                                     clamp_01=True)
        outputs = network(images)
        outputs = torch.argmax(outputs, dim=1)
        correct += (labels_orig == outputs).sum().item()
        num_samples += curr_batch_size
        if num_samples > max_samples:
            break
    return (correct * 1.0) / num_samples


def stage_error(network: nn.Module, train_data: Dataset, sorted_indices: Sequence[int], i: int, loss_func: nn.Module,
                one_hot_encode: bool = True, batch_size: int = 100,
                adversarial=False, epsilon: float = 8 / 255, attack_iter: int = 10) -> float:
    """
    network: The network used
    train_data: The training dataset used
    sorted_indices: The indices sorted according to their label
    i: The length of the prefix
    loss_func: The function used as loss
    one_hot_encode: whether you should one hot encode the labels
    """

    # get an iterator for the samples of this stage
    stage_iterator = full_stage(train_data, sorted_indices, i, batch_size=batch_size)

    # compute the error for all samples of this stage
    error_lst = []
    for images, labels in stage_iterator:
        curr_batch_size = labels.size(0)
        if one_hot_encode:
            labels = one_hot(labels)
        if adversarial:
            images = gen_adversarial(network, images, labels, loss_func, epsilon=epsilon, attack_iter=attack_iter,
                                     clamp_01=True)
        outputs = network(images)
        error_lst.append(curr_batch_size * loss_func(outputs.float(), labels.float()).item())
    return sum(error_lst) / (i + 1)


def stage_accuracy(network: nn.Module, data: Dataset, sorted_indices: Sequence[int], i: int, loss_func: nn.Module,
                   one_hot_encode: bool = True, batch_size: int = 100,
                   adversarial=False, epsilon: float = 8 / 255, attack_iter: int = 10) -> float:
    """
    network: The network used
    data: The dataset upon which the accuracy of the model is calculated
    sorted_indices: The indices sorted according to their label
    i:until what part of the prefix the accuracy should be calculated
    one_hot_encode: whether to one hot encode the labels
    """
    # get an iterator for the samples of this stage
    stage_iterator = full_stage(data, sorted_indices, i, batch_size=batch_size)

    # compute the accuracy for all samples that are relevant for this stage
    correct = 0
    total = 0
    for images, labels in stage_iterator:
        curr_batch_size = labels.size(0)
        if adversarial:
            images = gen_adversarial(network, images, labels, loss_func, epsilon=epsilon, attack_iter=attack_iter,
                                     clamp_01=True)
        outputs = network(images)
        if one_hot_encode:
            labels = torch.argmax(labels, dim=1)
        outputs = torch.argmax(outputs, dim=1)
        total += curr_batch_size
        correct += (labels == outputs).sum().item()
    return (correct * 1.0) / total


def find_max_label(data: Dataset, sorted_indices: Sequence[int], cid: int, current: int) -> int:
    """
    data: The dataset which we parse
    sorted_indices: The indices of the dataset sorted according to label
    cid: The current id
    current: The current max label of the data
    """
    max_idx = len(sorted_indices)
    i = cid
    while (i < max_idx) and (data.targets[sorted_indices[i]] <= current):
        i += 1
    i -= 1
    return i


def gen_adversarial(model: nn.Module, x: torch.Tensor, y: torch.Tensor, loss_func: nn.Module,
                    epsilon: float = 8 / 255, input_noise_rate: int = 1, step_size: float = 0.25,
                    attack_iter: int = 10, clamp: bool = True, clamp_01: bool = False) -> torch.Tensor:
    x_adv = nn.Parameter(deepcopy(x) + input_noise_rate * epsilon * (2 * torch.rand(x.shape, device=x.device) - 1))
    for i in range(attack_iter):
        x_adv.grad = None
        model.zero_grad()
        outputs = model(x_adv)
        loss_func(outputs, y).backward(retain_graph=True)
        grads_input = deepcopy(x_adv.grad)
        x_adv.data = x_adv.data + epsilon * step_size * torch.sign(grads_input)
        if clamp:
            x_adv.data = torch.min(torch.max(x_adv.data, x - epsilon), x + epsilon)
        if clamp_01:
            x_adv.data = torch.clamp(x_adv.data, 0, 1)
    return x_adv.detach().data


# def gen_adversarial_2(model: nn.Module, x: torch.Tensor, y: torch.Tensor, loss_func: nn.Module,
#                     epsilon: float = 8 / 255, input_noise_rate: int = 1, step_size: float = 0.25,
#                     attack_iter: int = 10, clamp: bool = True, clamp_01: bool = False) -> torch.Tensor:
#     x_adv = nn.Parameter(deepcopy(x) + input_noise_rate * epsilon * (2 * torch.rand(x.shape, device=x.device) - 1))
#     for i in range(attack_iter):
#         x_adv.grad = None
#         outputs = model(x_adv)
#         loss_func(outputs.float(), y.float()).backward(retain_graph=True)
#         with torch.no_grad():
#             grads_input = deepcopy(x_adv.grad)
#             x_adv.data = x_adv.data + epsilon * step_size * torch.sign(grads_input)
#             if clamp:
#                 x_adv.data = torch.min(torch.max(x_adv.data, x - epsilon), x + epsilon)
#             if clamp_01:
#                 x_adv.data = torch.clamp(x_adv.data, 0, 1)
#         #print(i,torch.abs(x-x_adv).sum())
#     return x_adv.detach().data


# def save_adversarial(adv_list: List[Tuple[torch.Tensor, torch.Tensor]], x: torch.Tensor, y: torch.Tensor):
#     """
#     adv_list: The list of previous adversarial examples
#     x: The sample
#     y: The label
#     """
#     adv_list.append((x, y))


class EfficientDataset(Dataset):
    def __init__(self, dataset: Dataset):
        super().__init__()
        self.samples = [sample for sample in dataset]
        self.targets = [target for _, target in self.samples]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, i: int):
        return self.samples[i]


class AdversarialLoader(DataLoader):
    def __init__(self, train_data, num_iterations, batch_size):
        self.train_data = train_data
        self.adv_data = []
        self.num_iterations = num_iterations
        self.batch_size = batch_size
        self.full_stage = False

    def new_adversarial(self, model, x, y, y_orig, loss_func, epsilon=8 / 255, input_noise_rate=1, step_size=0.25,
                        attack_iter=10, clamp=True):
        x_adv = gen_adversarial(model, x, y, loss_func, epsilon, input_noise_rate, step_size, attack_iter, clamp)
        self.adv_data.append((x_adv, y_orig))

    def __len__(self):
        return (len(self.train_data) + len(self.adv_data)) // self.batch_size

    def __iter__(self) -> Iterator:
        if len(self.adv_data) > 0:
            adv_x = torch.cat([x for x, _ in self.adv_data], dim=0)
            adv_y = torch.cat([y for _, y in self.adv_data], dim=0)
            adv_dataset = TensorDataset(adv_x, adv_y)
            #train_data_plus_adv = ConcatDataset([self.train_data, adv_dataset])
            train_data_plus_adv = adv_dataset
        else:
            train_data_plus_adv = self.train_data
        if self.full_stage:
            yield from full_stage_standard(train_data_plus_adv, batch_size=self.batch_size)
        else:
            yield from load_stage_standard(train_data_plus_adv, self.num_iterations, batch_size=self.batch_size)
