import os
import pathlib
import time

import torch
import torch.nn as nn
from typing import *
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter

from approaches.supsup import adaptors, trainers, utils
from approaches.supsup.args import args
from torch.utils.data import DataLoader

from utils import print_num_params


class Appr(object):
    def __init__(self, inputsize: Tuple[int, ...],
                 lr_factor: float, lr_min: float, patience_max: int,
                 list__ncls: List[int], nhid: int, drop1: float, drop2: float,
                 ):
        self.lr_factor = lr_factor
        self.lr_min = lr_min
        self.patience_max = patience_max
        self.list__ncls = list__ncls

        # Make the a directory corresponding to this run for saving results, checkpoints etc.
        i = 0
        while True:
            self.run_base_dir = pathlib.Path(f"{args.log_dir}/{args.name}~try={str(i)}")

            if not self.run_base_dir.exists():
                os.makedirs(self.run_base_dir)
                args.name = args.name + f"~try={i}"
                break
            # endif
            i += 1
        # endwhile

        (self.run_base_dir / "settings.txt").write_text(str(args))
        args.run_base_dir = self.run_base_dir

        print(f"=> Saving data in {self.run_base_dir}")

        # Get dataloader.
        # self.data_loader = getattr(data, args.set)()

        # Track accuracy on all tasks.
        if args.num_tasks:
            self.best_acc1 = [0.0 for _ in range(args.num_tasks)]
            self.best_loss1 = [100000 for _ in range(args.num_tasks)]
            self.curr_acc1 = [0.0 for _ in range(args.num_tasks)]
            self.adapt_acc1 = [0.0 for _ in range(args.num_tasks)]
        # endif

        # Get the model.
        self.inputsize = {'inputsize': inputsize, 'list__ncls': self.list__ncls,
                          'nhid': nhid, 'drop1': drop1, 'drop2': drop2}
        print(f'Initializing a model')
        self.model = utils.get_model(**self.inputsize)
        print(f'Initialized a model')

        print_num_params(self.model)

        # If necessary, set the sparsity of the model of the model using the ER sparsity budget (see paper).
        if args.er_sparsity:
            for n, m in self.model.named_modules():
                if hasattr(m, "sparsity"):
                    m.sparsity = min(
                        0.5,
                        args.sparsity
                        * (m.weight.size(0) + m.weight.size(1))
                        / (
                                m.weight.size(0)
                                * m.weight.size(1)
                                * m.weight.size(2)
                                * m.weight.size(3)
                        ),
                        )
                    print(f"Set sparsity of {n} to {m.sparsity}")
                # endif
            # endfor
        # endif

        # Put the model on the GPU,
        self.model = utils.set_gpu(self.model)  # type: nn.Module

        """
        # Optionally resume from a checkpoint.
        if args.resume:
            if os.path.isfile(args.resume):
                print(f"=> Loading checkpoint '{args.resume}'")
                checkpoint = torch.load(
                    args.resume, map_location=f"cuda:{args.multigpu[0]}"
                    )
                best_acc1 = checkpoint["best_acc1"]
                pretrained_dict = checkpoint["state_dict"]
                model_dict = model.state_dict()
                pretrained_dict = {
                    k: v for k, v in pretrained_dict.items() if k in model_dict
                    }
                model_dict.update(pretrained_dict)
                model.load_state_dict(pretrained_dict)

                print(f"=> Loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})")
            else:
                print(f"=> No checkpoint found at '{args.resume}'")
            # endif
        # endif
        """

        self.criterion = nn.CrossEntropyLoss().to(args.device)

        self.writer = SummaryWriter(log_dir=self.run_base_dir)

        # Track the number of tasks learned.
        self.num_tasks_learned = 0

        self.trainer = getattr(trainers, args.trainer or "default")
        print(f"=> Using trainer {self.trainer}")

        self.train, self.test = self.trainer.train, self.trainer.test

        # Initialize model specific context (editorial note: avoids polluting main file)
        if hasattr(self.trainer, "init"):
            self.trainer.init(args)
        # endif

        """
        # TODO: Put this in another file
        if args.task_eval is not None:
            assert 0 <= args.task_eval < args.num_tasks, "Not a valid task idx"
            print(f"Task {args.set}: {args.task_eval}")

            model.apply(lambda m: setattr(m, "task", args.task_eval))

            assert hasattr(
                data_loader, "update_task"
                ), "[ERROR] Need to implement update task method for use with multitask experiments"

            data_loader.update_task(args.task_eval)

            optimizer = get_optimizer(args, model)
            lr_scheduler = schedulers.get_policy(args.lr_policy or "cosine_lr")(
                optimizer, args
                )

            # Train and do inference and normal for args.epochs epcohs.
            best_acc1 = 0.0

            for epoch in range(0, args.epochs):
                lr_scheduler(epoch, None)

                train(
                    model,
                    writer,
                    data_loader.train_loader,
                    optimizer,
                    criterion,
                    epoch,
                    task_idx=args.task_eval,
                    data_loader=None,
                    )

                curr_acc1 = test(
                    model,
                    writer,
                    criterion,
                    data_loader.val_loader,
                    epoch,
                    task_idx=args.task_eval,
                    )

                if curr_acc1 > best_acc1:
                    best_acc1 = curr_acc1
                # endif
            # endfor

            utils.write_result_to_csv(
                name=f"{args.name}~set={args.set}~task={args.task_eval}",
                curr_acc1=curr_acc1,
                best_acc1=best_acc1,
                save_dir=run_base_dir,
                )

            if args.save:
                torch.save(
                    {
                        "epoch": args.epochs,
                        "arch": args.model,
                        "state_dict": model.state_dict(),
                        "best_acc1": best_acc1,
                        "curr_acc1": curr_acc1,
                        "args": args,
                        },
                    run_base_dir / "final.pt",
                    )
            # endif

            return best_acc1
        """
    # enddef

    def train_task(self, idx_task: int, dl_train: DataLoader, dl_val: DataLoader):
        print(f"Task {args.set}: {idx_task}")

        # Tell the model which task it is trying to solve -- in Scenario NNs this is ignored.
        self.model.apply(lambda m: setattr(m, "task", idx_task))

        # Update the data loader so that it returns the data for the correct task, also done by passing the task index.
        """
        assert hasattr(
            self.data_loader, "update_task"
            ), "[ERROR] Need to implement update task method for use with multitask experiments"

        self.data_loader.update_task(idx_task)
        """

        # Clear the grad on all the parameters.
        for p in self.model.parameters():
            p.grad = None
        # endfor

        # Make a list of the parameters relavent to this task.
        params = []
        for n, p in self.model.named_parameters():
            if not p.requires_grad:
                continue
            # endif
            split = n.split(".")
            if split[-2] in ["scores", "s", "t"] and (
                    int(split[-1]) == idx_task or (args.trainer and "nns" in args.trainer)
            ):
                params.append(p)
            # endif

            # train all weights if train_weight_tasks is -1, or num_tasks_learned < train_weight_tasks
            if (args.train_weight_tasks < 0
                    or self.num_tasks_learned < args.train_weight_tasks):
                if split[-1] == "weight" or split[-1] == "bias":
                    params.append(p)
                # endif
            # endif
        # endfor

        # train_weight_tasks specifies the number of tasks that the weights are trained for.
        # e.g. in SupSup, train_weight_tasks = 0. in BatchE, train_weight_tasks = 1.
        # If training weights, use train_weight_lr. Else use lr.
        lr = (
            args.train_weight_lr
            if args.train_weight_tasks < 0
               or self.num_tasks_learned < args.train_weight_tasks
            else args.lr
        )

        # get optimizer, scheduler
        """
        if args.optimizer == "adam":
            optimizer = optim.Adam(params, lr=lr, weight_decay=args.wd)
        elif args.optimizer == "rmsprop":
            optimizer = optim.RMSprop(params, lr=lr)
        """
        if args.optimizer == 'sgd':
            optimizer = optim.SGD(
                params, lr=lr, momentum=args.momentum, weight_decay=args.wd
                )
        else:
            raise NotImplementedError()
        # endif

        train_epochs = args.epochs

        if args.no_scheduler:
            scheduler = None
        else:
            scheduler = ReduceLROnPlateau(optimizer,
                                          mode='min',
                                          factor=1.0 / self.lr_factor,
                                          patience=max(self.patience_max - 1, 0),
                                          min_lr=self.lr_min,
                                          verbose=True,
                                          )
            # scheduler = CosineAnnealingLR(optimizer, T_max=train_epochs)
        # endif

        # Train on the current task.
        time_start = time.time()
        patience = 0
        for epoch in range(1, train_epochs + 1):

            self.train(
                self.model,
                self.writer,
                # self.data_loader.train_loader,
                dl_train,
                optimizer,
                self.criterion,
                epoch,
                idx_task,
                # self.data_loader,
                None,
                )

            # Required for our PSP implementation, not used otherwise.
            utils.cache_weights(self.model, self.num_tasks_learned + 1)

            test_loss, test_acc = self.test(self.model,
                                            self.writer,
                                            self.criterion,
                                            # self.data_loader.val_loader,
                                            dl_val,
                                            epoch,
                                            idx_task)
            self.curr_acc1[idx_task] = test_acc
            # if self.curr_acc1[idx_task] > self.best_acc1[idx_task]:
            if epoch == 1 or test_loss < self.best_loss1[idx_task]:
                # self.best_acc1[idx_task] = self.curr_acc1[idx_task]
                self.best_loss1[idx_task] = test_loss
                self.best_acc1[idx_task] = test_acc
                bestparams = self.model.state_dict()
            # endif

            if scheduler:
                # scheduler.step()
                scheduler.step(test_loss)
            # endif

            list__lr = [param_group['lr'] for param_group in optimizer.param_groups]
            assert len(list__lr) == 1, list__lr
            if list__lr[0] <= self.lr_min:
                patience += 1
            else:
                patience = 0
            # endif

            if patience >= self.patience_max:
                self.model.load_state_dict(bestparams)
                break
            # endif

            if (args.iter_lim > 0
                    # and len(self.data_loader.train_loader) * epoch > args.iter_lim
                    and len(dl_train) * epoch > args.iter_lim
            ):
                break
            # endif
        # endfor
        time_end = time.time()
        time_consumed = time_end - time_start

        utils.write_result_to_csv(
            name=f"{args.name}~set={args.set}~task={idx_task}",
            curr_acc1=self.curr_acc1[idx_task],
            best_acc1=self.best_acc1[idx_task],
            save_dir=self.run_base_dir,
            )

        # Save memory by deleting the optimizer and scheduler.
        del optimizer, scheduler, params

        # Increment the number of tasks learned.
        self.num_tasks_learned += 1

        # If operating in NNS scenario, get the number of tasks learned count from the model.
        if args.trainer and "nns" in args.trainer:
            self.model.apply(
                lambda m: setattr(
                    m, "num_tasks_learned", min(self.model.num_tasks_learned, args.num_tasks)
                    )
                )
        else:
            self.model.apply(lambda m: setattr(m, "num_tasks_learned", self.num_tasks_learned))
        # endif

        """
        if args.save:
            torch.save(
                {
                    "epoch": args.epochs,
                    "arch": args.model,
                    "state_dict": model.state_dict(),
                    "best_acc1": best_acc1,
                    "curr_acc1": curr_acc1,
                    "args": args,
                    },
                run_base_dir / "final.pt",
                )
        # endif
        """

        # return self.adapt_acc1
        return time_consumed
    # enddef

    def test_task(self, idx_task: int, dl_test: DataLoader) -> None:
        # TODO series of asserts with required arguments (eg num_tasks)
        # args.eval_ckpts contains values of num_tasks_learned for which testing on all tasks so far is performed.
        # this is done by default when all tasks have been learned, but you can do something like
        # args.eval_ckpts = [5,10] to also do this when 5 tasks are learned, and again when 10 tasks are learned.

        # if self.num_tasks_learned in args.eval_ckpts or self.num_tasks_learned == args.num_tasks:
        if True:
            avg_acc = 0.0
            avg_correct = 0.0

            # Settting task to -1 tells the model to infer task identity instead of being given the task.
            self.model.apply(lambda m: setattr(m, "task", -1))

            # an "adaptor" is used to infer task identity.
            # args.adaptor == gt implies we are in scenario GG.

            # This will cache all of the information the model needs for inferring task identity.
            if args.adaptor != "gt":
                raise NotImplementedError(args)
                utils.cache_masks(self.model)
            # endif

            # Iterate through all tasks.
            adapt = getattr(adaptors, args.adaptor)

            # for i in range(self.num_tasks_learned):
            for i in [idx_task]:
                print(f"Testing {i}: {args.set} ({i})")
                # model.apply(lambda m: setattr(m, "task", i))

                # Update the data loader so it is returning data for the right task.
                # self.data_loader.update_task(i)

                # Clear the stored information -- memory leak happens if not.
                for p in self.model.parameters():
                    p.grad = None
                # endfor

                for b in self.model.buffers():
                    b.grad = None
                # endfor

                torch.cuda.empty_cache()

                adapt_acc = adapt(
                    model=self.model,
                    # idx_task,
                    # i,
                    writer=self.writer,
                    # self.data_loader.val_loader,
                    test_loader=dl_test,
                    num_tasks_learned=self.num_tasks_learned,
                    task=i,
                    )

                self.adapt_acc1[i] = adapt_acc
                avg_acc += adapt_acc

                torch.cuda.empty_cache()
                utils.write_adapt_results(
                    name=args.name,
                    task=f"{args.set}_{i}",
                    num_tasks_learned=self.num_tasks_learned,
                    curr_acc1=self.curr_acc1[i],
                    adapt_acc1=adapt_acc,
                    task_number=i,
                    )
            # endfor

            self.writer.add_scalar(
                "adapt/avg_acc", avg_acc / self.num_tasks_learned, self.num_tasks_learned
                )

            utils.clear_masks(self.model)
            torch.cuda.empty_cache()
        # endif
    # enddef

    # TODO: Remove this with task-eval
    def get_optimizer(self, args, model):
        for n, v in model.named_parameters():
            if v.requires_grad:
                print("<DEBUG> gradient to", n)
            # endif

            if not v.requires_grad:
                print("<DEBUG> no gradient to", n)
            # endif
        # endfor

        if args.optimizer == "sgd":
            parameters = list(model.named_parameters())
            bn_params = [v for n, v in parameters if ("bn" in n) and v.requires_grad]
            rest_params = [v for n, v in parameters if ("bn" not in n) and v.requires_grad]
            optimizer = torch.optim.SGD(
                [
                    {"params": bn_params, "weight_decay": args.wd, },
                    {"params": rest_params, "weight_decay": args.wd},
                    ],
                args.lr,
                momentum=args.momentum,
                weight_decay=args.wd,
                nesterov=False,
                )
        else:
            raise NotImplementedError()
        """
        elif args.optimizer == "adam":
            optimizer = torch.optim.Adam(
                filter(lambda p: p.requires_grad, model.parameters()),
                lr=args.lr,
                weight_decay=args.wd,
                )
        elif args.optimizer == "rmsprop":
            optimizer = torch.optim.RMSprop(
                filter(lambda p: p.requires_grad, model.parameters()), lr=lr
                )                
        """
        # endif

        return optimizer
    # enddef
