import os
import math
from copy import deepcopy, copy
from pathlib import Path
from argparse import ArgumentParser
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union, Optional, Text

import pytorch_lightning as pl
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR


def weighted_mean(outputs: List[Dict], key: str, batch_size_key: str) -> float:
    """Computes the mean of the values of a key weighted by the batch size.

    Args:
        outputs (List[Dict]): list of dicts containing the outputs of a validation step.
        key (str): key of the metric of interest.
        batch_size_key (str): key of batch size values.

    Returns:
        float: weighted mean of the values of a key
    """

    value = 0
    n = 0
    for out in outputs:
        value += out[batch_size_key] * out[key]
        n += out[batch_size_key]
    value = value / n
    try:
        return value.squeeze(0)
    except AttributeError:
        return value


class BaseTrain(pl.LightningModule):

    def __init__(
        self,
        max_epochs: int,
        batch_size: int,
        optimizer: str,
        lr: float,
        momentum: float,
        beta2: float,
        weight_decay: float,
        scheduler: str,
        min_lr: float,
        warmup_start_lr: float,
        warmup_epochs: float,
        lr_decay_steps: Sequence = None,
        **kwargs,
    ):
        super().__init__()
        self.max_epochs = max_epochs
        self.batch_size = batch_size
        self.optimizer = optimizer
        self.lr = lr
        self.momentum = momentum
        self.beta2 = beta2
        self.weight_decay = weight_decay
        self.scheduler = scheduler
        self.lr_decay_steps = lr_decay_steps
        self.min_lr = min_lr
        self.warmup_start_lr = warmup_start_lr
        self.warmup_epochs = warmup_epochs

    @staticmethod
    def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
        """Adds shared basic arguments that are shared for all methods.

        Args:
            parent_parser (ArgumentParser): argument parser that is used to create a
                argument group.

        Returns:
            ArgumentParser: same as the argument, used to avoid errors.
        """

        parser = parent_parser.add_argument_group("training")

        # general train
        parser.add_argument("--lr", type=float, default=0.015)
        parser.add_argument("--momentum", type=float, default=0.9)
        parser.add_argument("--beta2", type=float, default=0.999)
        parser.add_argument("--weight_decay", type=float, default=5e-4)

        # optimizer
        SUPPORTED_OPTIMIZERS = ["sgd", "adam", "adamw"]
        parser.add_argument("--optimizer", choices=SUPPORTED_OPTIMIZERS, type=str, default='adamw')

        # scheduler
        SUPPORTED_SCHEDULERS = [
            "cosine",
            "warmup_cosine",
            "step",
            "none",
        ]

        parser.add_argument("--scheduler", choices=SUPPORTED_SCHEDULERS, type=str, default="cosine")
        parser.add_argument("--lr_decay_steps", default=None, type=int, nargs="+")
        parser.add_argument("--min_lr", default=0.0, type=float)
        parser.add_argument("--warmup_start_lr", default=0.003, type=float)
        parser.add_argument("--warmup_epochs", default=0, type=int)

        return parent_parser

    @property
    def learnable_params(self) -> List[Dict[str, Any]]:
        raise NotImplementedError()

    def forward(self, *args, **kwargs) -> List[Dict[str, Any]]:
        raise NotImplementedError()

    def configure_optimizers(self) -> Tuple[List, List]:
        """Collects learnable parameters and configures the optimizer and learning rate scheduler.
        """
        # collect learnable parameters
        learnable_params = self.learnable_params
        for group in learnable_params:
            group['params'] = list(filter(lambda x: x.requires_grad, group['params']))

        # select optimizer
        if self.optimizer == "sgd":
            optimizer = partial(torch.optim.SGD, momentum=self.momentum)
        elif self.optimizer == "adam":
            optimizer = partial(torch.optim.Adam, betas=(self.momentum, self.beta2))
        elif self.optimizer == "adamw":
            optimizer = partial(torch.optim.AdamW, betas=(self.momentum, self.beta2))
        else:
            raise ValueError(f"{self.optimizer} not in (sgd, adam, adamw)")

        # create optimizer
        optimizer = optimizer(
            self.learnable_params,
            lr=self.lr,
            weight_decay=self.weight_decay,
        )

        if self.scheduler == "none":
            return optimizer

        if self.scheduler == "warmup_cosine":
            scheduler = LinearWarmupCosineAnnealingLR(
                optimizer,
                warmup_epochs=self.warmup_epochs,
                max_epochs=self.max_epochs,
                warmup_start_lr=self.warmup_start_lr,
                eta_min=self.min_lr,
            )
        elif self.scheduler == "cosine":
            scheduler = CosineAnnealingLR(optimizer, self.max_epochs, eta_min=self.min_lr)
        elif self.scheduler == "step":
            scheduler = MultiStepLR(optimizer, self.lr_decay_steps)
        else:
            raise ValueError(f"{self.scheduler} not in (warmup_cosine, cosine, step)")

        return [optimizer], [scheduler]

    def _calculate_metrics(self, logits: torch.Tensor, targets: torch.Tensor) -> Dict:
        """Computes the classification loss using the output `logits` and the `targets`.
        """

        loss = F.cross_entropy(logits, targets)
        acc = 100 * logits.argmax(-1).eq(targets).float().mean(-1)

        metrics = {f"loss": loss, f"acc": acc}

        return metrics

    def training_step(self, data, batch_idx: int) -> Dict[str, Any]:
        """Training step for pytorch lightning. It does all the shared operations, such as
        forwarding the crops, computing logits and computing statistics.
        """
        outs = self(data)
        if hasattr(data, 'train_mask'):
            logits = outs['logits'][data.train_mask]
            targets = data.y[data.train_mask]
            batch_size = 1
        else:
            logits = outs['logits']
            try:
                targets = data.y
            except AttributeError:
                targets = data[1]
            try:
                batch_size = targets.size(0)
            except IndexError:
                targets = targets.reshape(-1)
                batch_size = 1

        metrics = self._calculate_metrics(logits, targets)
        loss = metrics['loss']

        metrics = dict(map(lambda x: ('train_' + x[0], x[1]), metrics.items()))
        self.log_dict(metrics, prog_bar=True, on_epoch=True, sync_dist=True, batch_size=batch_size)
        return loss

    def validation_step(
        self, data, batch_idx: int, dataloader_idx: int = None
    ) -> Dict[str, Any]:
        """Validation step for pytorch lightning. It does all the shared operations, such as
        forwarding a batch of images, computing logits and computing metrics.
        """
        outs = self(data)
        if hasattr(data, 'val_mask'):
            logits = outs['logits'][data.val_mask]
            targets = data.y[data.val_mask]
            batch_size = 1
        else:
            logits = outs['logits']
            try:
                targets = data.y
            except AttributeError:
                targets = data[1]
            try:
                batch_size = targets.size(0)
            except IndexError:
                targets = targets.reshape(-1)
                batch_size = 1

        metrics = self._calculate_metrics(logits, targets)
        metrics = dict(map(lambda x: ('val_' + x[0], x[1]), metrics.items()))
        metrics['batch_size'] = batch_size
        return metrics

    def validation_epoch_end(self, outs: List[Dict[str, Any]]):
        """Averages the losses and accuracies of all the validation batches.
        This is needed because the last batch can be smaller than the others,
        slightly skewing the metrics.
        """

        metrics = dict()
        for key in outs[0].keys():
            metrics[key] = weighted_mean(outs, key, "batch_size")
        self.log_dict(metrics, sync_dist=True)

    def test_step(
        self, data, batch_idx: int, dataloader_idx: int = None
    ) -> Dict[str, Any]:
        """Validation step for pytorch lightning. It does all the shared operations, such as
        forwarding a batch of images, computing logits and computing metrics.
        """
        outs = self(data)
        if hasattr(data, 'test_mask'):
            logits = outs['logits'][data.test_mask]
            targets = data.y[data.test_mask]
            batch_size = 1
        else:
            logits = outs['logits']
            try:
                targets = data.y
            except AttributeError:
                targets = data[1]
            try:
                batch_size = targets.size(0)
            except IndexError:
                targets = targets.reshape(-1)
                batch_size = 1

        metrics = self._calculate_metrics(logits, targets)
        metrics = dict(map(lambda x: ('test_' + x[0], x[1]), metrics.items()))
        metrics['batch_size'] = batch_size
        return metrics

    def test_epoch_end(self, outs: List[Dict[str, Any]]):
        """Averages the losses and accuracies of all the validation batches.
        This is needed because the last batch can be smaller than the others,
        slightly skewing the metrics.
        """

        metrics = dict()
        for key in outs[0].keys():
            metrics[key] = weighted_mean(outs, key, "batch_size")
        self.log_dict(metrics, sync_dist=True)


class BasePrune(object):
    @property
    def arch(self):
        raise NotImplementedError('Implement a BasePrune and include the architecture that will be pruned...')

    def __init__(self, target_net_path: Path, target_args: Dict[str, Any],
                 source_net_path: Path,
                 check_w_lt_eps: bool,
                 eps: float,
                 debug: bool,
                 overparam_factor: int,
                 solver: Text,
                 num_workers: int = 1,
                 num_threads: int = 0,
                 timeout: int = 300,
                 **kwargs):
        self.solver = solver
        self.source_net_path = source_net_path
        self.overparam_factor = overparam_factor
        self.eps = eps
        self.debug = debug
        self.check_w_lt_eps = check_w_lt_eps
        self.num_workers = num_workers
        self.num_threads = num_threads
        self.timeout = timeout

        print("Loading target network...")
        self.states = torch.load(target_net_path, map_location='cpu')
        train_model = self.arch(**target_args)
        train_model.load_state_dict(self.states['state_dict'])
        self.target_model = train_model.encoder
        self.target_model.eval()
        self.target_model.requires_grad_(False)

        # determine the bounds of the uniform distribution from which the a_i coefficients are drawn
        high = max(map(torch.max, self.target_model.parameters()))
        low = min(map(torch.min, self.target_model.parameters()))
        self.bound = 2 * max(abs(high), abs(low))

        print("Preparing subsetsum problems...")
        self.target_net_depth = 1
        self.problems = self.prepare_problems(self.target_model)
        assert(len(self.problems) > 0)
        self.target_net_depth = len(self.problems)
        del self.problems
        self.problems = self.prepare_problems(self.target_model)

        train_model = self.arch(**target_args)
        train_model.load_state_dict(self.states['state_dict'])
        self.source_model = train_model.encoder
        self.source_model.eval()
        self.source_model.requires_grad_(False)

        if os.path.isfile(source_net_path):
            self.solve(draft=True)
            self.source_model.load_state_dict(torch.load(source_net_path))
        else:
            self.solve(draft=False)
            torch.save(self.source_model.state_dict(), source_net_path)

    def solve(self, draft=False):
        print("Solving subsetsum problems...")
        self.solutions = self.prune(self.problems, draft=draft)

        print("Putting together pruned overparameterized network...")
        self.source_model = self.build_src_model(self.solutions, self.source_model)

        print(self.source_model)
        for key in self.solutions.keys():
            del self.solutions[key]['model']

    def init_src_from_trg_mod(self, mod: nn.Linear, type_=torch.nn.Linear):
        if not isinstance(mod, type_):
            raise TypeError(f'not a {type_}')
        assert(mod.bias is None)

        n_out, n_in = mod.weight.shape
        # times 2, because of ReLU
        overparam_factor = 2 * int(math.ceil(math.log2(self.target_net_depth * n_in * n_out/self.eps) * self.overparam_factor))
        n_int = n_in * overparam_factor

        lin1 = nn.Linear(n_in, n_int, bias=False)
        with torch.no_grad():
            lin1.weight.fill_(0.)
            for n in range(n_in):
                lin1.weight[n*overparam_factor:(n+1)*overparam_factor, n].uniform_(-self.bound, self.bound)

        lin2 = nn.Linear(n_int, n_out, bias=False)
        with torch.no_grad():
            lin2.weight.uniform_(-self.bound, self.bound)

        src_mod = [
            lin1,
            nn.ReLU(inplace=True),
            lin2,
        ]

        return overparam_factor, src_mod

    def prepare_problems(self, target_model, prefix=''):
        problems = dict()
        for i, trg_mod in enumerate(target_model):
            try:
                over_factor, src_mod = self.init_src_from_trg_mod(trg_mod)
            except TypeError as e:
                continue
            key = prefix + '_' + str(i) if prefix else str(i)
            problems[key] = (prefix, over_factor, src_mod, trg_mod)
        return problems

    def build_src_model(self, solutions, source_model: nn.ModuleList, prefix=''):
        for idx, mod in enumerate(source_model):
            try:
                key = prefix + '_' + str(idx) if prefix else str(idx)
                source_subnet = solutions[key]['model']
            except KeyError:
                continue
            source_model[idx] = source_subnet
        return source_model

    @property
    def pruning_results(self):
        overparams_before = sum(map(lambda x: x.numel(), self.source_model.parameters()))
        params_target = sum(map(lambda x: x.numel(), self.target_model.parameters()))
        overparams_after = sum(map(lambda x: x.abs().gt(0.).int().sum(), self.source_model.parameters()))
        return dict(params_before_pruning=int(overparams_before),
                    params_after_pruning=int(overparams_after),
                    percentage_remaining=int(overparams_after) / int(overparams_before) * 100,
                    params_target=int(params_target),
                    overparam_factors=[p[1] for p in self.problems.values()],
                    max_rel_weight_error=max(s['max_rel_error'] for s in self.solutions.values()),
                    max_abs_weight_error=max(s['max_abs_error'] for s in self.solutions.values()),
                    )

    def prune(self, problems, draft=False):
        raise NotImplementedError()

    @torch.no_grad()
    def test_src_model(self, dataloader, device=None):
        self.source_model.to(device)
        self.target_model.to(device)

        N = 0
        src_acc = 0.
        trg_acc = 0.
        avg_rel_out_error = 0.
        max_rel_out_error = 0.
        print('vc vl')
        for x, y in tqdm(dataloader):
            x = x.to(device)
            y = y.to(device)
            y.squeeze_()
            try:
                N += y.size(0)
            except IndexError:
                N += 1
            src_logits = self.source_model(x)
            src_acc += src_logits.argmax(-1).eq(y).float().sum()

            trg_logits = self.target_model(x)
            norm_trg_logits = torch.linalg.norm(trg_logits, dim=-1)
            trg_acc += trg_logits.argmax(-1).eq(y).float().sum()

            rel_error =  torch.linalg.norm(src_logits - trg_logits, dim=-1) / norm_trg_logits
            avg_rel_out_error += rel_error.sum()
            max_rel_out_error = max(max_rel_out_error, float(rel_error.max()))

        src_acc /= N
        trg_acc /= N
        avg_rel_out_error /= N

        return dict(src_acc=float(src_acc),
                    trg_acc=float(trg_acc),
                    avg_rel_out_error=float(avg_rel_out_error),
                    max_rel_out_error=float(max_rel_out_error))

    @staticmethod
    def add_model_specific_args(parent_parser: ArgumentParser):
        parser = parent_parser.add_argument_group("pruning")
        parser.add_argument("--check_w_lt_eps", type=eval, default=False, choices=[True, False],
            help='before running subset sum, check if the weight magnitude is less than eps. If it is, then the approximation is thresholded to 0')
        parser.add_argument("--eps", type=float, default=0.01, metavar='E',
            help='tolerance for each weight approximation (default: 0.01)')
        parser.add_argument("--overparam_factor", type=float, default=5., metavar='C',
            help='multiplicative constant; only used when deterministic approximation is not used. n = round(c * log(1/eps)) (default: 5)')
        parser.add_argument("--debug", type=eval, default=False, choices=[True, False])
        parser.add_argument("--num_threads", type=int, default=0,
            help='how many threads per subset-sum problem')  # Zero threads means as many as virtual cores
        parser.add_argument("--timeout", type=int, default=2000,
            help='timeout of each subset-sum problem in ms')
        parser.add_argument("--solver", type=str, default='ortools', choices=['mock', 'gurobi', 'ortools'])

        return parent_parser
