from argparse import ArgumentParser
import logging
import math
import os.path
import sys
import time
import warnings

import numpy as np
import torch
import torch.nn.functional as F
from glm_saga.elasticnet import maximum_reg_loader, get_device, elastic_loss_and_acc_loader
from torch import nn

from crossProjectHelpers.utils import safe_zip
import torch as ch

# TODO checkout this change: Marks changes to the group version of glmsaga

"""
This would need glm_saga to run
usage to select 50 features with parameters as in paper:
metadata contains information about the precomputed train features in feature_loaders
args contains the default arguments for glm-saga, as described at the bottom
def get_glm_to_zero(feature_loaders, metadata,  args, num_classes, device, train_ds, Ntotal):
    num_features = metadata["X"]["num_features"][0]
    fittingClass = FeatureSelectionFitting(num_features, num_lasses, args, 0.8,
                                           50,
                                           True,0.1,
                                           lookback=3, tol=1e-4,
                                           epsilon=1,)
    to_drop, test_acc = fittingClass.fit(feature_loaders, metadata, device)
    return to_drop
    
to_drop is then used to remove the features from the downstream fitting and finetuning.
"""


class FeatureSelectionFitting:
    def __init__(self, n_features, n_classes, args, selalpha, nKeep, lam_fac, lookback=None, tol=None,
                 epsilon=None):
        """
        This is an adaption of the group version of glm-saga (https://github.com/MadryLab/DebuggableDeepNetworks)
        The function extended_mask_max covers the changed operator,
        Args:
            n_features:
            n_classes:
            args: default args for glmsaga
            selalpha: alpha for elastic net
            nKeep: target number features
            lam_fac: discount factor for lambda
            parameters of glmsaga
            lookback:
            tol:
            epsilon:
        """
        self.selected_features = torch.zeros(n_features, dtype=torch.bool)
        self.num_features = n_features
        self.selalpha = selalpha
        self.lam_Fac = lam_fac
        self.n_classes = n_classes
        self.nKeep = nKeep
        self.args = self.extend_args(args, lookback, tol, epsilon)

    # Extended Proximal Operator for Feature Selection
    def extended_mask_max(self, greater_to_keep, thresh):
        prev = greater_to_keep[self.selected_features]
        greater_to_keep[self.selected_features] = torch.min(greater_to_keep)
        max_entry = torch.argmax(greater_to_keep)
        greater_to_keep[self.selected_features] = prev
        mask = torch.zeros_like(greater_to_keep)
        mask[max_entry] = 1
        final_mask = (greater_to_keep > thresh)
        final_mask = final_mask * mask
        allowed_to_keep = torch.logical_or(self.selected_features, final_mask)
        return allowed_to_keep

    def extend_args(self, args, lookback, tol, epsilon):
        for key, entry in safe_zip(["lookbehind", "tol",
                                    "lr_decay_factor", ], [lookback, tol, epsilon]):
            if entry is not None:
                setattr(args, key, entry)
        return args

    # Grouped L1 regularization
    # proximal operator for f(weight) = lam * \|weight\|_2
    # where the 2-norm is taken columnwise
    def group_threshold(self, weight, lam):
        norm = weight.norm(p=2, dim=0) + 1e-6
        #  print(ch.sum((norm > lam)))
        return (weight - lam * weight / norm) * self.extended_mask_max(norm, lam)

    # Elastic net regularization with group sparsity
    # proximal operator for f(x) = alpha * \|x\|_1 + beta * \|x\|_2^2
    # where the 2-norm is taken columnwise
    def group_threshold_with_shrinkage(self, x, alpha, beta):
        y = self.group_threshold(x, alpha)
        return y / (1 + beta)

    def threshold(self, weight_new, lr, lam):
        alpha = self.selalpha
        if alpha == 1:
            # Pure L1 regularization
            weight_new = self.group_threshold(weight_new, lr * lam * alpha)
        else:
            # Elastic net regularization
            weight_new = self.group_threshold_with_shrinkage(weight_new, lr * lam * alpha,
                                                             lr * lam * (1 - alpha))
        return weight_new

    # Train an elastic GLM with proximal SAGA
    # Since SAGA stores a scalar for each example-class pair, either pass
    # the number of examples and number of classes or calculate it with an
    # initial pass over the loaders
    def train_saga(self, linear, loader, lr, nepochs, lam, alpha, group=True, verbose=None,
                   state=None, table_device=None, n_ex=None, n_classes=None, tol=1e-4,
                   preprocess=None, lookbehind=None, family='multinomial', logger=None):
        if logger is None:
            logger = print
        with ch.no_grad():
            weight, bias = list(linear.parameters())
            if table_device is None:
                table_device = weight.device

            # get total number of examples and initialize scalars
            # for computing the gradients
            if n_ex is None:
                n_ex = sum(tensors[0].size(0) for tensors in loader)
            if n_classes is None:
                if family == 'multinomial':
                    n_classes = max(tensors[1].max().item() for tensors in loader) + 1
                elif family == 'gaussian':
                    for batch in loader:
                        y = batch[1]
                        break
                    n_classes = y.size(1)

            # Storage for scalar gradients and averages
            if state is None:
                a_table = ch.zeros(n_ex, n_classes).to(table_device)
                w_grad_avg = ch.zeros_like(weight).to(weight.device)
                b_grad_avg = ch.zeros_like(bias).to(weight.device)
            else:
                a_table = state["a_table"].to(table_device)
                w_grad_avg = state["w_grad_avg"].to(weight.device)
                b_grad_avg = state["b_grad_avg"].to(weight.device)

            obj_history = []
            obj_best = None
            nni = 0
            for t in range(nepochs):
                total_loss = 0
                for n_batch, batch in enumerate(loader):
                    if len(batch) == 3:
                        X, y, idx = batch
                        w = None
                    elif len(batch) == 4:
                        X, y, w, idx = batch
                    else:
                        raise ValueError(
                            f"Loader must return (data, target, index) or (data, target, index, weight) but instead got a tuple of length {len(batch)}")

                    if preprocess is not None:
                        device = get_device(preprocess)
                        with ch.no_grad():
                            X = preprocess(X.to(device))
                    X = X.to(weight.device)
                    out = linear(X)

                    # split gradient on only the cross entropy term
                    # for efficient storage of gradient information
                    if family == 'multinomial':
                        if w is None:
                            loss = F.cross_entropy(out, y.to(weight.device), reduction='mean')
                        else:
                            loss = F.cross_entropy(out, y.to(weight.device), reduction='none')
                            loss = (loss * w).mean()
                        I = ch.eye(linear.weight.size(0))
                        target = I[y].to(weight.device)  # change to OHE

                        # Calculate new scalar gradient
                        logits = F.softmax(linear(X))
                    elif family == 'gaussian':
                        if w is None:
                            loss = 0.5 * F.mse_loss(out, y.to(weight.device), reduction='mean')
                        else:
                            loss = 0.5 * F.mse_loss(out, y.to(weight.device), reduction='none')
                            loss = (loss * (w.unsqueeze(1))).mean()
                        target = y

                        # Calculate new scalar gradient
                        logits = linear(X)
                    else:
                        raise ValueError(f"Unknown family: {family}")
                    total_loss += loss.item() * X.size(0)

                    # BS x NUM_CLASSES
                    a = logits - target
                    if w is not None:
                        a = a * w.unsqueeze(1)
                    a_prev = a_table[idx].to(weight.device)

                    # weight parameter
                    w_grad = (a.unsqueeze(2) * X.unsqueeze(1)).mean(0)
                    w_grad_prev = (a_prev.unsqueeze(2) * X.unsqueeze(1)).mean(0)
                    w_saga = w_grad - w_grad_prev + w_grad_avg
                    weight_new = weight - lr * w_saga
                    weight_new = self.threshold(weight_new, lr, lam)
                    # bias parameter
                    b_grad = a.mean(0)
                    b_grad_prev = a_prev.mean(0)
                    b_saga = b_grad - b_grad_prev + b_grad_avg
                    bias_new = bias - lr * b_saga

                    # update table and averages
                    a_table[idx] = a.to(table_device)
                    w_grad_avg.add_((w_grad - w_grad_prev) * X.size(0) / n_ex)
                    b_grad_avg.add_((b_grad - b_grad_prev) * X.size(0) / n_ex)

                    if lookbehind is None:
                        dw = (weight_new - weight).norm(p=2)
                        db = (bias_new - bias).norm(p=2)
                        criteria = ch.sqrt(dw ** 2 + db ** 2)

                        if criteria.item() <= tol:
                            return {
                                "a_table": a_table.cpu(),
                                "w_grad_avg": w_grad_avg.cpu(),
                                "b_grad_avg": b_grad_avg.cpu()
                            }

                    weight.data = weight_new
                    bias.data = bias_new

                saga_obj = total_loss / n_ex + lam * alpha * weight.norm(p=1) + 0.5 * lam * (1 - alpha) * (
                        weight ** 2).sum()

                # save amount of improvement
                obj_history.append(saga_obj.item())
                if obj_best is None or saga_obj.item() + tol < obj_best:
                    obj_best = saga_obj.item()
                    nni = 0
                else:
                    nni += 1

                # Stop if no progress for lookbehind iterationsd:])
                criteria = lookbehind is not None and (nni >= lookbehind)

                nnz = (weight.abs() > 1e-5).sum().item()
                total = weight.numel()
                if verbose and (t % verbose) == 0:
                    if lookbehind is None:
                        logger(
                            f"obj {saga_obj.item()} weight nnz {nnz}/{total} ({nnz / total:.4f}) criteria {criteria:.4f} {dw} {db}")
                    else:
                        logger(
                            f"obj {saga_obj.item()} weight nnz {nnz}/{total} ({nnz / total:.4f}) obj_best {obj_best}")

                if lookbehind is not None and criteria:
                    logger(
                        f"obj {saga_obj.item()} weight nnz {nnz}/{total} ({nnz / total:.4f}) obj_best {obj_best} [early stop at {t}]")
                    return {
                        "a_table": a_table.cpu(),
                        "w_grad_avg": w_grad_avg.cpu(),
                        "b_grad_avg": b_grad_avg.cpu()
                    }

            logger(f"did not converge at {nepochs} iterations (criteria {criteria})")
            return {
                "a_table": a_table.cpu(),
                "w_grad_avg": w_grad_avg.cpu(),
                "b_grad_avg": b_grad_avg.cpu()
            }

    def glm_saga(self, linear, loader, max_lr, nepochs, alpha, dropout, tries,
                 table_device=None, preprocess=None, group=False,
                 verbose=None, state=None, n_ex=None, n_classes=None,
                 tol=1e-4, epsilon=0.001, k=100, checkpoint=None,
                 do_zero=True, lr_decay_factor=1, metadata=None,
                 val_loader=None, test_loader=None, lookbehind=None,
                 family='multinomial', encoder=None, tot_tries=1):
        if encoder is not None:
            warnings.warn("encoder argument is deprecated; please use preprocess instead", DeprecationWarning)
            preprocess = encoder
        device = get_device(linear)
        checkpoint = self.out_dir
        if preprocess is not None and (device != get_device(preprocess)):
            raise ValueError(
                f"Linear and preprocess must be on same device (got {get_device(linear)} and {get_device(preprocess)})")

        if metadata is not None:
            if n_ex is None:
                n_ex = metadata['X']['num_examples']
            if n_classes is None:
                n_classes = metadata['y']['num_classes']
        lam_fac = (1 + (tries - 1) / tot_tries)
        print("Using lam_fac ", lam_fac)
        max_lam = maximum_reg_loader(loader, group=group, preprocess=preprocess, metadata=metadata,
                                     family=family) / max(
            0.001, alpha) * lam_fac
        group_lam = maximum_reg_loader(loader, group=True, preprocess=preprocess, metadata=metadata,
                                       family=family) / max(
            0.001, alpha) * lam_fac
        min_lam = epsilon * max_lam
        group_min_lam = epsilon * group_lam
        # logspace is base 10 but log is base e so use log10
        lams = ch.logspace(math.log10(max_lam), math.log10(min_lam), k)
        lrs = ch.logspace(math.log10(max_lr), math.log10(max_lr / lr_decay_factor), k)
        found = False
        if do_zero:
            lams = ch.cat([lams, lams.new_zeros(1)])
            lrs = ch.cat([lrs, lrs.new_ones(1) * lrs[-1]])

        path = []
        best_val_loss = float('inf')

        if checkpoint is not None:
            os.makedirs(checkpoint, exist_ok=True)

            file_handler = logging.FileHandler(filename=os.path.join(checkpoint, 'output.log'))
            stdout_handler = logging.StreamHandler(sys.stdout)
            handlers = [file_handler, stdout_handler]

            logging.basicConfig(
                level=logging.DEBUG,
                format='[%(asctime)s] %(levelname)s - %(message)s',
                handlers=handlers
            )
            logger = logging.getLogger('glm_saga').info
        else:
            logger = print
        while self.selected_features.sum() < self.nKeep:  # TODO checkout this change, one iteration per feature
            n_feature_to_keep = self.selected_features.sum()
            for i, (lam, lr) in enumerate(zip(lams, lrs)):
                lam = lam * self.lam_Fac
                start_time = time.time()
                self.selected_features = self.selected_features.to(device)
                state = self.train_saga(linear, loader, lr, nepochs, lam, alpha,
                                        table_device=table_device, preprocess=preprocess, group=group, verbose=verbose,
                                        state=state, n_ex=n_ex, n_classes=n_classes, tol=tol, lookbehind=lookbehind,
                                        family=family, logger=logger)

                with ch.no_grad():
                    loss, acc = elastic_loss_and_acc_loader(linear, loader, lam, alpha, preprocess=preprocess,
                                                            family=family)
                    loss, acc = loss.item(), acc.item()

                    loss_val, acc_val = -1, -1
                    if val_loader:
                        loss_val, acc_val = elastic_loss_and_acc_loader(linear, val_loader, lam, alpha,
                                                                        preprocess=preprocess,
                                                                        family=family)
                        loss_val, acc_val = loss_val.item(), acc_val.item()

                    loss_test, acc_test = -1, -1
                    if test_loader:
                        loss_test, acc_test = elastic_loss_and_acc_loader(linear, test_loader, lam, alpha,
                                                                          preprocess=preprocess, family=family)
                        loss_test, acc_test = loss_test.item(), acc_test.item()

                    params = {
                        "lam": lam,
                        "lr": lr,
                        "alpha": alpha,
                        "time": time.time() - start_time,
                        "loss": loss,
                        "metrics": {
                            "loss_tr": loss,
                            "acc_tr": acc,
                            "loss_val": loss_val,
                            "acc_val": acc_val,
                            "loss_test": loss_test,
                            "acc_test": acc_test,
                        },
                        "weight": linear.weight.detach().cpu().clone(),
                        "bias": linear.bias.detach().cpu().clone()

                    }
                    path.append(params)
                    if loss_val is not None and loss_val < best_val_loss:
                        best_val_loss = loss_val
                        best_params = params
                        found = True
                    nnz = (linear.weight.abs() > 1e-5).sum().item()
                    total = linear.weight.numel()
                    if family == 'multinomial':
                        logger(
                            f"{n_feature_to_keep} Feature ({i}) lambda {lam:.4f}, loss {loss:.4f}, acc {acc:.4f} [val acc {acc_val:.4f}] [test acc {acc_test:.4f}], sparsity {nnz / total} [{nnz}/{total}], time {time.time() - start_time}, lr {lr:.4f}")
                    elif family == 'gaussian':
                        logger(
                            f"({i}) lambda {lam:.4f}, loss {loss:.4f} [val loss {loss_val:.4f}] [test loss {loss_test:.4f}], sparsity {nnz / total} [{nnz}/{total}], time {time.time() - start_time}, lr {lr:.4f}")

                if self.check_new_feature(linear.weight):  # TODO checkout this change, canceling if new feature is used
                    if checkpoint is not None:
                        ch.save(params, os.path.join(checkpoint, f"params{n_feature_to_keep}.pth"))
                    break
        if found:
            return {
                'path': path,
                'best': best_params,
                'state': state
            }
        else:
            return False

    def check_new_feature(self, weight):
        # TODO checkout this change, checking if new feature is used
        copied_weight = torch.tensor(weight.cpu())
        used_features = torch.unique(
            torch.nonzero(copied_weight)[:, 1])
        if len(used_features) > 0:
            new_set = set(used_features.tolist())
            old_set = set(torch.nonzero(self.selected_features)[:, 0].tolist())
            diff = new_set - old_set
            if len(diff) > 0:
                self.selected_features[used_features] = True
                return True
        return False

    def fit(self, feature_loaders, metadata, device):
        # TODO checkout this change, glm saga code slightly adapted to return to_drop
        print("Initializing linear model...")
        linear = nn.Linear(self.num_features, self.n_classes).to(device)
        for p in [linear.weight, linear.bias]:
            p.data.zero_()

        print("Preparing normalization preprocess and indexed dataloader")
        preprocess = NormalizedRepresentation(feature_loaders['train'],
                                              metadata=metadata,
                                              device=linear.weight.device)

        print("Calculating the regularization path")
        mpl_logger = logging.getLogger("matplotlib")
        mpl_logger.setLevel(logging.WARNING)
        selected_features = self.glm_saga(linear,
                                          feature_loaders['train'],
                                          self.args.lr,
                                          self.args.max_epochs,
                                          self.selalpha, 0, 1,
                                          val_loader=feature_loaders['val'],
                                          test_loader=feature_loaders['test'],
                                          n_classes=self.n_classes,
                                          verbose=self.args.verbose,
                                          tol=self.args.tol,
                                          lookbehind=self.args.lookbehind,
                                          lr_decay_factor=self.args.lr_decay_factor,
                                          group=True,
                                          epsilon=self.args.lam_factor,
                                          metadata=metadata,
                                          preprocess=preprocess, tot_tries=1)
        to_drop = np.where(self.selected_features.cpu().numpy() == 0)[0]
        test_acc = selected_features["path"][-1]["metrics"]["acc_test"]
        torch.set_grad_enabled(True)
        return to_drop, test_acc


class NormalizedRepresentation(ch.nn.Module):
    def __init__(self, loader, metadata, device='cuda', tol=1e-5):
        super(NormalizedRepresentation, self).__init__()

        assert metadata is not None
        self.device = device
        self.mu = metadata['X']['mean']
        self.sigma = ch.clamp(metadata['X']['std'], tol)

    def forward(self, X):
        return (X - self.mu.to(self.device)) / self.sigma.to(self.device)


if __name__ == '__main__':
    # Default args from glm_saga, https://github.com/MadryLab/glm_saga
    parser = ArgumentParser()
    parser.add_argument('--dataset', type=str, help='dataset name')
    parser.add_argument('--dataset-type', type=str, help='One of ["language", "vision"]')
    parser.add_argument('--dataset-path', type=str, help='path to dataset')
    parser.add_argument('--model-path', type=str, help='path to model checkpoint')
    parser.add_argument('--arch', type=str, help='model architecture type')
    parser.add_argument('--out-path', help='location for saving results')
    parser.add_argument('--cache', action='store_true', help='cache deep features')
    parser.add_argument('--balance', action='store_true', help='balance classes for evaluation')

    parser.add_argument('--device', default='cuda')
    parser.add_argument('--random-seed', default=0)
    parser.add_argument('--num-workers', type=int, default=2)
    parser.add_argument('--batch-size', type=int, default=256)
    parser.add_argument('--val-frac', type=float, default=0.1)
    parser.add_argument('--lr-decay-factor', type=float, default=1)
    parser.add_argument('--lr', type=float, default=0.1)
    parser.add_argument('--alpha', type=float, default=0.99)
    parser.add_argument('--max-epochs', type=int, default=2000)
    parser.add_argument('--verbose', type=int, default=200)
    parser.add_argument('--tol', type=float, default=1e-4)
    parser.add_argument('--lookbehind', type=int, default=3)
    parser.add_argument('--lam-factor', type=float, default=0.001)
    parser.add_argument('--group', action='store_true')
    args = parser.parse_args()

    args = parser.parse_args()
