"""
SIOPT Implementations

1. Implementation of the Variance Reduced Optimizer In Pytorch Optim.
The code for this implementation is heavily reliant on https://github.com/yueqiw/OptML-SVRG-PyTorch/tree/master

2. Implementation of FUM

3. Full Gradient Computation for a prefix

4. Implementation of SIOPT
"""

from copy import deepcopy
from typing import List, Sequence, Tuple

import torch
import pandas as pd
from tqdm import tqdm
from torch import nn, optim
from torch.nn.functional import one_hot
from torch.optim import Optimizer
from torch.utils.data import Dataset

from data import full_stage, load_stage, Sample
from optim_utils import find_max_label, gen_adversarial, stage_accuracy, stage_error


class VR_G(Optimizer):
    """
    The frequent iterate optimizer
    """

    def __init__(self, params, lr):
        self.vr = None
        defaults = dict(lr=lr)
        self.cr_stage = 0
        super(VR_G, self).__init__(params, defaults)

    def get_param_groups(self):
        return self.param_groups

    def set_vr(self, new_vr):
        if self.vr is None:
            self.vr = deepcopy(new_vr)
        for u_group, new_group in zip(self.vr, new_vr):
            for u, new_u in zip(u_group['params'], new_group['params']):
                u.grad = new_u.grad.clone()

    def update_vr(self, sample_grad):
        for u_group, u_sample in zip(self.vr, sample_grad):
            for u, new_u in zip(u_group['params'], u_sample['params']):
                if (u.grad is None) or (new_u.grad is None):
                    continue
                u.grad += new_u.grad

    def set_cr_stage(self, new_cr_stage):
        self.cr_stage = new_cr_stage

    def update_cr_stage(self, batch_size):
        self.cr_stage += batch_size

    def step(self, params):
        for group, new_group, vr_group in zip(self.param_groups, params, self.vr):
            for p, q, vr in zip(group['params'], new_group['params'], vr_group['params']):
                if (p.grad is None) or (q.grad is None):
                    continue
                '''
                p.grad is the gradient at the random sample and new weights
                q.grad is the gradient at the random sample and frozen weights
                vr.grad.data is the full gradient estimate at the frozen weights
                '''
                vr_grad = p.grad.data - q.grad.data
                vr_grad = torch.add(vr_grad, vr.grad.data, alpha=1.0 / self.cr_stage)

                p.data = torch.add(p.data, vr_grad, alpha=-group['lr'])

    def step2(self, params):
        '''
        This gradient estimate is the one in the paper, it considers the gradient at the latest adversarial example.
        '''
        for group, new_group, vr_group in zip(self.param_groups, params, self.vr):
            for p, q, vr in zip(group['params'], new_group['params'], vr_group['params']):
                if (p.grad is None) or (q.grad is None):
                    continue
                '''
                p.grad is the gradient at the random sample and new weights
                q.grad is the gradient at the random sample and frozen weights
                vr.grad.data is the full gradient estimate at the frozen weights
                '''
                vr_grad = p.grad.data - q.grad.data
                vr_grad = torch.add(vr_grad, vr.grad.data, alpha=1.0 / self.cr_stage)

                p.data = torch.add(p.data, vr_grad, alpha=-group['lr']*(1-1/self.cr_stage))

    # debugging function
    def print_vr(self):
        for u_group in self.vr:
            for u in u_group['params']:
                if u.grad is None:
                    continue
                print(u.grad)


class frozen(optim.Optimizer):
    """
    Optimizer whose sole purpose is the sparse full gradient computation
    """

    def __init__(self, params):
        defaults = dict()
        super(frozen, self).__init__(params, defaults)

    def get_param_groups(self):
        return self.param_groups

    def set_param_groups(self, new_params):
        for group, new_group in zip(self.param_groups, new_params):
            for p, q in zip(group['params'], new_group['params']):
                p.data[:] = q.data[:]


def FUM(network: nn.Module, frozen_network: nn.Module, optimizer: Optimizer, frozen_optimizer: Optimizer,
        loss_func: nn.Module, train_data: Dataset, sorted_indices: Sequence[int], i: int, T_i: int,
        replacement: bool = True, batch_size: int = 1,
        one_hot_encode: bool = True, get_errors: bool = False,
        at: bool = False, epsilon: float = 8 / 255, attack_iter: int = 10, scheduler_name = 'constant') -> List[float]:
    """
    network: The network used
    frozen_network: The network at the reference point
    optimizer: The optimizer used
    frozen_optimizer: An optimizer used for the reference point
    loss_func: The function used as loss
    train_data: The training dataset used
    sorted_indices: The indices sorted according to their label
    i: The prefix index
    T_i: The iterations to do at this stage
    replacement: whether to use replacement when sampling
    batch_size: the number of samples of each mini batch
    one_hot_encode: whether you should one hot encode the labels
    get_errors: give the errors
    """
    # Take enough samples for all iterations for this stage
    train_data_iter = load_stage(train_data, sorted_indices, i, T_i,
                                 replacement=replacement, batch_size=batch_size)
    lr = optimizer.param_groups[0]['lr']

    if scheduler_name == 'linear':
        scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=len(train_data_iter))
    # sgd_optim = optim.SGD(network.parameters(),lr = 0.001)
    # Train the model
    network.train()
    errors = []
    losses = 0
    correct = 0
    total = 0
    j = 0

    for images, labels in tqdm(train_data_iter):
        total += images.shape[0]
        # calculate the loss for this sample batch
        if one_hot_encode:
            labels = one_hot(labels, 10).float()
        images_adv = images.clone()
        if at:
            images_adv = gen_adversarial(network, images_adv, labels, loss_func, epsilon=epsilon,
                                         attack_iter=attack_iter, input_noise_rate=0)
        outputs = network(images_adv)
        #print(torch.argmax(outputs,dim=1).shape, labels.shape)
        if one_hot_encode:
            labels = torch.argmax(labels,dim=1)

        correct += (torch.argmax(outputs,dim=1) == labels).sum().item()

        loss = loss_func(outputs, labels)

        # sgd_optim.zero_grad()
        # loss.backward()
        # sgd_optim.step()

        # clear gradient for this training step
        optimizer.zero_grad()

        # backpropagation, compute gradients
        loss.backward()
        losses += loss.item()

        # clear gradients for reference point
        frozen_optimizer.zero_grad()

        # compute loss for reference point
        frozen_outputs = frozen_network(images)
        frozen_loss = loss_func(frozen_outputs, labels)

        # backpropagation, compute gradients in the frozen network
        frozen_loss.backward()

        # apply gradients
        optimizer.step(frozen_optimizer.get_param_groups())
        if scheduler_name != 'constant':
            scheduler.step()
        # Get average loss for this iteration
        if get_errors:
            j += 1
            with torch.no_grad():
                network.eval()
                error = stage_error(network, train_data, sorted_indices, i, loss_func, one_hot_encode=one_hot_encode)
                print(f"Iteration: {j}/{T_i} error: {error}")
                errors.append(error)
            network.train()
    for g in optimizer.param_groups:
        g['lr'] = lr
    return losses/total, correct/total


def full_gradient(frozen_model: nn.Module, vr_optimizer: Optimizer, frozen_optimizer: Optimizer, loss_func: nn.Module,
                  train_data: Dataset, sorted_indices: Sequence[int], i: int,
                  batch_size: int = 1, one_hot_encode: bool = True):
    """
    frozen_model: The network at the reference point
    vr_optimizer: The optimizer used
    frozen_optimizer: An optimizer used for the reference point
    loss_func: The function used as loss
    train_data: The training dataset used
    sorted_indices: The indices sorted according to their label
    i: The prefix index
    batch_size: the batch_size of the data_loader
    one_hot_encode: whether you should one hot encode the labels
    """

    # initialize the dataloader for the full gradient computation
    data_iter = full_stage(train_data, sorted_indices, i, batch_size=batch_size)
    # calculate a full gradient for the current network
    # set the optimizer to zero grad
    frozen_optimizer.zero_grad()
    for images, labels in tqdm(data_iter):
        if one_hot_encode:
            labels = one_hot(labels, 10).float()
        outputs = frozen_model(images)
        loss = loss_func(outputs, labels)
        loss.backward()

    vr = frozen_optimizer.get_param_groups()
    vr_optimizer.set_vr(vr)
    vr_optimizer.set_cr_stage(len(data_iter))


def update_full_gradient(frozen_network: nn.Module, optimizer: Optimizer,
                         frozen_optimizer: Optimizer, loss_func: nn.Module, sample: Sample,
                         one_hot_encode: bool = True):
    """
    frozen_model: The network at the reference point
    vr_optimizer: The optimizer used
    frozen_optimizer: An optimizer used for the reference point
    loss_func: The function used as loss
    train_data: The training dataset used
    sorted_indices: The indices sorted according to their label
    i: The prefix index
    one_hot_encode: whether you should one hot encode the labels
    """

    # get image and label to include in the full gradient
    images, labels = sample
    bs = images.shape[0]
    # if one_hot_encode
    if one_hot_encode:
        labels = one_hot(labels, 10).float()

    # clear gradients for reference point
    frozen_optimizer.zero_grad()

    # compute loss for reference point
    frozen_outputs = frozen_network(images)
    frozen_loss = loss_func(frozen_outputs, labels)

    # backpropagation, compute gradients in the frozen network
    frozen_loss.backward()

    sample_grad = frozen_optimizer.get_param_groups()
    optimizer.update_vr(sample_grad)
    optimizer.update_cr_stage(1)


def SIOPT(network: nn.Module, frozen_network: nn.Module, optimizer: Optimizer, frozen_optimizer: Optimizer,
          loss_func: nn.Module, train_data: Dataset, test_data: Dataset,
          sorted_indices: Sequence[int], sorted_indices_test: Sequence[int], T: int, alpha: int, adversarial_iter: bool,
          stages: int = -1, replacement: bool = True, batch_size: int = 1, one_hot_encode: bool = False,
          epsilon: float = 8 / 255, attack_iter: int = 10, output_filename: str = 'results', scheduler_name = 'constant',
          forgetting: bool = False) -> Tuple[
    List[float], List[float], List[float], List[float], List[float], List[int]
]:
    """
    network: The network used
    frozen_network: The copy of the network
    optimizer: The optimizer used
    frozen_optimizer: The optimizer used for the full gradient
    loss_func: The function used as loss
    train_data: The training dataset used
    test_data: The test dataset used
    sorted_indices: The indices sorted according to their label
    sorted_indices_test: The indices sorted according to their label of the test set
    T: The iterations to do at this stage
    alpha: How often to do a full gradient computation
    replacement: whether to use replacement when sampling
    batch_size: the number of samples of each mini batch
    one_hot_encode: whether you should one hot encode the labels
    """
    Res = []
    FOs = []
    TestAcc = []
    TestForget = []
    TrainForget = []
    TrainAcc = []
    TestRob = []
    TrainRob = []
    # set the current indices for accuracy calculation
    testlim = 0
    trainlim = 0

    print(len(sorted_indices))
    if stages < 0:
        stages = len(sorted_indices) // batch_size
    crFOs = 0
    prev = 0
    update = False
    for j in range(0, stages):
        if forgetting:
            i = max(min((len(sorted_indices)//batch_size - j)*batch_size - 1, len(sorted_indices)-1), 0)
        else:
            i = min((j + 1) * batch_size - 1, len(sorted_indices)-1)
        current = train_data.targets[sorted_indices[i]]
        if type(current) is torch.Tensor:
            current = current.item()

        # update the stage counter of the optimizer
        optimizer.update_cr_stage(1)
        # calculate the total FO complexity of this stage
        crFOs += 2 * T * batch_size
        if (abs(i - prev) >= alpha * i):
            update = True
            frozen_optimizer.set_param_groups(optimizer.get_param_groups())
            crFOs += i
            full_gradient(frozen_network, optimizer, frozen_optimizer, loss_func, train_data, sorted_indices, i,
                          batch_size=batch_size, one_hot_encode=one_hot_encode)
        else:
            crFOs += 1
            x, y = [], []
            for k in sorted_indices[i - batch_size + 1: i + 1]:
                xx, yy = train_data[k]
                x.append(xx.unsqueeze(0))
                y.append(yy.unsqueeze(0))
            x, y = torch.cat(x, dim=0), torch.cat(y, dim=0)
            update_full_gradient(frozen_network, optimizer, frozen_optimizer, loss_func, (x, y),
                                 one_hot_encode=one_hot_encode)
        loss_train, accuracy_train = FUM(network, frozen_network, optimizer, frozen_optimizer, loss_func, train_data, sorted_indices, i, T,
            replacement=replacement, batch_size=batch_size, one_hot_encode=one_hot_encode,
            at=adversarial_iter, epsilon=epsilon,scheduler_name=scheduler_name)
        if update:
            update = False
            prev = i
            frozen_optimizer.set_param_groups(optimizer.get_param_groups())
            crFOs += i
            full_gradient(frozen_network, optimizer, frozen_optimizer, loss_func, train_data, sorted_indices, i,
                          batch_size=batch_size, one_hot_encode=one_hot_encode)
        FOs.append(crFOs)
        # Compute resulting loss for this stage accuracy on the train and test set
        trainlim = find_max_label(train_data, sorted_indices, 0, (max(current-1,0) if forgetting else current))
        testlim = find_max_label(test_data, sorted_indices_test, 0, (max(current-1,0) if forgetting else current))
        # print(f"trainlim: {trainlim}, testlim: {testlim}")
        with torch.no_grad():
            network.eval()
            #loss = stage_error(network, train_data, sorted_indices, i, loss_func, one_hot_encode=one_hot_encode)
            # accuracy_train = stage_accuracy(network, train_data, sorted_indices, trainlim, loss_func,
            #                                 one_hot_encode=one_hot_encode)
            accuracy_test = stage_accuracy(network, test_data, sorted_indices_test, testlim, loss_func,
                                           one_hot_encode=one_hot_encode)
            if forgetting:
                # trainlim_2 = find_max_label(train_data, sorted_indices, 0, max(current-1,0))
                # testlim_2 = find_max_label(test_data, sorted_indices_test, 0, max(current-1,0))
                test_forget = stage_accuracy(network, test_data, list(reversed(sorted_indices_test)), len(sorted_indices_test)-2-testlim, loss_func,
                                           one_hot_encode=one_hot_encode)
                train_forget = stage_accuracy(network, train_data, list(reversed(sorted_indices)), len(sorted_indices)-2-trainlim, loss_func,
                                             one_hot_encode=one_hot_encode)
                TestForget.append(test_forget)
                TrainForget.append(train_forget)
                #print('forget:', len(sorted_indices)-2-trainlim, 'remember', trainlim)
        robustness_test = None
        robustness_train = None
        if adversarial_iter:
            # robustness_train = stage_accuracy(network, train_data, sorted_indices, trainlim, loss_func,
            #                                   one_hot_encode=one_hot_encode, adversarial=True,
            #                                   epsilon=epsilon, attack_iter=attack_iter)
            robustness_test = stage_accuracy(network, test_data, sorted_indices_test, testlim, loss_func,
                                             one_hot_encode=one_hot_encode, adversarial=True,
                                             epsilon=epsilon, attack_iter=attack_iter)
        Res.append(loss_train)
        TrainAcc.append(accuracy_train)
        TestAcc.append(accuracy_test)
        TrainRob.append(robustness_train)
        TestRob.append(robustness_test)
        print(
            f"Stage: {j}/{stages}, Loss: {loss_train}, Training Accuracy: {accuracy_train}, "
            f"Test Accuracy: {accuracy_test}, Train Robustness: {robustness_train}, Test Robustness: {robustness_test}"
        )
        if forgetting:
            results_multi_stages = pd.DataFrame(zip(Res, TrainAcc, TestAcc, TestForget, TrainForget, TrainRob, TestRob, FOs),
                                                    columns=["Loss", "TrainAcc", "TestAcc", "TestForget", "TrainForget", "TrainRob", "TestRob", "FO"])
        else:
            results_multi_stages = pd.DataFrame(zip(Res, TrainAcc, TestAcc, TrainRob, TestRob, FOs),
                                                    columns=["Loss", "TrainAcc", "TestAcc", "TrainRob", "TestRob", "FO"])
        results_multi_stages.to_csv(output_filename,sep="\t", encoding="utf-8")
    if forgetting:
        return Res, TrainAcc, TestAcc, TestForget, TrainForget, TrainRob, TestRob, FOs
    else:
        return Res, TrainAcc, TestAcc, TrainRob, TestRob, FOs
