from numpy.linalg import svd
from utils import *
import torch
from tqdm import tqdm
from my_ipca import MyIPCA as IPCA
from sklearn.decomposition import PCA

def train(task_list, args, train_data, test_data, model_clip, model):
    cil_acc_mat_test = np.zeros((args.n_tasks * 2 + 1, args.n_tasks * 2 + 1)) - 100
    til_acc_mat_test = np.zeros((args.n_tasks * 2 + 1, args.n_tasks * 2 + 1)) - 100

    # cil_correct, til_correct are for cumulative accuracy throughout training
    cil_correct, til_correct, total = 0, 0, 0

    train_loaders, test_loaders = [], []

    param_copy = None

    combined_sigma = 0

    for task_id in range(len(task_list)):
        task_loss_list = []

        if args.validation is None:
            t_train = train_data.make_dataset()
            t_test = test_data.make_dataset()
        else:
            t_train, t_test = train_data.make_dataset()

        train_loaders.append(make_loader(t_train, args, train=True))
        test_loaders.append(make_loader(t_test, args, train=False))

        if args.model == 'batch_pca':
            from sklearn.utils.extmath import svd_flip
            x_list, y_list = [], []
            args.logger.print("Computing Eigenpairs for PCA")
            for x, y, _, _ in tqdm(train_loaders[-1]):
                x = x.to(args.device)
                with torch.no_grad():
                    x = model_clip.encode_image(x).type(torch.FloatTensor).to(args.device)
                x = x / x.norm(1, keepdim=True)
                x_list.append(x.cpu().numpy())
                y_list.append(y.numpy())
            x_list = np.concatenate(x_list)
            y_list = np.concatenate(y_list)

            ys = list(sorted(set(y_list)))

            # If dynamic memory is used, update n_components and eigenpairs
            if args.dynamic is not None:
                new_components = min(args.dynamic // (len(model.seen_ids) + len(ys)), 512)
                args.logger.print(f"Save {args.n_components} per class -> Save {new_components} per class")
                args.n_components = new_components
                for y_ in range(len(model.eigval_list)):
                    model.eigvec_list[y_] = model.eigvec_list[y_][:args.n_components]
                    model.eigval_list[y_] = model.eigval_list[y_][:args.n_components]

            for y_ in ys:
                idx = np.where(y_list == y_)[0]
                data_ = x_list[idx]
                model.mu_list.append(np.mean(data_, 0))
                data_ -= model.mu_list[-1]

                U, S, V = np.linalg.svd(data_, full_matrices=False)
                U, V = svd_flip(U, V, u_based_decision=False)
                model.eigvec_list.append(V[:args.n_components])
                model.eigval_list.append(S[:args.n_components] ** 2 / (len(idx) - 1))
            args.logger.print("Done")

        args.logger.print("Start Training...")
        for epoch in range(args.n_epochs):
            model.reset_eval()
            for x, y, f_y, names in tqdm(train_loaders[-1]):
                # for simplicity, consider that we know the labels ahead
                f_y = f_y[:, 1]
                x, y = x.to(args.device), y.to(args.device)
                with torch.no_grad():
                    x = model_clip.encode_image(x).type(torch.FloatTensor).to(args.device)
                loss = model.observe(x, y, names, x, f_y, text_embedding=param_copy)
                task_loss_list.append(loss)

            if args.n_epochs == 1:
                cil_correct += model.correct
                til_correct += model.til_correct
                total += model.total
                cil_acc, til_acc = model.acc()
                args.logger.print("Task {}, CIL Cumulative Acc: {:.2f}".format(task_id, cil_acc))
                args.logger.print("All seen classes, CIL Cumulative Acc: {:.2f}".format(cil_correct / total * 100))

            model.reset_eval()
            for x, y, _, _ in test_loaders[-1]:
                x, y = x.to(args.device), y.to(args.device)
                with torch.no_grad():
                    x = model_clip.encode_image(x).type(torch.FloatTensor).to(args.device)
                model.evaluate(x, y, task_id)

            cil_acc, til_acc = model.acc()
            args.logger.print("Task {}, Epoch {}/{}, Total Loss: {:.4f}, CIL Acc: {:.2f}, TIL Acc: {:.2f}".format(task_id, epoch + 1, args.n_epochs, np.mean(task_loss_list), cil_acc, til_acc))

        # End task
        if hasattr(model, 'end_task'):
            model.end_task(train_loaders[-1])

        torch.save(model.net.state_dict(), args.logger.dir() + 'model_task_{}'.format(task_id))

        args.logger.print("######################")
        true_lab, pred_lab = [], []
        for p_task_id, loader in enumerate(test_loaders):
            model.reset_eval()
            for x, y, _, _ in loader:
                x, y = x.to(args.device), y.to(args.device)
                with torch.no_grad():
                    x = model_clip.encode_image(x).type(torch.FloatTensor).to(args.device)
                model.evaluate(x, y, task_id=p_task_id)

            if args.tsne:
                tsne(np.concatenate(model.output_list),
                     np.concatenate(model.label_list),
                     logger=args.logger)

            if args.confusion:
                true_lab_ = np.concatenate(model.true_lab)
                pred_lab_ = np.concatenate(model.pred_lab)

                plot_confusion(true_lab_, pred_lab_, model.seen_names, task_id, p_task_id, logger=args.logger)

                true_lab.append(true_lab_)
                pred_lab.append(pred_lab_)

            if args.confusion and p_task_id == len(test_loaders) - 1:
                true_lab_ = np.concatenate(true_lab)
                pred_lab_ = np.concatenate(pred_lab)
                plot_confusion(true_lab_, pred_lab_, model.seen_names,
                                name='confusion mat task {}'.format(p_task_id), logger=args.logger)

            cil_acc, til_acc = model.acc()
            cil_acc_mat_test[task_id, p_task_id] = cil_acc
            til_acc_mat_test[task_id, p_task_id] = til_acc
        cil_acc_mat_test[task_id, -1] = np.mean(cil_acc_mat_test[task_id, :p_task_id + 1])
        til_acc_mat_test[task_id, -1] = np.mean(til_acc_mat_test[task_id, :p_task_id + 1])

        # Compute forgetting
        for i in range(task_id):
            cil_acc_mat_test[-1, i] = cil_acc_mat_test[i, i] - cil_acc_mat_test[task_id, i]
            til_acc_mat_test[-1, i] = til_acc_mat_test[i, i] - til_acc_mat_test[task_id, i]

        # Compute average incremental accuracy
        cil_acc_mat_test[-1, -1] = np.mean(cil_acc_mat_test[:task_id + 1, -1])
        til_acc_mat_test[-1, -1] = np.mean(til_acc_mat_test[:task_id + 1, -1])

        args.logger.print()
        args.logger.print("CIL result")
        print_result(cil_acc_mat_test, task_id, type='acc', print=args.logger.print)
        print_result(cil_acc_mat_test, task_id, type='forget', print=args.logger.print)
        args.logger.print("TIL result")
        print_result(til_acc_mat_test, task_id, type='acc', print=args.logger.print)
        print_result(til_acc_mat_test, task_id, type='forget', print=args.logger.print)
        args.logger.print()
