import torch
import numpy as np
import random
import ResNet18Model
import IncDataStream
from torch import nn
import RotateNet
from torch.utils.data import DataLoader


# this class is the super class for all learning algorithms I use
class CLLearningAlgo:

    def __init__(self, args, classification_head=True):
        torch.backends.cudnn.deterministic = True
        self.seed = args.seed
        torch.manual_seed(self.seed[0])
        np.random.seed(self.seed[0])
        random.seed(self.seed[0])
        torch.cuda.manual_seed(self.seed[1])

        self.dropout = 0.0  # 0.2
        self.gamma = 1.0  # 0.8
        self.learning_rate = args.lr  # 0.1
        self.momentum = 0.0  # 0.6 # 0.0
        self.epochs = args.epochs
        self.batch_size = args.batch_size  # 10
        self.reg_coef = 1  # algo specific
        self.model_width = 10  # 10
        self.classification_head = classification_head
        self.dataset = args.dataset

        if args.dataset == 'CIFAR100':
            self.window_len = 5
            self.nclasses = 100
        elif args.dataset == 'MiniImageNet':
            self.window_len = 5
            self.nclasses = 100
        elif args.dataset == 'CIFAR10':
            self.window_len = 2
            self.nclasses = 10
        elif args.dataset == 'MNIST':
            self.window_len = 2
            self.nclasses = 10
        else:
            raise ValueError('The dataset given is not supported')

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using {self.device} device")

        config = {'dropout': self.dropout, 'classification_head': self.classification_head}
        self.model = ResNet18Model.ResNet18(config=config, nf=self.model_width, nclasses=self.nclasses)
        #self.model = RotateNet.ResNet18(config=config, nf=self.model_width, nclasses=self.nclasses, cl_algo=self)
        #self.model = ResNet18Model.MNISTModel(config=config, nclasses=self.nclasses)
        self.model = self.model.to(self.device)

        self.optimiser = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=self.momentum)

        #self.optimiser = torch.optim.Adam(self.model.parameters())

        self.task_stream = None

        self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimiser, gamma=self.gamma)

        # These fields store vales relating to the current task when running setting
        self.task_id = 0
        self.batch_num = -1
        self.data_loader = None
        self.nullClasses = None

        # this is used to store training accuracies
        self.train_acc = []
        self.per_seen_task_acc = None

        # these are used to store what classes correspond to what task_ids
        self.calc_classes = None
        self.calc_null_classes = None
        self.calc_per_point_nullclasses = None

        # used to change behaviour of model between training and eval
        self.training = True

        # used to store the current training batch
        self.batch = None

    def eval(self):
        self.training = False
        self.model.eval()

    def train(self):
        self.training = True
        self.model.train()

    def after_optimiser_step(self):
        return

    def at_end_of_task(self):
        return

    def before_eval_on_task(self, task_id):
        return

    def before_batch_calc(self):
        return

    def after_loss_calc(self):
        return

    def calc_reg_loss_term(self):
        return torch.zeros(1, device=self.device)

    def loss_fn(self, X, Y):
        return nn.functional.cross_entropy(X, Y)

    def regularised_loss_fn(self, X, Y):
        return self.loss_fn(X, Y) + self.reg_coef * self.calc_reg_loss_term()

    def _train(self):
        self.per_seen_task_acc = {}
        for task_id, data_loader in enumerate(self.task_stream):
            self.task_id = task_id
            self.data_loader = data_loader
            print(" ")
            print("------------------ task " + str(task_id) + " -------------------------")

            #self.model.plot_hist(task_id)

            nullClasses = self.calc_null_classes(task_id, self.task_stream.classes)
            self.nullClasses = nullClasses
            train_task(self.device, self.regularised_loss_fn, self.optimiser,
                       self.scheduler, self.epochs, data_loader, self)

            self.at_end_of_task()
            #self.model.plot_hist(task_id)
            #self.model.plot_hist(0)

            if self.optimiser.param_groups[0]['lr'] <= 0.00005:
                self.optimiser.param_groups[0]['lr'] = 0.00005
            else:
                if task_id % self.window_len == 0:
                    self.scheduler.step()

    def _dummy_train(self):
        for task_id, data_loader in enumerate(self.task_stream):
            self.task_id = task_id
            self.data_loader = data_loader
            print(" ")
            print("------------------ task " + str(task_id) + " -------------------------")

            dummy_train_task(self.model, self.device, self.regularised_loss_fn, self.optimiser,
                             self.scheduler, self.epochs, data_loader, self.task_stream, self, self.window_len)

            self.at_end_of_task()

            if self.optimiser.param_groups[0]['lr'] <= 0.00005:
                self.optimiser.param_groups[0]['lr'] = 0.00005
            else:
                if task_id % self.window_len == 0:
                    self.scheduler.step()

    # implements the evaluation procedure of task/class incremental settings, i.e evaluate on the all the tasks val sets
    # with the network learnt at train time
    def classic_test(self):
        self.eval()

        acc_data = []

        self.task_stream.eval()
        print(" ")
        print("---------------------------- eval --------------------------------")
        overall_acc = 0
        overall_n = 0
        for task_id, data_loader in enumerate(self.task_stream):
            self.task_id = task_id
            self.data_loader = data_loader
            self.before_eval_on_task(task_id)

            nullClasses = self.calc_null_classes(task_id, self.task_stream.classes)
            self.nullClasses = nullClasses
            acc, n = test_task(self, self.device, data_loader)
            acc_data.append(acc)
            overall_acc += acc
            overall_n += 1
            print("task " + str(task_id) + " acc: " + str(acc))
        print("seed: "+str(self.seed))
        print("overall accuracy: " + str(overall_acc / overall_n))

    # our new evaluation method where we retrain on the tasks and measure the speed at which methods become accurate
    def shifting_window_test(self):

        self.eval()
        acc_data = []

        self.task_stream.eval() # have this as val set not test set
        print(" ")
        print("---------------------------- eval --------------------------------")
        overall_acc = 0
        overall_n = 0
        for task_id, data_loader in enumerate(self.task_stream):
            self.task_id = task_id
            self.data_loader = data_loader
            self.before_eval_on_task(task_id)

            nullClasses = self.calc_null_classes(task_id, self.task_stream.classes)
            self.nullClasses = nullClasses

            acc, n, learning_curve = shifting_window_test_task(self.model, self.device, self.regularised_loss_fn,
                                                               self.optimiser, self.scheduler, data_loader, nullClasses,
                                                               self)
            acc_data.append(acc)
            overall_acc += acc
            overall_n += 1
            print("task " + str(task_id) + " acc: " + str(acc))
            print(learning_curve)
        print("overall accuracy: " + str(overall_acc / overall_n))
        print(acc_data)
        print(self.train_acc)

    def dummy_classic_test(self):

        self.eval()

        acc_data = []

        self.task_stream.eval()
        print(" ")
        print("---------------------------- eval --------------------------------")
        overall_acc = 0
        overall_n = 0
        for task_id, data_loader in enumerate(self.task_stream):
            self.task_id = task_id
            self.data_loader = data_loader
            self.before_eval_on_task(task_id)

            acc, n = dummy_test_task(self, self.device, data_loader, self.task_stream, self.window_len)
            acc_data.append(acc)
            overall_acc += acc
            overall_n += 1
            print("task " + str(task_id) + " acc: " + str(acc))
        print("seed: " + str(self.seed))
        print("overall accuracy: " + str(overall_acc / overall_n))

    def run_full_window_setting(self):

        if self.dataset == 'CIFAR100':
            self.task_stream = IncDataStream.CIFAR100ShiftWinStream(batch_size=self.batch_size,
                                                                    window_length=self.window_len)
        elif self.dataset == 'MiniImageNet':
            self.task_stream = IncDataStream.MiniImageNetShiftWinStream(batch_size=self.batch_size,
                                                                        window_length=self.window_len)
        elif self.dataset == 'CIFAR10':
            self.task_stream = IncDataStream.CIFAR10ShiftWinStream(batch_size=self.batch_size,
                                                                   window_length=self.window_len)

        def calc_classes(task_id, classes):
            return classes[task_id:task_id + self.window_len]

        def calc_null_classes(task_id, classes):
            null_classes = list(classes)
            del null_classes[task_id:task_id + self.window_len]
            return null_classes

        self.calc_classes = calc_classes

        self.calc_null_classes = calc_null_classes

        self._train()

        self.shifting_window_test()

    def run_shifting_window_setting(self):

        if self.dataset == 'CIFAR100':
            self.task_stream = IncDataStream.CIFAR100ShiftWinStream(batch_size=self.batch_size,
                                                                    window_length=self.window_len)
        elif self.dataset == 'MiniImageNet':
            self.task_stream = IncDataStream.MiniImageNetShiftWinStream(batch_size=self.batch_size,
                                                                        window_length=self.window_len)
        elif self.dataset == 'CIFAR10':
            self.task_stream = IncDataStream.CIFAR10ShiftWinStream(batch_size=self.batch_size,
                                                                   window_length=self.window_len)

        def calc_classes(task_id, classes):
            return classes[task_id:task_id + self.window_len]

        def calc_null_classes(task_id, classes):
            null_classes = list(classes)
            del null_classes[task_id:task_id + self.window_len]
            return null_classes

        self.calc_classes = calc_classes

        self.calc_null_classes = calc_null_classes

        self._train()

        self.classic_test()

    def run_disjoint_tasks_setting(self):

        if self.dataset == 'CIFAR100':
            self.task_stream = IncDataStream.SplitCIFAR100(batch_size=self.batch_size, window_length=self.window_len)
        elif self.dataset == 'MiniImageNet':
            self.task_stream = IncDataStream.SplitMiniImageNet(batch_size=self.batch_size,
                                                               window_length=self.window_len)
        elif self.dataset == 'CIFAR10':
            self.task_stream = IncDataStream.SplitCIFAR10(batch_size=self.batch_size, window_length=self.window_len)
        elif self.dataset == 'MNIST':
            self.task_stream = IncDataStream.SplitMNIST(batch_size=self.batch_size, window_length=self.window_len)


        def calc_classes(task_id, classes):
            return classes[task_id * self.window_len:task_id * self.window_len + self.window_len]

        def calc_null_classes(task_id, classes):
            null_classes = list(classes)
            del null_classes[task_id * self.window_len:task_id * self.window_len + self.window_len]
            return null_classes

        self.calc_classes = calc_classes

        self.calc_null_classes = calc_null_classes

        self._train()

        self.classic_test()

    def run_dummy_window_setting(self):

        if self.dataset == 'CIFAR100':
            self.task_stream = IncDataStream.CIFAR100DummyShiftWinStream(batch_size=self.batch_size,
                                                                         window_length=self.window_len)
        elif self.dataset == 'MiniImageNet':
            self.task_stream = IncDataStream.MiniImageNetDummyShiftWinStream(batch_size=self.batch_size,
                                                                             window_length=self.window_len)
        elif self.dataset == 'CIFAR10':
            self.task_stream = IncDataStream.CIFAR10DummyShiftWinStream(batch_size=self.batch_size,
                                                                        window_length=self.window_len)

        def calc_per_point_nullclasses(stream, X, Y, window_len, training):
            per_point_nullclasses = []
            for j in range(X.size()[0]):
                null_classes = list(stream.classes)
                task_id = stream.window_num_of_data(X[j], Y[j].item(), train=training)
                del null_classes[task_id:task_id + window_len]
                per_point_nullclasses.append(null_classes)

            return per_point_nullclasses

        self.calc_per_point_nullclasses = calc_per_point_nullclasses

        self._dummy_train()

        self.dummy_classic_test()

    def run_dummy_disjoint_tasks_setting(self):

        if self.dataset == 'CIFAR100':
            self.task_stream = IncDataStream.DummySplitCIFAR100(batch_size=self.batch_size,
                                                                window_length=self.window_len)
        elif self.dataset == 'MiniImageNet':
            self.task_stream = IncDataStream.DummySplitMiniImageNet(batch_size=self.batch_size,
                                                                    window_length=self.window_len)
        elif self.dataset == 'CIFAR10':
            self.task_stream = IncDataStream.DummySplitCIFAR10(batch_size=self.batch_size,
                                                               window_length=self.window_len)

        def calc_per_point_nullclasses(stream, X, Y, window_len, training):
            per_point_nullclasses = []
            for j in range(X.size()[0]):
                null_classes = list(stream.classes)
                task_id = null_classes.index(Y[j].item()) // window_len
                del null_classes[task_id * window_len:task_id * window_len + window_len]
                per_point_nullclasses.append(null_classes)

            return per_point_nullclasses

        self.calc_per_point_nullclasses = calc_per_point_nullclasses

        self._dummy_train()

        self.dummy_classic_test()

    # this helper function is used to correctly select the correct output head for the input for task indexed method
    def predict(self, X):
        out = self.model(X)
        #out = self.model(X, self.task_id)
        out[:, self.nullClasses] = -10e10
        return out if self.training else out.argmax(1)


# this helper function is used to correctly select the correct output head for the input, where each input might go to
# a different output head
def calc_multi_head_model_output(model, X, per_point_nullClasses):
    out = model(X)
    for j in range(X.size()[0]):
        out[j, per_point_nullClasses[j]] = -10e10
    return out


def calc_model_output(model, X, nullClasses):
    out = model(X)
    out[:, nullClasses] = -10e10
    return out


# calc the forgetting metric (should really change this to use numpy arrays!)
def calc_forgetting(per_seen_task_acc):
    max_acc = {task_id: 0.0 for task_id in per_seen_task_acc}
    for task_id in per_seen_task_acc:
        for old_task_id in range(len(per_seen_task_acc[task_id])):
            if max_acc[old_task_id] < per_seen_task_acc[task_id][old_task_id]:
                max_acc[old_task_id] = per_seen_task_acc[task_id][old_task_id]

    max_task = task_id
    f = 0.0
    for task_id in per_seen_task_acc:
        if task_id != max_task:
            f += max_acc[task_id] - per_seen_task_acc[max_task][task_id]
    return 1/(len(per_seen_task_acc)-1)*f


def train_batch(X, Y, learning_algo, device, optimiser, lossfunction, i, scheduler):
    learning_algo.before_batch_calc()
    X, Y = X.to(device), Y.to(device)
    optimiser.zero_grad()
    loss = lossfunction(learning_algo.predict(X), Y)
    loss.backward()
    learning_algo.after_loss_calc()
    optimiser.step()
    learning_algo.after_optimiser_step()
    if i % 10 == 0:
        print(f"loss: {loss.item():>7f}  [{i * len(X):>5d}, {scheduler.get_lr()[0]}]")


def train_task(device, lossfunction, optimiser, scheduler, epochs, data_loader, learning_algo):
    for e in range(epochs):
        print(" ")
        print("epoch: "+str(e))
        for i, (X, Y) in enumerate(data_loader):
            learning_algo.batch_num += 1
            learning_algo.batch = (X, Y)
            train_batch(X, Y, learning_algo, device, optimiser, lossfunction, i, scheduler)


def shifting_window_test_task(model, device, lossfunction, optimiser,
                              scheduler, data_loader, learning_algo):

    acc = 0
    learning_curve = []

    for i, (X, Y) in enumerate(data_loader):
        model.eval()
        X, Y = X.to(device), Y.to(device)
        if i <= 1:
            model.train()
            train_batch(X, Y, learning_algo, device, optimiser, lossfunction, 1, scheduler)
        else:
            batch_acc = (learning_algo.predict(X).argmax(1) == Y).type(torch.float).sum().item()
            acc += batch_acc
            learning_curve.append(batch_acc/learning_algo.batch_size)

    return acc/len(data_loader.dataset), len(data_loader.dataset), learning_curve


def test_task(learning_algo, device, data_loader):
    learning_algo.eval()
    acc = 0

    for (X, Y) in data_loader:
        X, Y = X.to(device), Y.to(device)
        acc += (learning_algo.predict(X) == Y).type(torch.float).sum().item()
    return acc/len(data_loader.dataset), len(data_loader.dataset)


def dummy_train_batch(X, Y, learning_algo, device, optimiser, lossfunction, model, per_point_nullClasses, i, scheduler):
    learning_algo.before_batch_calc()
    X, Y = X.to(device), Y.to(device)
    optimiser.zero_grad()
    loss = lossfunction(calc_multi_head_model_output(model, X, per_point_nullClasses), Y)
    loss.backward()
    learning_algo.after_loss_calc()
    optimiser.step()
    learning_algo.after_optimiser_step()
    if i % 10 == 0:
        print(f"loss: {loss.item():>7f}  [{i * len(X):>5d}, {scheduler.get_lr()[0]}]")


def dummy_train_task(model, device, lossfunction, optimiser, scheduler, epochs,
                     data_loader, stream, learning_algo, window_len):
    for e in range(epochs):
        print(" ")
        print("epoch: "+str(e))
        for i, (X, Y) in enumerate(data_loader):
            learning_algo.batch = (X, Y)
            per_point_nullclasses = learning_algo.calc_per_point_nullclasses(stream, X, Y, window_len, stream.training)
            dummy_train_batch(X, Y, learning_algo, device, optimiser, lossfunction,
                        model, per_point_nullclasses, i, scheduler)


def dummy_test_task(learning_algo, device, data_loader, stream, window_len):
    model = learning_algo.model
    model.eval()
    acc = 0

    for (X, Y) in data_loader:
        per_point_nullclasses = learning_algo.calc_per_point_nullclasses(stream, X, Y, window_len, stream.training)
        X, Y = X.to(device), Y.to(device)
        out = calc_multi_head_model_output(model, X, per_point_nullclasses)
        acc += (out.argmax(1) == Y).type(torch.float).sum().item()
    return acc/len(data_loader.dataset), len(data_loader.dataset)

