"""
SGD implementations

1. Implementation of single stage SGD

2. Implementation of the whole Procedure using the single stage Implementation
"""

from typing import List, Sequence, Tuple

import torch
import pandas as pd
from torch import nn
from torch.nn.functional import one_hot
from torch.optim import Optimizer
from torch.utils.data import Dataset

from data import load_stage
from optim_utils import find_max_label, gen_adversarial, stage_accuracy, stage_error


def SGD_Stage(network: nn.Module, optimizer: Optimizer, loss_func: nn.Module,
              train_data: Dataset, sorted_indices: Sequence[int], i: int, T_i: int,
              replacement: bool = True, batch_size: int = 1, one_hot_encode: bool = False,
              get_errors: bool = False,
              at: bool = False, epsilon: float = 8 / 255, attack_iter: int = 10, scheduler_name = 'constant') -> List[float]:
    """
    network: The network used
    optimizer: The optimizer used
    loss_func: The function used as loss
    train_data: The training dataset used
    sorted_indices: The indices sorted according to their label
    i: The prefix index
    T_i: The iterations to do at this stage
    replacement: whether to use replacement when sampling
    batch_size: the number of samples of each mini batch
    one_hot_encode: whether you should one hot encode the labels
    get_errors: give the errors
    """

    # Take enough samples for all iterations for this stage
    train_data_iter = load_stage(train_data, sorted_indices, i, T_i,
                                 replacement=replacement, batch_size=batch_size)
    lr = optimizer.param_groups[0]['lr']

    if scheduler_name == 'linear':
        scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=len(train_data_iter))
    # Train the model
    network.train()
    errors = []
    losses = 0
    correct = 0
    total = 0
    j = 0

    for images, labels in train_data_iter:
        total += images.shape[0]
        # calculate the loss for this sample batch
        if one_hot_encode:
            labels = one_hot(labels, network.num_classes).float()
        if at:
            images = gen_adversarial(network, images, labels, loss_func, epsilon=epsilon, attack_iter=attack_iter,
                                     input_noise_rate=0)
        outputs = network(images)
        loss = loss_func(outputs, labels)
        losses += loss.item()
        correct += (torch.argmax(outputs,dim=1) == labels).sum().item()

        # clear gradient for this training step
        optimizer.zero_grad()

        # backpropagation, compute gradients
        loss.backward()

        # apply gradients
        optimizer.step()
        if scheduler_name != 'constant':
            scheduler.step()

        # Get average loss for this iteration
        if get_errors:
            j += 1
            with torch.no_grad():
                network.eval()
                error = stage_error(network, train_data, sorted_indices, i, loss_func, one_hot_encode=one_hot_encode)
                print(f"Iteration: {j}/{T_i} error: {error}")
                errors.append(error)
            network.train()
    for g in optimizer.param_groups:
        g['lr'] = lr
    return losses/total, correct/total


def Produce_SGD(network: nn.Module, optimizer: Optimizer, loss_func: nn.Module, train_data: Dataset, test_data: Dataset,
                sorted_indices: Sequence[int], sorted_indices_test: Sequence[int], T_i: int, adversarial_iter: bool,
                stages: int = -1, replacement: bool = True, batch_size: int = 1, one_hot_encode: bool = False,
                epsilon: float = 8 / 255, attack_iter: int = 10, output_filename = '', scheduler_name = 'constant',
                forgetting: bool = False) -> Tuple[
    List[float], List[float], List[float], List[float], List[float], List[int]
]:
    """
    network: The network used
    optimizer: The optimizer used
    loss_func: The function used as loss
    train_data: The training dataset used
    test_data: The test dataset used
    sorted_indices: The indices sorted according to their label
    sorted_indices_test: The indices of the test set sorted according to their label
    T_i: The iterations to do at this stage
    stages: How many stages the algorithm should do if not all of them
    replacement: whether to use replacement when sampling
    batch_size: the number of samples of each mini batch
    one_hot_encode: whether you should one hot encode the labels
    """
    Res = []
    FOs = []
    TestAcc = []
    TestForget = []
    TrainForget = []
    TrainAcc = []
    TestRob = []
    TrainRob = []
    # set the current indices for accuracy calculation
    testlim = 0
    trainlim = 0

    if stages < 0:
        stages = len(sorted_indices) // batch_size
    crFOs = 0
    for j in range(0, stages):
        # get the label of the current point
        if forgetting:
            i = max(min((len(sorted_indices)//batch_size - j)*batch_size - 1, len(sorted_indices)-1), 0)
        else:
            i = min((j + 1) * batch_size - 1, len(sorted_indices)-1)
        current = train_data.targets[sorted_indices[i]]
        if type(current) is torch.Tensor:
            current = current.item()

        # calculate the total FO complexity of this stage
        crFOs += T_i * batch_size
        FOs.append(crFOs)

        # run SGD for this stage
        loss_train, accuracy_train = SGD_Stage(network, optimizer, loss_func, train_data, sorted_indices, i, T_i,
                  replacement=replacement, batch_size=batch_size, one_hot_encode=one_hot_encode,
                  at=adversarial_iter, epsilon=epsilon, scheduler_name = scheduler_name)

        # Compute resulting loss for this stage accuracy on the train and test set
        trainlim = find_max_label(train_data, sorted_indices, 0, (current-1 if forgetting else current))
        testlim = find_max_label(test_data, sorted_indices_test, 0, (current-1 if forgetting else current))
        # print(f"trainlim: {trainlim}, testlim: {testlim}")
        with torch.no_grad():
            network.eval()
            #loss = stage_error(network, train_data, sorted_indices, i, loss_func, one_hot_encode=one_hot_encode)
            # accuracy_train = stage_accuracy(network, train_data, sorted_indices, trainlim, loss_func,
            #                                 one_hot_encode=one_hot_encode)
            accuracy_test = stage_accuracy(network, test_data, sorted_indices_test, testlim, loss_func,
                                           one_hot_encode=one_hot_encode)
            if forgetting:
                trainlim_2 = find_max_label(train_data, sorted_indices, 0, max(current-1,0))
                testlim_2 = find_max_label(test_data, sorted_indices_test, 0, max(current-1,0))
                test_forget = stage_accuracy(network, test_data, list(reversed(sorted_indices_test)), len(sorted_indices_test)-2-testlim_2, loss_func,
                                           one_hot_encode=one_hot_encode)
                train_forget = stage_accuracy(network, train_data, list(reversed(sorted_indices)), len(sorted_indices)-2-trainlim_2, loss_func,
                                             one_hot_encode=one_hot_encode)
                TestForget.append(test_forget)
                TrainForget.append(train_forget)
                print('forget:', len(sorted_indices)-2-trainlim, 'remember', trainlim)
        robustness_test = None
        robustness_train = None
        if adversarial_iter:
            # robustness_train = stage_accuracy(network, train_data, sorted_indices, trainlim, loss_func,
            #                                   one_hot_encode=one_hot_encode, adversarial=True,
            #                                   epsilon=epsilon, attack_iter=attack_iter)
            robustness_test = stage_accuracy(network, test_data, sorted_indices_test, testlim, loss_func,
                                             one_hot_encode=one_hot_encode, adversarial=True,
                                             epsilon=epsilon, attack_iter=attack_iter)
        Res.append(loss_train)
        TrainAcc.append(accuracy_train)
        TestAcc.append(accuracy_test)
        TrainRob.append(robustness_train)
        TestRob.append(robustness_test)
        network.train()
        print(
            f"Stage: {j}/{stages}, Loss: {loss_train}, Training Accuracy: {accuracy_train}, "
            f"Test Accuracy: {accuracy_test}, Train Robustness: {robustness_train}, Test Robustness: {robustness_test}"
        )
        if forgetting:
            results_multi_stages = pd.DataFrame(zip(Res, TrainAcc, TestAcc, TestForget, TrainForget, TrainRob, TestRob, FOs),
                                                    columns=["Loss", "TrainAcc", "TestAcc", "TestForget", "TrainForget", "TrainRob", "TestRob", "FO"])
        else:
            results_multi_stages = pd.DataFrame(zip(Res, TrainAcc, TestAcc, TrainRob, TestRob, FOs),
                                                    columns=["Loss", "TrainAcc", "TestAcc", "TrainRob", "TestRob", "FO"])
        results_multi_stages.to_csv(output_filename,sep="\t", encoding="utf-8")
    if forgetting:
        return Res, TrainAcc, TestAcc, TestForget, TrainForget, TrainRob, TestRob, FOs
    else:
        return Res, TrainAcc, TestAcc, TrainRob, TestRob, FOs
