import copy
import torch
import cvxpy as cp
from tqdm import tqdm
from train import train_eval_loop
import numpy as np
from Models import cifar_mobilenetv2, imagenet_mobilenetv2
from Layers.layers import replace_activation
from scipy.optimize import fsolve


def sparsity_schedule(sparsity, schedule, cur_epoch, epochs):
    if schedule == 'exponential':
        sparse = sparsity**((cur_epoch + 1) / epochs)
    elif schedule == 'linear':
        sparse = 1.0 - (1.0 - sparsity) * ((cur_epoch + 1) / epochs)
    elif schedule == 'constant':
        sparse = sparsity
    # linear increase is used by default
    elif schedule == 'increase':
        sparse = sparsity * (cur_epoch + 1) / epochs
    return sparse


def prune_loop(model,
               loss,
               pruner,
               dataloader,
               device,
               sparsity,
               schedule,
               scope,
               epochs,
               reinitialize=False,
               reinitialize_sparse=True,
               train_mode=False,
               shuffle=False,
               uniform_shuffle=False,
               invert=False,
               verbose=False,
               random=False):
    r"""Applies score mask loop iter\newtheorem{theorem}{Theorem}
atively to a final sparsity level.
    """
    if epochs == 0:
        return

    # replace all ReLU6 with ReLU to help the gradient flow
    if isinstance(model, cifar_mobilenetv2.MobileNetV2) or isinstance(model, imagenet_mobilenetv2.MobileNetV2):
        replace_activation(model, 'relu')

    # Set model to train or eval mode
    model.train()
    if not train_mode:
        model.eval()
    remaining_params, total_params = pruner.stats()
    if 'oneshot' in scope or scope == 'perturb_global':
        epochs = 1

    for_range = range(epochs)
    if verbose:
        for_range = tqdm(for_range)
    for epoch in for_range:
        pruner.score(model, loss, dataloader, device, double=True)
        sparse = sparsity_schedule(sparsity, schedule, epoch, epochs)
        if invert:
            pruner.invert()
        pruner.mask(sparse, scope, verbose)
        remaining_params, total_params = pruner.stats()
        if remaining_params / total_params < sparsity:
            print(f'early stop at {epoch}/{epochs}')
            break
    model.to(device)

    # Reainitialize weights
    if reinitialize:
        model._initialize_weights(sparse_init=reinitialize_sparse)

    # Shuffle masks
    if shuffle:
        pruner.shuffle(filter=mask_scope == 'filter')
    if uniform_shuffle:
        pruner.uniform_shuffle()

    # replace all ReLU back with ReLU6
    if isinstance(model, cifar_mobilenetv2.MobileNetV2) or isinstance(model, imagenet_mobilenetv2.MobileNetV2):
        replace_activation(model, 'relu6')

    # Confirm sparsity level
    remaining_params, total_params = pruner.stats()
    if np.abs(remaining_params - total_params * sparsity) >= 10:
        print("WARNING: {} prunable parameters remaining, expected {}".format(
            remaining_params, total_params * sparsity))

    mask_list = []


def lt_prune(model,
             new_model,
             pruner,
             new_pruner,
             sparsity,
             scope,
             loss,
             optimizer,
             scheduler,
             train_loader,
             test_loader,
             device,
             epochs,
             verbose,
             result_dir,
             resume=None):
    if resume is None:
        torch.save(model.state_dict(), "{}/init_model.pt".format(result_dir))
        train_eval_loop(model, loss, optimizer, scheduler,
                        train_loader, test_loader, device, epochs, verbose)
        torch.save(model.state_dict(),
                   "{}/trained_model.pt".format(result_dir))
    else:
        new_model.load_state_dict(torch.load(
            "{}/init_model.pt".format(resume)))
        model.load_state_dict(torch.load("{}/trained_model.pt".format(resume)))

    # only for storing the model and params for later analysis
    new_pruner.score(new_model, loss, train_loader, device)

    pruner.score(model, loss, train_loader, device)
    pruner.mask(sparsity, 'global', verbose)
    if scope == 'global':
        for (new_name, new_module), (name, module) in zip(new_model.named_modules(), model.named_modules()):
            if hasattr(module, 'weight_mask'):
                new_module.weight_mask.copy_(module.weight_mask)
    else:
        sparsity_list = []
        for mask, param in pruner.masked_parameters:
            sparsity_list.append(mask.sum().item() / mask.numel())
        print(sparsity_list)
        new_pruner.sparsity_list = sparsity_list
        new_pruner.mask(sparsity, scope, verbose)

    del model
    return new_model


def single_solve(alpha, compression):
    alpha = np.array(alpha)
    sparsity_list = cp.Variable(alpha.size)
    budget = alpha.sum() * compression
    constraints = [
        alpha @ sparsity_list <= budget,
        sparsity_list <= 1,
        sparsity_list >= 1e-4
    ]
    objective = cp.Maximize(cp.sum(cp.log(sparsity_list)))
    prob = cp.Problem(objective, constraints)
    prob.solve()
    return sparsity_list.value


def double_solve(alpha, beta, compression, flops_compression, expand=False, expand_limit=None):
    alpha = np.array(alpha)
    beta = np.array(beta)
    assert alpha.shape == beta.shape, 'alpha and beta must have the same shape'
    sparsity_list = cp.Variable(alpha.size)
    budget = alpha.sum() * compression
    budget_flops = beta.sum() * flops_compression
    if not expand:
        constraints = [
            alpha @ sparsity_list <= budget,
            beta @ sparsity_list <= budget_flops,
            sparsity_list <= 1,
            sparsity_list >= 1e-4
        ]
    elif expand_limit is None:
        constraints = [
            alpha @ sparsity_list <= budget,
            beta @ sparsity_list <= budget_flops,
            sparsity_list >= 1e-4
        ]   
    else:
        constraints = [
            alpha @ sparsity_list <= budget,
            beta @ sparsity_list <= budget_flops,
            sparsity_list <= expand_limit,
            sparsity_list >= 1e-4
        ]
    objective = cp.Maximize(cp.sum(cp.log(sparsity_list)))
    prob = cp.Problem(objective, constraints)
    prob.solve()
    return sparsity_list.value
