from copy import deepcopy
from glob import glob
from os import makedirs

import attr
import pandas as pd
from attr.validators import and_, ge, gt, in_, instance_of, le
import torch
from torch import nn
from torch.cuda import device_count
from torch.optim import SGD
from torch.utils.data import Subset
from torchvision.datasets import CIFAR10, CIFAR100, FashionMNIST, MNIST
from torchvision.transforms import Compose, Normalize, Pad, RandomCrop, RandomHorizontalFlip, ToTensor

from adversarial import Produce_SGD_Standard, SIOPT_Standard, SIOPT_Continual
from models import ModelHpars
from preprocess import ExpandChannels, LoadToDevice, pre_process_dataset, TargetToTensor
from sgd import Produce_SGD, SGD_Stage
from siopt import frozen, full_gradient, FUM, SIOPT, update_full_gradient, VR_G
from utils import AbstractConf, reproduce


class Experiment:
    def __init__(self, seed: int, device: str,
                 dataset_name: str, model_hpars: dict, optimizer_name: str, continual: bool,
                 adversarial: bool, adversarial_iter: bool,
                 dataset_fraction: float, num_stages: int, num_iterations: int,
                 batch_size: int, lr: float, siopt_alpha: float,
                 adversarial_epsilon: float, adversarial_attack_iter: int, momentum: float,
                 scheduler: str, weight_decay: float, forgetting: bool, checkpoint_path: str):
        self.seed = seed
        self.device = device
        self.dataset_name = dataset_name
        self.model_hpars = model_hpars
        self.model_hpars["input_size"] = 32 if dataset_name in ["cifar10", 'cifar100'] else 28
        self.model_hpars["input_channels"] = 3 if dataset_name in ["cifar10", 'cifar100'] else 1
        self.model_hpars["num_classes"] = (10 if dataset_name!= 'cifar100' else 100)

        print(self.model_hpars["num_classes"])
        self.optimizer_name = optimizer_name
        self.continual = continual
        self.adversarial = adversarial
        self.adversarial_iter = adversarial_iter
        self.dataset_fraction = dataset_fraction
        self.num_stages = num_stages
        self.num_iterations = num_iterations
        self.batch_size = batch_size
        self.lr = lr
        self.momentum = momentum
        self.scheduler = scheduler
        self.weight_decay = weight_decay
        self.forgetting = forgetting
        self.checkpoint_path = checkpoint_path

        self.siopt_alpha = siopt_alpha
        self.adversarial_epsilon = adversarial_epsilon
        self.adversarial_attack_iter = adversarial_attack_iter

        if self.forgetting:
            run_name = 'forgetting_' + self.optimizer_name + '_' + self.model_hpars['name'] + f'_continual{self.continual}_' + self.dataset_name + f'fraction{self.dataset_fraction}_bs{self.batch_size}_stages{self.num_stages}_iters{self.num_iterations}_lr{self.lr}_alpha{self.siopt_alpha}_scheduler' + self.scheduler + f'_seed{self.seed}'
        else:
            run_name = self.optimizer_name + '_' + self.model_hpars['name'] + f'_continual{self.continual}_' + self.dataset_name + f'fraction{self.dataset_fraction}_bs{self.batch_size}_stages{self.num_stages}_iters{self.num_iterations}_lr{self.lr}_alpha{self.siopt_alpha}_scheduler' + self.scheduler + f'_seed{self.seed}'
        self.output_folder = f"results/{self.dataset_name}/runs/" + run_name
        self.output_filename = self.output_folder + "/results_multi_stages.txt"

        datasets_means = {"mnist": 0.1307, "fashion_mnist": 0.2860, "cifar10": (0.4914, 0.4822, 0.4465)}
        datasets_stds = {"mnist": 0.3081, "fashion_mnist": 0.3530, "cifar10": (0.2023, 0.1994, 0.2010)}
        input_transform = [RandomCrop(32 if dataset_name in ["cifar10", "cifar100"] else 28, padding=4),
                           RandomHorizontalFlip(p=0.5),
                           ToTensor(),  # Normalize(datasets_means[dataset_name], datasets_stds[dataset_name]),
                           LoadToDevice(self.device)]
        if model_hpars["name"] == "alexnet":
            if dataset_name == "cifar10":
                input_transform.append(Pad(96))
            else:
                input_transform.append(Pad(98))
                input_transform.append(ExpandChannels(3))
        elif model_hpars["name"] == "vgg":
            if dataset_name != "cifar10":
                input_transform.append(Pad(2))
                input_transform.append(ExpandChannels(3))
        elif model_hpars["name"] == "resnet":
            if dataset_name not in ["cifar10", 'cifar100']:
                input_transform.append(ExpandChannels(3))

        input_transform_test = Compose(input_transform[2:])
        input_transform = Compose(input_transform)
        target_transform = Compose([TargetToTensor(), LoadToDevice(self.device)])

        if dataset_name == "mnist":
            self.train_data = MNIST("data/", train=True, transform=input_transform,
                                    target_transform=target_transform, download=True)
            self.test_data = MNIST("data/", train=False, transform=input_transform_test,
                                   target_transform=target_transform)
        elif dataset_name == "fashion_mnist":
            self.train_data = FashionMNIST("data/", train=True, transform=input_transform,
                                           target_transform=target_transform, download=True)
            self.test_data = FashionMNIST("data/", train=False, transform=input_transform_test,
                                          target_transform=target_transform)
        elif dataset_name == "cifar10":
            self.train_data = CIFAR10("data/", train=True, transform=input_transform,
                                      target_transform=target_transform, download=True)
            self.test_data = CIFAR10("data/", train=False, transform=input_transform_test,
                                     target_transform=target_transform)
        elif dataset_name == "cifar10":
            self.train_data = CIFAR10("data/", train=True, transform=input_transform,
                                      target_transform=target_transform, download=True)
            self.test_data = CIFAR10("data/", train=False, transform=input_transform_test,
                                     target_transform=target_transform)
        elif dataset_name == "cifar100":
            self.train_data = CIFAR100("data/", train=True, transform=input_transform,
                                      target_transform=target_transform, download=True)
            self.test_data = CIFAR100("data/", train=False, transform=input_transform_test,
                                     target_transform=target_transform)

        # Logging
        makedirs(f"results/{self.dataset_name}/runs", exist_ok=True)
        self.run_id = len(glob(f"results/{self.dataset_name}/runs/*")) + 1
        makedirs(self.output_folder, exist_ok=True)

    def run(self):
        reproduce(self.seed)

        # Data Preprocessing
        num_samples_per_class = int(len(self.train_data) * self.dataset_fraction / self.model_hpars["num_classes"])
        
        if self.dataset_fraction == 1:
            prefix_train, _ = pre_process_dataset(self.train_data, num_samples=-1)
            prefix_test,_ = pre_process_dataset(self.test_data, num_samples=-1)
        else:
            prefix_train, prefix_test = pre_process_dataset(self.train_data, num_samples=num_samples_per_class)
            self.test_data = self.train_data

        print(len(prefix_train), len(prefix_test))

        # Training and Plots
        if self.forgetting:
            network = ModelHpars.from_dict(self.model_hpars).make()
            network.load_state_dict(torch.load(self.checkpoint_path))
            network = network.to(self.device)
        else:
            network = ModelHpars.from_dict(self.model_hpars).make().to(self.device)
        loss = nn.CrossEntropyLoss().to(self.device)

        if self.continual:

            if self.num_stages == 1:
                # Single Stage

                if self.optimizer_name == "sgd":
                    # SGD
                    sgd_optim = SGD(network.parameters(), lr=self.lr)

                    losses_single_stage = SGD_Stage(network, sgd_optim, loss, self.train_data, prefix_train,
                                                    1, self.num_iterations,
                                                    batch_size=self.batch_size, get_errors=True)

                elif self.optimizer_name == "siopt":
                    # SIOPT
                    frozen_network = deepcopy(network).to(self.device)
                    vr_optim = VR_G(network.parameters(), lr=self.lr)
                    frozen_optim = frozen(frozen_network.parameters())

                    full_gradient(frozen_network, vr_optim, frozen_optim, loss, self.train_data, prefix_train, 0)
                    update_full_gradient(frozen_network, vr_optim, frozen_optim, loss, self.train_data, prefix_train, 1)

                    losses_single_stage = FUM(network, frozen_network, vr_optim, frozen_optim, loss,
                                              self.train_data, prefix_train, 1, self.num_iterations,
                                              batch_size=self.batch_size, get_errors=True)

                losses_single_stage = pd.Series(losses_single_stage, name="Loss")
                losses_single_stage = losses_single_stage.to_frame().assign(Dataset=self.dataset_name,
                                                                            Model=self.model_hpars["name"],
                                                                            Optimizer=self.optimizer_name,
                                                                            Continual=self.continual,
                                                                            Adversarial=False)
                losses_single_stage.index.name = "Iteration"

                losses_single_stage.to_csv(self.output_folder + "/losses_single_stage.txt",
                                           sep="\t", encoding="utf-8")

            else:
                # All Stages

                if self.optimizer_name == "sgd":
                    # SGD
                    sgd_optim = SGD(network.parameters(), lr=self.lr)
                    if self.forgetting:
                        Res, TrainAcc, TestAcc, TestForget, TrainForget, TrainRob, TestRob, FOs = Produce_SGD(
                            network, sgd_optim, loss, self.train_data, self.test_data, prefix_train, prefix_test,
                            self.num_iterations, self.adversarial_iter, self.num_stages, batch_size=self.batch_size,
                            epsilon=self.adversarial_epsilon, attack_iter=self.adversarial_attack_iter,
                            output_filename=self.output_filename, scheduler_name = self.scheduler, forgetting=self.forgetting
                        )
                    else:
                        Res, TrainAcc, TestAcc, TrainRob, TestRob, FOs = Produce_SGD(
                            network, sgd_optim, loss, self.train_data, self.test_data, prefix_train, prefix_test,
                            self.num_iterations, self.adversarial_iter, self.num_stages, batch_size=self.batch_size,
                            epsilon=self.adversarial_epsilon, attack_iter=self.adversarial_attack_iter,
                            output_filename=self.output_filename, scheduler_name = self.scheduler, forgetting=self.forgetting
                        )


                elif self.optimizer_name == "siopt":
                    # SIOPT
                    frozen_network = deepcopy(network).to(self.device)
                    vr_optim = VR_G(network.parameters(), lr=self.lr)
                    frozen_optim = frozen(frozen_network.parameters())
                    if self.forgetting:
                        Res, TrainAcc, TestAcc, TestForget, TrainForget, TrainRob, TestRob, FOs = SIOPT(
                            network, frozen_network, vr_optim, frozen_optim, loss,
                            self.train_data, self.test_data, prefix_train, prefix_test, self.num_iterations, self.siopt_alpha,
                            self.adversarial_iter, self.num_stages, batch_size=self.batch_size,
                            epsilon=self.adversarial_epsilon, attack_iter=self.adversarial_attack_iter,
                            output_filename=self.output_filename, scheduler_name = self.scheduler, forgetting=self.forgetting
                        )
                    else:
                        Res, TrainAcc, TestAcc, TrainRob, TestRob, FOs = SIOPT(
                            network, frozen_network, vr_optim, frozen_optim, loss,
                            self.train_data, self.test_data, prefix_train, prefix_test, self.num_iterations, self.siopt_alpha,
                            self.adversarial_iter, self.num_stages, batch_size=self.batch_size,
                            epsilon=self.adversarial_epsilon, attack_iter=self.adversarial_attack_iter,
                            output_filename=self.output_filename, scheduler_name = self.scheduler, forgetting=self.forgetting
                        )

                if self.forgetting:
                    results_multi_stages = pd.DataFrame(
                        zip(Res, TrainAcc, TestAcc, TestForget, TrainForget, TrainRob, TestRob, FOs),
                        columns=["Loss", "TrainAcc", "TestAcc", "TestForget", "TrainForget", "TrainRob", "TestRob", "FO"]
                    )
                else:
                    results_multi_stages = pd.DataFrame(
                        zip(Res, TrainAcc, TestAcc, TrainRob, TestRob, FOs),
                        columns=["Loss", "TrainAcc", "TestAcc", "TrainRob", "TestRob", "FO"]
                    )
                results_multi_stages = results_multi_stages.assign(Dataset=self.dataset_name,
                                                                   Model=self.model_hpars["name"],
                                                                   Optimizer=self.optimizer_name,
                                                                   Continual=self.continual,
                                                                   Adversarial=False)
                results_multi_stages.index.name = "Stage"
                results_multi_stages.to_csv(self.output_filename,
                                            sep="\t", encoding="utf-8")

        else:
            # Adversarial

            # Truncate dataset
            self.train_data = Subset(self.train_data, prefix_train)
            self.test_data = Subset(self.test_data, prefix_test)

            lr_steps = self.num_stages * 10 * int(
                len(self.train_data) * self.dataset_fraction / 10
            ) // self.batch_size  # 196 is the number of steps in a cifar10 epoch

            if self.optimizer_name == "sgd":
                # SGD
                sgd_optim = SGD(network.parameters(), lr=self.lr, momentum=self.momentum,
                                weight_decay=self.weight_decay)
                if self.scheduler == 'constant':
                    scheduler = torch.optim.lr_scheduler.ConstantLR(sgd_optim, total_iters=lr_steps)
                elif self.scheduler == 'GradAlign':
                    scheduler = torch.optim.lr_scheduler.CyclicLR(sgd_optim, base_lr=0, max_lr=self.lr,
                                                                  step_size_up=lr_steps / 2,
                                                                  step_size_down=lr_steps / 2)
                Res, TrainAcc, TestAcc, TrainRob, TestRob, FOs = Produce_SGD_Standard(
                    network, sgd_optim, scheduler, loss, self.train_data, self.test_data,
                    self.num_iterations, self.adversarial, self.adversarial_iter, self.num_stages,
                    batch_size=self.batch_size, epsilon=self.adversarial_epsilon,
                    attack_iter=self.adversarial_attack_iter
                )

            elif self.optimizer_name == "siopt":
                # SIOPT
                frozen_network = deepcopy(network).to(self.device)
                vr_optim = VR_G(network.parameters(), lr=self.lr)
                frozen_optim = frozen(frozen_network.parameters())
                if self.scheduler == 'constant':
                    scheduler = torch.optim.lr_scheduler.ConstantLR(vr_optim, total_iters=lr_steps)
                elif self.scheduler == 'GradAlign':
                    scheduler = torch.optim.lr_scheduler.CyclicLR(vr_optim, base_lr=0, max_lr=self.lr,
                                                                  step_size_up=lr_steps / 2,
                                                                  step_size_down=lr_steps / 2, cycle_momentum=False)
                Res, TrainAcc, TestAcc, TrainRob, TestRob, FOs = SIOPT_Continual(
                    network, frozen_network, vr_optim, frozen_optim, scheduler, loss,
                    self.train_data, self.test_data, self.num_iterations, self.siopt_alpha, self.output_filename,
                    self.adversarial, self.adversarial_iter,
                    self.num_stages, batch_size=self.batch_size, epsilon=self.adversarial_epsilon,
                    attack_iter=self.adversarial_attack_iter
                )

            results_multi_stages = pd.DataFrame(zip(Res, TrainAcc, TestAcc, TrainRob, TestRob, FOs),
                                                columns=["Loss", "TrainAcc", "TestAcc", "TrainRob", "TestRob", "FO"])
            results_multi_stages = results_multi_stages.assign(
                Dataset=self.dataset_name, Model=self.model_hpars["name"],
                Optimizer=f"{self.optimizer_name}{'_adv_iter' if self.adversarial_iter else ''}",
                Continual=self.continual, Adversarial=self.adversarial
            )
            results_multi_stages.index.name = "Stage"
            results_multi_stages.to_csv(self.output_filename,
                                        sep="\t", encoding="utf-8")

        # Save final model
        torch.save(network.state_dict(), self.output_folder + "/network.tar")


@attr.s
class ExperimentHpars(AbstractConf):
    OPTIONS = {"experiment": Experiment}
    seed = attr.ib(default=1, validator=instance_of(int))
    device = attr.ib(default="cpu", validator=in_(["cpu", ] + [f"cuda:{i}" for i in range(device_count())]))
    dataset_name = attr.ib(default="mnist", validator=in_(["mnist", "fashion_mnist", "cifar10", "cifar100"]))
    model_hpars = attr.ib(factory=ModelHpars, validator=lambda i, a, v: type(v) is ModelHpars)
    optimizer_name = attr.ib(default="sgd", validator=in_(["sgd", "siopt"]))
    continual = attr.ib(default=False, validator=instance_of(bool))
    adversarial = attr.ib(default=True, validator=instance_of(bool))
    adversarial_iter = attr.ib(default=False, validator=instance_of(bool))
    dataset_fraction = attr.ib(default=1.0, validator=and_(instance_of(float), gt(0), le(1)))
    num_stages = attr.ib(default=1, validator=and_(instance_of(int), ge(1)))
    num_iterations = attr.ib(default=200, validator=and_(instance_of(int), ge(1)))
    batch_size = attr.ib(default=1, validator=and_(instance_of(int), ge(1)))
    lr = attr.ib(default=0.01, validator=instance_of(float))
    siopt_alpha = attr.ib(default=0.01, validator=instance_of(float))
    adversarial_epsilon = attr.ib(default=8 / 255, validator=instance_of(float))
    adversarial_attack_iter = attr.ib(default=10, validator=instance_of(int))
    momentum = attr.ib(default=0.9, validator=instance_of(float))
    scheduler = attr.ib(default='constant', validator=instance_of(str))
    weight_decay = attr.ib(default=5e-4, validator=instance_of(float))
    forgetting = attr.ib(default=False, validator=instance_of(bool))
    checkpoint_path = attr.ib(default='', validator=instance_of(str))
    name = "experiment"
