
import copy
import math
import sys
import time
import warnings
from argparse import Namespace
from typing import Tuple

import torch
from torch.optim.lr_scheduler import ExponentialLR

from datasets import get_dataset
from datasets.utils.continual_dataset import ContinualDataset
from models.utils.continual_model import ContinualModel

from utils.loggers import *
from utils.model_utils import get_num_params, get_num_trainable_params
from utils.schedulers import PolynomialLR
from utils.status import ProgressBar

try:
    import wandb
except ImportError:
    wandb = None

def mask_classes(outputs: torch.Tensor, dataset: ContinualDataset, k: int) -> None:
    """
    Given the output tensor, the dataset at hand and the current task,
    masks the former by setting the responses for the other tasks at -inf.
    It is used to obtain the results for the task-il setting.
    :param outputs: the output tensor
    :param dataset: the continual dataset
    :param k: the task index
    """
    if k == 0:
        left_bound = 0
        right_bound = dataset.args.classes_first_task
    else:
        left_bound = dataset.args.classes_first_task + (k - 1) * dataset.args.classes_per_task
        right_bound = dataset.args.classes_first_task + k * dataset.args.classes_per_task
    outputs[:, :left_bound] = -float('inf')
    outputs[:, right_bound:] = -float('inf')


def evaluate(model: ContinualModel, dataset: ContinualDataset, last=False) -> Tuple[list, list]:
    """
    Evaluates the accuracy of the model for each past task.
    :param model: the model to be evaluated
    :param dataset: the continual dataset at hand
    :return: a tuple of lists, containing the class-il
             and task-il accuracy for each task
    """
    status = model.net.training
    model.net.eval()
    accs, accs_mask_classes = [], []
    for k, test_loader in enumerate(dataset.test_loaders):
        if last and k < len(dataset.test_loaders) - 1:
            continue
        correct, correct_mask_classes, total = 0.0, 0.0, 0.0
        for data in test_loader:
            with torch.no_grad():
                inputs, labels = data
                inputs, labels = inputs.to(model.device), labels.to(model.device)
                if 'class-il' not in model.COMPATIBILITY:
                    outputs = model(inputs, k)
                else:
                    outputs = model(inputs)

                _, pred = torch.max(outputs.data, 1)
                correct += torch.sum(pred == labels).item()
                total += labels.shape[0]

                if dataset.SETTING == 'class-il':
                    mask_classes(outputs, dataset, k)
                    _, pred = torch.max(outputs.data, 1)
                    correct_mask_classes += torch.sum(pred == labels).item()

        accs.append(correct / total * 100
                    if 'class-il' in model.COMPATIBILITY else 0)
        accs_mask_classes.append(correct_mask_classes / total * 100)

    model.net.train(status)
    return accs, accs_mask_classes


def get_reweighted_mean_accs(accs, dataset):
    res = []
    for task_accs in accs:
        assert len(task_accs) == len(dataset.test_loaders)
        agg_acc, agg_size = 0, 0
        for ta, loader in zip(task_accs, dataset.test_loaders):
            agg_acc += ta * len(loader.dataset)
            agg_size += len(loader.dataset)
        res.append(agg_acc / agg_size)
    return res


def train(model: ContinualModel, dataset: ContinualDataset,
          args: Namespace) -> None:
    """
    The training process, including evaluations and loggers.
    :param model: the module to be trained
    :param dataset: the continual dataset at hand
    :param args: the arguments of the current execution
    """
    print(args)

    if not args.nowand:
        assert wandb is not None, "Wandb not installed, please install it or run without wandb"
        wandb_config = copy.deepcopy(vars(args))
        wandb_config["num_params_total"] = get_num_params(model)
        wandb_config["num_params_trainable"] = get_num_trainable_params(model)
        wandb_name = getattr(args, "experiment_name", None)

        print("Trainable params: ", wandb_config["num_params_trainable"])

        for _ in range(20):
            try:
                wandb.init(anonymous="allow")
            except:
                time.sleep(10)
            else:
                break
        args.wandb_url = wandb.run.get_url()

    model.net.to(model.device)
    results, results_mask_classes = [], []

    if not args.disable_log:
        logger = Logger(dataset.SETTING, dataset.NAME, model.NAME)

    progress_bar = ProgressBar(verbose=not args.non_verbose)

    if not args.ignore_other_metrics:
        dataset_copy = get_dataset(args)
        for t in range(dataset.n_tasks):
            model.net.train()
            _, _ = dataset_copy.get_data_loaders()
        if model.NAME != 'icarl' and model.NAME != 'pnn':
            random_results_class, random_results_task = evaluate(model, dataset_copy)

    if model.NAME == "pec" and args.make_equal_task_sizes:
        dataset_copy = get_dataset(args)
        loader_sizes = []
        for t in range(dataset.n_tasks):
            train_loader, _ = dataset_copy.get_data_loaders()
            loader_sizes.append(len(train_loader))
        model.task_size = min(loader_sizes)

    print(file=sys.stderr)
    for t in range(dataset.n_tasks):
        if args.optim_reset_every_task:
            model.reset_optimizer()

        model.net.train()
        class_start = dataset.i
        train_loader, test_loader = dataset.get_data_loaders()
        class_end = dataset.i

        if hasattr(model, 'begin_task'):
            model.begin_task(dataset)
        if t and t % args.eval_every_n_task == 0 and not args.ignore_other_metrics:
            accs = evaluate(model, dataset, last=True)
            results[-1] = results[-1] + accs[0]
            if dataset.SETTING == 'class-il':
                results_mask_classes[-1] = results_mask_classes[-1] + accs[1]

        scheduler = dataset.get_scheduler(model, args)
        model.reset_per_step_lr_scheduler(num_steps=int(model.args.n_epochs * len(train_loader)),
                                          class_start=class_start, class_end=class_end)
        if scheduler is not None and args.optim_scheduler != "none":
            warnings.warn("Turning off per-epoch scheduler since per-step scheduler was specified!")
            scheduler = None

        epoch_loops = model.args.n_epochs
        if 0 < model.args.n_epochs < 1:
            epoch_loops = 1

        for epoch in range(epoch_loops):
            if args.model == 'joint':
                continue
            for i, data in enumerate(train_loader):
                if args.debug_mode and i > 3:
                    break
                if 0 < model.args.n_epochs < 1 and i > model.args.n_epochs * len(train_loader):
                    break
                if hasattr(dataset.train_loader.dataset, 'logits'):
                    inputs, labels, not_aug_inputs, logits = data
                    inputs = inputs.to(model.device)
                    labels = labels.to(model.device)
                    not_aug_inputs = not_aug_inputs.to(model.device)
                    logits = logits.to(model.device)
                    loss = model.meta_observe(inputs, labels, not_aug_inputs, logits)
                else:
                    inputs, labels, not_aug_inputs = data
                    inputs, labels = inputs.to(model.device), labels.to(
                        model.device)
                    not_aug_inputs = not_aug_inputs.to(model.device)
                    loss = model.meta_observe(inputs, labels, not_aug_inputs)
                model.step_lr_scheduler(labels)
                assert not math.isnan(loss)
                progress_bar.prog(i, len(train_loader), epoch, t, loss)

            if scheduler is not None:
                scheduler.step()

        if hasattr(model, 'end_task'):
            model.end_task(dataset)

        if (t + 1) == dataset.n_tasks and args.model == "pec" and args.unfair_calibration:
            model.unfair_calibration(dataset)

        if (t + 1) == dataset.n_tasks and args.model == "pec" and args.evolution_calibration:
            model.evolution_calibration(dataset)

        if (t + 1) % args.eval_every_n_task == 0 or (t + 1) == dataset.n_tasks:
            accs = evaluate(model, dataset)
            results.append(accs[0])
            results_mask_classes.append(accs[1])

            mean_acc = np.mean(accs, axis=1)
            print_mean_accuracy(mean_acc, t + 1, dataset.SETTING)

            reweighted_mean_acc = get_reweighted_mean_accs(accs, dataset)

            if not args.disable_log:
                logger.log(mean_acc)
                logger.log_fullacc(accs)

            if not args.nowand:
                d2={'RESULT_reweighted_class_mean_accs': reweighted_mean_acc[0],
                    'RESULT_reweighted_task_mean_accs': reweighted_mean_acc[1],
                    'RESULT_class_mean_accs': mean_acc[0], 'RESULT_task_mean_accs': mean_acc[1],
                    **{f'RESULT_class_acc_{i}': a for i, a in enumerate(accs[0])},
                    **{f'RESULT_task_acc_{i}': a for i, a in enumerate(accs[1])}}

                wandb_safe_log(d2)



    if not args.disable_log and not args.ignore_other_metrics:
        logger.add_bwt(results, results_mask_classes)
        logger.add_forgetting(results, results_mask_classes)
        if model.NAME != 'icarl' and model.NAME != 'pnn':
            logger.add_fwt(results, random_results_class,
                    results_mask_classes, random_results_task)

    if not args.disable_log:
        logger.write(vars(args))
        if not args.nowand:
            d = logger.dump()
            d['wandb_url'] = wandb.run.get_url()
            wandb_safe_log(d)

    if not args.nowand:
        wandb.finish()
