"""Adversarial Training from Instance Optimality"""

from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
from torch import nn
from tqdm import tqdm
from torch.nn.functional import one_hot
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import Dataset, DataLoader

from data import Sample
from optim_utils import AdversarialLoader, dataset_accuracy, dataset_error, gen_adversarial


def SGD_Stage_Standard(network: nn.Module, optimizer: Optimizer, scheduler: LRScheduler, loss_func: nn.Module,
                       train_data: Dataset, dataloader: AdversarialLoader, T_i: int, one_hot_encode: bool = False,
                       adversarial: bool = False, epsilon: float = 8 / 255, attack_iter: int = 10,
                       get_errors: bool = False) -> List[float]:
    """
    network: The network used
    optimizer: The optimizer used
    loss_func: The function used as loss
    train_data: The training dataset used
    dataloader: The training dataloader used
    T_i: The iterations to do at this stage
    one_hot_encode: whether you should one hot encode the labels
    get_errors: give the errors
    """

    errors = []
    j = 0
    network.train()
    for images, labels in tqdm(dataloader):
        if one_hot_encode:
            labels = one_hot(labels, 10)
        if adversarial:
            images = gen_adversarial(network, images, labels, loss_func, epsilon=epsilon, attack_iter=attack_iter,
                                     input_noise_rate=0)
        outputs = network(images)
        if one_hot_encode:
            labels = labels.float()
        loss = loss_func(outputs, labels)

        # clear gradient for this training step
        optimizer.zero_grad()

        # backpropagation, compute gradients
        loss.backward()

        # apply gradients
        optimizer.step()
        scheduler.step()
        j += 1
        if get_errors:
            with torch.no_grad():
                error = dataset_error(train_data, network, loss_func)
                print(f"Iteration: {j}/{T_i} error: {error}")
                errors.append(error)
    return errors


def full_gradient_standard(frozen_model: nn.Module, vr_optimizer: Optimizer, frozen_optimizer: Optimizer,
                           loss_func: nn.Module, dataloader: AdversarialLoader, one_hot_encode: bool = False):
    """
    frozen_model: The network at the reference point
    vr_optimizer: The optimizer used
    frozen_optimizer: An optimizer used for the reference point
    loss_func: The function used as loss
    dataloader: The training dataloader used
    one_hot_encode: whether you should one hot encode the labels
    """

    # calculate a full gradient for the current network
    # set the optimizer to zero grad
    frozen_optimizer.zero_grad()
    for images, labels in tqdm(dataloader):
        if one_hot_encode:
            labels = one_hot(labels, 10).float()
        outputs = frozen_model(images)
        loss = loss_func(outputs, labels)
        loss.backward()

    vr = frozen_optimizer.get_param_groups()
    vr_optimizer.set_vr(vr)
    vr_optimizer.set_cr_stage(len(dataloader))


def update_full_gradient_standard(frozen_network: nn.Module, optimizer: Optimizer,
                                  frozen_optimizer: Optimizer, loss_func: nn.Module, sample: Sample):
    """
    frozen_model: The network at the reference point
    vr_optimizer: The optimizer used
    frozen_optimizer: An optimizer used for the reference point
    loss_func: The function used as loss
    sample: The sample used for the reference point
    one_hot_encode: whether you should one hot encode the labels
    """

    # get image and label to include in the full gradient
    images, labels = sample
    bs = images.shape[0]

    # clear gradients for reference point
    frozen_optimizer.zero_grad()

    # compute loss for reference point
    frozen_outputs = frozen_network(images)
    frozen_loss = loss_func(frozen_outputs, labels)

    # backpropagation, compute gradients in the frozen network
    frozen_loss.backward()

    sample_grad = frozen_optimizer.get_param_groups()
    optimizer.update_vr(sample_grad)
    optimizer.update_cr_stage(bs)


def FUM_Stage_Standard(network: nn.Module, frozen_network: nn.Module,
                       optimizer: Optimizer, frozen_optimizer: Optimizer, scheduler: LRScheduler, loss_func: nn.Module,
                       train_data: Dataset, dataloader: AdversarialLoader, T_init: int, 
                       last_batch: torch.tensor = None, last_label: torch.tensor = None, one_hot_encode: bool = True,
                       adversarial: bool = False, epsilon: float = 8 / 255, attack_iter: int = 10,
                       get_errors: bool = False) -> List[float]:
    """
    network: The network used
    frozen_network: The network at the reference point
    optimizer: The optimizer used
    frozen_optimizer: An optimizer used for the reference point
    loss_func: The function used as loss
    train_data: The training dataset used
    dataloader: The training dataloader used
    T_init: The iterations to do for the original training
    one_hot_encode: whether you should one hot encode the labels
    get_errors: give the errors
    """

    errors = []
    j = 0
    for images, labels in tqdm(dataloader):
        bs = images.shape[0]
        network.train()
        if one_hot_encode:
            labels = one_hot(labels, 10).float()
        images_adv = images.clone()
        if adversarial:
            images_adv = gen_adversarial(network, images_adv, labels, loss_func, epsilon=epsilon,
                                         attack_iter=attack_iter, input_noise_rate=0)
        if last_batch is not None:
            loss_no_reduce = nn.CrossEntropyLoss(reduce='none').to(images_adv.device)
            outputs = network(torch.cat(images_adv,last_batch,dim=0))
            loss = loss_no_reduce(outputs, torch.cat(labels,last_label, dim=0))
            loss[bs:]*=((1/optimizer.cr_stage)/(1-(1/optimizer.cr_stage)))
            loss = loss.mean()
        else:
            outputs = network(images_adv)
            loss = loss_func(outputs, labels)
        # clear gradient for this training step
        optimizer.zero_grad()

        # backpropagation, compute gradients
        loss.backward()

        # clear gradients for reference point
        frozen_optimizer.zero_grad()

        # compute loss for reference point
        
        frozen_outputs = frozen_network(images)
        frozen_loss = loss_func(frozen_outputs, labels)

        # backpropagation, compute gradients in the frozen network
        frozen_loss.backward()

        # apply gradients
        if last_batch is not None:
            optimizer.step2(frozen_optimizer.get_param_groups())
        else:
            optimizer.step(frozen_optimizer.get_param_groups())
        scheduler.step()
        j += 1
        if get_errors:
            with torch.no_grad():
                network.eval()
                error = dataset_error(train_data, network, loss_func)
                print(f"Iteration: {j}/{T_init} error: {error}")
                errors.append(error)

    return errors


def evaluate(
        network: nn.Module, loss_func: nn.Module, train_data: Dataset, test_data: Dataset,
        robustness: bool, one_hot_encode: bool, epsilon: float, attack_iter: int
) -> Tuple[float, float, float, float, float]:
    network.eval()
    with torch.no_grad():
        loss = dataset_error(train_data, network, loss_func, False, one_hot_encode=one_hot_encode)
        accuracy_train = dataset_accuracy(train_data, network, loss_func, False, one_hot_encode=one_hot_encode)
        accuracy_test = dataset_accuracy(test_data, network, loss_func, False, one_hot_encode=one_hot_encode)
    robustness_test = None
    robustness_train = None
    if robustness:
        # robustness_train = dataset_accuracy(train_data, network, loss_func, True,
        #                                    one_hot_encode=one_hot_encode,
        #                                    epsilon=epsilon, attack_iter=attack_iter)
        robustness_test = dataset_accuracy(test_data, network, loss_func, True,
                                           one_hot_encode=one_hot_encode,
                                           epsilon=epsilon, attack_iter=attack_iter)
    network.train()
    return loss, accuracy_train, accuracy_test, robustness_train, robustness_test


def Produce_SGD_Standard(network: nn.Module, optimizer: Optimizer,
                         scheduler: LRScheduler, loss_func: nn.Module,
                         train_data: Dataset, test_data: Dataset, T: int,
                         adversarial: bool = False, adversarial_iter: bool = False, stages: int = -1,
                         batch_size: int = 1, one_hot_encode: bool = False,
                         epsilon: float = 8 / 255, attack_iter: int = 10) -> Tuple[
    List[float], List[float], List[float], List[float], List[float], List[int]
]:
    """
    network: The network used
    optimizer: The optimizer used
    loss_func: The function used as loss
    train_data: The training dataset used
    test_data: The test dataset used
    T: The iterations to do for the training
    adversarial: whether to add adversarial examples every stage
    adversarial_iter: whether to learn on adversarial examples every iteration
    stages: How many stages the algorithm should do if not all of them
    batch_size: the number of samples of each mini batch
    one_hot_encode: whether you should one hot encode the labels
    """

    FOs = []
    Res = []
    TrainAcc = []
    TestAcc = []
    TestRob = []
    TrainRob = []

    N = len(train_data)
    # Run standard training on full dataset
    dataloader = AdversarialLoader(train_data, T, batch_size)
    dataloader.full_stage = True

    if stages < 0:
        stages = N // 3
    crFOs = 0
    for i in range(stages):

        # calculate the total FO complexity of this stage
        crFOs += T * batch_size
        FOs.append(crFOs)

        # Compute resulting loss for this stage accuracy on the train and test set
        loss, accuracy_train, accuracy_test, robustness_train, robustness_test = evaluate(
            network, loss_func, train_data, test_data, adversarial or adversarial_iter,
            one_hot_encode, epsilon, attack_iter
        )
        Res.append(loss)
        TrainAcc.append(accuracy_train)
        TestAcc.append(accuracy_test)
        TestRob.append(robustness_test)
        TrainRob.append(robustness_train)
        print(
            f"Stage: {i}/{stages}, Loss: {loss}, Training Accuracy: {accuracy_train}, "
            f"Test Accuracy: {accuracy_test}, Train Robustness: {robustness_train}, Test Robustness: {robustness_test}"
        )

        # run SGD for this stage
        SGD_Stage_Standard(network, optimizer, scheduler, loss_func, train_data, dataloader, T, one_hot_encode,
                           adversarial=adversarial_iter, epsilon=epsilon)

        if adversarial:
            # add adversarial examples
            sample_index = np.random.choice(N, size=batch_size, replace=True)
            x = []
            y_orig = []
            for id in sample_index:
                xx, yy = train_data[id]
                x.append(xx.unsqueeze(0))
                y_orig.append(yy.unsqueeze(0))
            x = torch.cat(x, dim=0)
            y_orig = torch.cat(y_orig, dim=0)
            y = y_orig.clone()
            if one_hot_encode:
                y = one_hot(y_orig, 10).float()
            dataloader.new_adversarial(network, x, y, y_orig, loss_func, epsilon=epsilon)

    loss, accuracy_train, accuracy_test, robustness_train, robustness_test = evaluate(
        network, loss_func, train_data, test_data, adversarial or adversarial_iter,
        one_hot_encode, epsilon, attack_iter
    )
    Res.append(loss)
    TrainAcc.append(accuracy_train)
    TestAcc.append(accuracy_test)
    TestRob.append(robustness_test)
    TrainRob.append(robustness_train)
    print(
        f"Stage: {i}/{stages}, Loss: {loss}, Training Accuracy: {accuracy_train}, "
        f"Test Accuracy: {accuracy_test}, Train Robustness: {robustness_train}, Test Robustness: {robustness_test}"
    )

    return Res, TrainAcc, TestAcc, TrainRob, TestRob, FOs


def SIOPT_Standard(network: nn.Module, frozen_network: nn.Module, optimizer: Optimizer, frozen_optimizer: Optimizer,
                   scheduler: LRScheduler, loss_func: nn.Module, train_data: Dataset, test_data: Dataset, T: int,
                   alpha: int, adversarial: bool = False, adversarial_iter: bool = False, stages: int = -1,
                   batch_size: int = 1, one_hot_encode: bool = False,
                   epsilon: float = 8 / 255, attack_iter: int = 10) -> Tuple[
    List[float], List[float], List[float], List[float], List[float], List[int]
]:
    """
    network: The network used
    frozen_network: The copy of the network
    optimizer: The optimizer used
    frozen_optimizer: The optimizer used for the full gradient
    loss_func: The function used as loss
    train_data: The training dataset used
    test_data: The test dataset used
    T: The iterations to do at this stage
    alpha: How often to do a full gradient computation
    batch_size: the number of samples of each mini batch
    one_hot_encode: whether you should one hot encode the labels
    """
    FOs = []
    Res = []
    TrainAcc = []
    TestAcc = []
    TrainRob = []
    TestRob = []

    N = len(train_data)
    # Run standard training on full dataset
    dataloader = AdversarialLoader(train_data, T, batch_size=batch_size)
    dataloader.full_stage = True

    if stages < 0:
        stages = N // 3
    crFOs = 0
    prev = 0
    update = False
    last_batch = None
    last_label = None
    for i in range(stages):

        # Compute resulting loss for this stage accuracy on the train and test set
        loss, accuracy_train, accuracy_test, robustness_train, robustness_test = evaluate(
            network, loss_func, train_data, test_data, adversarial or adversarial_iter,
            one_hot_encode, epsilon, attack_iter
        )
        Res.append(loss)
        TrainAcc.append(accuracy_train)
        TestAcc.append(accuracy_test)
        TrainRob.append(robustness_train)
        TestRob.append(robustness_test)
        print(
            f"Stage: {i}/{stages}, Loss: {loss}, Training Accuracy: {accuracy_train}, "
            f"Test Accuracy: {accuracy_test}, Train Robustness: {robustness_train}, Test Robustness: {robustness_test}"
        )

        # update the stage counter of the optimizer
        optimizer.update_cr_stage(batch_size)

        # calculate the total FO complexity of this stage
        crFOs += 2 * T * batch_size

        if i - prev >= alpha * i:
            update = True
            frozen_optimizer.set_param_groups(optimizer.get_param_groups())
            crFOs += N + len(dataloader.adv_data)
            full_gradient_standard(frozen_network, optimizer, frozen_optimizer, loss_func,
                                   dataloader, one_hot_encode=one_hot_encode)
        else:
            crFOs += 1
            sample_index = np.random.choice(N, size=batch_size, replace=True)
            x = []
            y = []
            for id in sample_index:
                xx, yy = train_data[id]
                x.append(xx.unsqueeze(0))
                y.append(yy.unsqueeze(0))
            x = torch.cat(x, dim=0)
            y = torch.cat(y, dim=0)
            if one_hot_encode:
                y = one_hot(y, 10)
            last_batch = x
            last_label = y
            update_full_gradient_standard(frozen_network, optimizer, frozen_optimizer, loss_func, (x, y))
        FUM_Stage_Standard(network, frozen_network, optimizer, frozen_optimizer, scheduler, loss_func,
                           train_data, dataloader, T, last_batch=last_batch, last_label=last_label, one_hot_encode=one_hot_encode,
                           adversarial=adversarial_iter, epsilon=epsilon)
        if update:
            update = False
            prev = i
            frozen_optimizer.set_param_groups(optimizer.get_param_groups())
            crFOs += N + len(dataloader.adv_data)
            full_gradient_standard(frozen_network, optimizer, frozen_optimizer, loss_func,
                                   dataloader, one_hot_encode=one_hot_encode)
        FOs.append(crFOs)

        if adversarial:
            # add adversarial examples
            sample_index = np.random.choice(N, size=batch_size, replace=True)
            x = []
            y_orig = []
            for id in sample_index:
                xx, yy = train_data[id]
                x.append(xx.unsqueeze(0))
                y_orig.append(yy.unsqueeze(0))
            x = torch.cat(x, dim=0)
            y_orig = torch.cat(y_orig, dim=0)
            y = y_orig.clone()
            if one_hot_encode:
                y = one_hot(y_orig, 10).float()
            dataloader.new_adversarial(network, x, y, y_orig, loss_func, epsilon=epsilon)

    loss, accuracy_train, accuracy_test, robustness_train, robustness_test = evaluate(
        network, loss_func, train_data, test_data, adversarial or adversarial_iter,
        one_hot_encode, epsilon, attack_iter
    )
    Res.append(loss)
    TrainAcc.append(accuracy_train)
    TestAcc.append(accuracy_test)
    TrainRob.append(robustness_train)
    TestRob.append(robustness_test)
    print(
        f"Stage: {i}/{stages}, Loss: {loss}, Training Accuracy: {accuracy_train}, "
        f"Test Accuracy: {accuracy_test}, Train Robustness: {robustness_train}, Test Robustness: {robustness_test}"
    )

    return Res, TrainAcc, TestAcc, TrainRob, TestRob, FOs


def SIOPT_Continual(network: nn.Module, frozen_network: nn.Module, optimizer: Optimizer, frozen_optimizer: Optimizer,
                   scheduler: LRScheduler, loss_func: nn.Module, train_data: Dataset, test_data: Dataset, T: int,
                   alpha: int, output_filename: str, adversarial: bool = False, adversarial_iter: bool = False, stages: int = -1,
                   batch_size: int = 1, one_hot_encode: bool = False,
                   epsilon: float = 8 / 255, attack_iter: int = 10) -> Tuple[
    List[float], List[float], List[float], List[float], List[float], List[int]
]:
    """
    network: The network used
    frozen_network: The copy of the network
    optimizer: The optimizer used
    frozen_optimizer: The optimizer used for the full gradient
    loss_func: The function used as loss
    train_data: The training dataset used
    test_data: The test dataset used
    T: The iterations to do at this stage
    alpha: How often to do a full gradient computation
    batch_size: the number of samples of each mini batch
    one_hot_encode: whether you should one hot encode the labels
    """
    FOs = []
    Res = []
    TrainAcc = []
    TestAcc = []
    TrainRob = []
    TestRob = []

    N = len(train_data)
    # Run standard training on full dataset
    dataloader = AdversarialLoader(train_data, T, batch_size=batch_size)
    # dataloader.full_stage = True

    if stages < 0:
        stages = N // 3
    crFOs = 0
    prev = 0
    update = False
    for i in range(stages):
        if not i % 200:
            # Compute resulting loss for this stage accuracy on the train and test set
            loss, accuracy_train, accuracy_test, robustness_train, robustness_test = evaluate(
                network, loss_func, train_data, test_data, adversarial or adversarial_iter,
                one_hot_encode, epsilon, attack_iter
            )
            Res.append(loss)
            TrainAcc.append(accuracy_train)
            TestAcc.append(accuracy_test)
            TrainRob.append(robustness_train)
            TestRob.append(robustness_test)
            print(
                f"Stage: {i}/{stages}, Loss: {loss}, Training Accuracy: {accuracy_train}, "
                f"Test Accuracy: {accuracy_test}, Train Robustness: {robustness_train}, Test Robustness: {robustness_test}"
            )

        # update the stage counter of the optimizer
        optimizer.update_cr_stage(batch_size)

        # calculate the total FO complexity of this stage
        crFOs += 2 * T * batch_size

        if i - prev >= alpha * i:
            update = True
            frozen_optimizer.set_param_groups(optimizer.get_param_groups())
            crFOs += N + len(dataloader.adv_data)
            dataloader.full_stage = True
            full_gradient_standard(frozen_network, optimizer, frozen_optimizer, loss_func,
                                   dataloader, one_hot_encode=one_hot_encode)
            dataloader.full_stage = False
        else:
            crFOs += 1
            sample_index = np.random.choice(N, size=batch_size, replace=True)
            x = []
            y = []
            for id in sample_index:
                xx, yy = train_data[id]
                x.append(xx.unsqueeze(0))
                y.append(yy.unsqueeze(0))
            x = torch.cat(x, dim=0)
            y = torch.cat(y, dim=0)
            if one_hot_encode:
                y = one_hot(y, 10)
            update_full_gradient_standard(frozen_network, optimizer, frozen_optimizer, loss_func, (x, y))
        FUM_Stage_Standard(network, frozen_network, optimizer, frozen_optimizer, scheduler, loss_func,
                           train_data, dataloader, T, one_hot_encode=one_hot_encode,
                           adversarial=adversarial_iter, epsilon=epsilon)
        if update:
            update = False
            prev = i
            frozen_optimizer.set_param_groups(optimizer.get_param_groups())
            crFOs += N + len(dataloader.adv_data)
            dataloader.full_stage = True
            full_gradient_standard(frozen_network, optimizer, frozen_optimizer, loss_func,
                                   dataloader, one_hot_encode=one_hot_encode)
            dataloader.full_stage = False
        FOs.append(crFOs)

        if adversarial:
            # add adversarial examples
            sample_index = np.random.choice(N, size=batch_size, replace=True)
            x = []
            y_orig = []
            for id in sample_index:
                xx, yy = train_data[id]
                x.append(xx.unsqueeze(0))
                y_orig.append(yy.unsqueeze(0))
            x = torch.cat(x, dim=0)
            y_orig = torch.cat(y_orig, dim=0)
            y = y_orig.clone()
            if one_hot_encode:
                y = one_hot(y_orig, 10).float()
            dataloader.new_adversarial(network, x, y, y_orig, loss_func, epsilon=epsilon)

        results_multi_stages = pd.DataFrame(zip(Res, TrainAcc, TestAcc, TrainRob, TestRob, FOs),
                                                columns=["Loss", "TrainAcc", "TestAcc", "TrainRob", "TestRob", "FO"])
        results_multi_stages.to_csv(output_filename,sep="\t", encoding="utf-8")

    loss, accuracy_train, accuracy_test, robustness_train, robustness_test = evaluate(
        network, loss_func, train_data, test_data, adversarial or adversarial_iter,
        one_hot_encode, epsilon, attack_iter
    )
    Res.append(loss)
    TrainAcc.append(accuracy_train)
    TestAcc.append(accuracy_test)
    TrainRob.append(robustness_train)
    TestRob.append(robustness_test)
    print(
        f"Stage: {i}/{stages}, Loss: {loss}, Training Accuracy: {accuracy_train}, "
        f"Test Accuracy: {accuracy_test}, Train Robustness: {robustness_train}, Test Robustness: {robustness_test}"
    )

    return Res, TrainAcc, TestAcc, TrainRob, TestRob, FOs
