import json
import logging
from os import path

import torch
from tqdm.auto import tqdm

from .debug_wrapper import DebugContainer

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

DEBUG_KEY_NAME = "debug_tensors"
DEBUG_FILE_NAME = "debug_container.pt"

class ADMM:
    def __init__(self, 
                    model: torch.nn.Module,
                    file_name: str,
                    rho=0.001
    ):
        self._u = {}
        self._z = {}

        self._rho = rho

        self._debug_enabled = False
        self._is_hard_pruned = False

        if not isinstance(file_name, str):
            raise Exception("filename must be a str")
        with open(file_name, "r") as stream:
            self._raw_dict = json.load(stream)
            self._prune_ratios = self._raw_dict['prune_ratios']

            for (name, weight) in model.named_parameters():
                if name not in self._prune_ratios:
                    continue
                self._u[name] = torch.zeros(weight.shape).cuda()  # add U
                self._z[name] = torch.Tensor(weight.shape).cuda()  # add Z
            
            if DEBUG_KEY_NAME in self._raw_dict:
                self._debug_enabled = True

                if type(self._raw_dict[DEBUG_KEY_NAME]) != list:
                    raise Exception("Debug key should correspond with a list of parameter names")

                for name in self._raw_dict[DEBUG_KEY_NAME]:
                    if name not in self._prune_ratios:
                        raise Exception("Attempting to get debug information for a non-pruned parameter")
                
                logger.info("Debugging enabled for {}".format(", ".join(self._raw_dict[DEBUG_KEY_NAME])))
                self._debug_container = DebugContainer(self._raw_dict[DEBUG_KEY_NAME])
        
    def hard_prune(self, model: torch.nn.Module):
        """
        Prune each weight in the model directly to the threshold rather than by masking

        Args:
            model (torch.nn.Module): PyTorch model
        """
        
        logger.info("Performing hard prune.")

        self._is_hard_pruned = True
        # We store the masks and the old weight values into previous 
        self._masks = dict.fromkeys(self._z)
        self._original_weights = dict.fromkeys(self._z)

        for name, weight in model.named_parameters():
            if name not in self._prune_ratios:
                continue

            # No longer doing a true projection so we want to remove the dual-variable from consideration
            # (I think? this may be right but probably need to test this)
            u_temp = self._u[name]
            self._u[name] = torch.zeros(weight.shape)
            
            weight_mask, pruned_weights = self.project(weight, name)

            # Restore U variable
            self._u[name] = u_temp

            self._masks[name] = weight_mask
            self._original_weights[name] = weight.data
            
            if self._debug_enabled:
                if name in self._debug_container.debugParameters:
                    pruned_tensor = pruned_weights.cpu()
                    weight_tensor = weight.data.cpu()
                    mask_tensor = weight_mask.cpu()

                    self._debug_container.addTensor(name, pruned_tensor, weight_tensor, mask_tensor)

            weight.data = pruned_weights
    
    def restore_weights(self, model: torch.nn.Module):
        """
        Restore the weights to an unpruned state. This is necessary in order to do end of epoch evaluation for
        ADMM configurations but still be able to continue training.

        Args:
            model (torch.nn.Module): PyTorch model

        """
        
        logger.info("Restoring weight magnitudes.")

        if not self._is_hard_pruned:
            raise Exception("Attempted to restore weights on a model that did not undergo hard-prune.")

        for name, weight in model.named_parameters():
            if name not in self._prune_ratios:
                continue
            
            weight.data = self._original_weights[name]
        
        self._is_hard_pruned = False
    
    def mask_weights(self, model: torch.nn.Module):
        """
        Used in masked fine-tuning of the model. Must have already called hard_prune on the model at this point.
        This method should be called after the optimizer has updated the weights for a training step.

        Args:
            model (torch.nn.Module): PyTorch model.
        """

        if not self._is_hard_pruned:
            raise Exception("Attempted to mask weights without performing appropriate hard prune")

        for name, weight in model.named_parameters():
            if name not in self._prune_ratios:
                continue
            
            weight.data = weight.data * self._masks[name]

    def initialize(self, model: torch.nn.Module):
        """
        Initialize the Z variable for each parameter.

        Args:
            model (torch.nn.Module) : PyTorch model
        """

        logger.info("Initializing Z")

        for name, weight in model.named_parameters():
            if name in self._prune_ratios:
                mask, updated_z = self.project(weight, name)
                self._z[name] = updated_z

                if self._debug_enabled:
                    if name in self._debug_container._tracked_parameters:
                        z_tensor = self._z[name].cpu()
                        w_tensor = weight.data.cpu()
                        mask_tensor = mask.cpu()

                        self._debug_container.addTensor(name, z_tensor, w_tensor, mask_tensor)

    def z_u_update(self, model: torch.nn.Module):
        """
        Update the ADMM auxiliary variables as determined by the ADMM instance's weight_pruning rule.

        Args:
            model (torch.nn.Module): PyTorch model
        """

        logger.info("Updating Z")
        
        for name, weight in model.named_parameters():
            if name not in self._prune_ratios:
                continue
            
            mask, self._z[name] = self.project(weight, name) # Euclidean projection

            # U[k+1] = W[k+1] - Z[k+1] + U[k]
            self._u[name] = weight - self._z[name] + self._u[name]

            if self._debug_enabled:
                if name in self._debug_container._tracked_parameters:
                    z_tensor = self._z[name].cpu()
                    w_tensor = weight.data.cpu()
                    mask_tensor = mask.cpu()

                    self._debug_container.addTensor(name, z_tensor, w_tensor, mask_tensor)

    def loss(self, model: torch.nn.Module):
        """
        Append ADMM loss to model loss (somewhat assumed to be cross-entropy but should work for others)

        Args:
            model (torch.nn.Module): PyTorch model

        Returns:
            loss (float): sum of loss terms
        """
        loss = 0
        for name, weight in model.named_parameters():
            if name not in self._prune_ratios:
                continue
            
            norm = torch.linalg.norm(weight - self._z[name] + self._u[name]) ** 2
            loss += 0.5 * self._rho * norm
        
        return loss

    def save_debug_information(self, save_path):
        if not self._debug_enabled:
            # Ideally this shouldn't ever be triggered
            raise Exception("Can't save debug information if not in debug mode.")

        destination = path.join(save_path, DEBUG_FILE_NAME)

        if not path.isdir(save_path):
            raise Exception("Destination directory ({}) does not exist".format(save_path))

        if not self._debug_container.validateStructure():
            raise Exception("Debug container not constructed correctly, aborting.")
        
        logger.info("Saving debug model container to {}".format(destination))
        torch.save(self._debug_container, destination)

    
    def get_prune_ratio(self, name: str):
        return self._prune_ratios[name]

    @property
    def u(self):
        """Dictionary for U auxialiary ADMM variables."""
        return self._u

    @property
    def z(self):
        """Dictionary for Z auxialiary ADMM variables."""
        return self._z

    @property
    def rho(self):
        """ADMM rho parameter."""
        return self._rho

    @property
    def prune_ratios(self):
        """Dictionary from named parameter to associated pruning ratio."""
        return self._prune_ratios
    
    @property
    def does_accumulate(self):
        """Whether or not the weight pruning system needs gradients."""
        return False

    @property
    def does_debug(self):
        """Whether or not we have been collecting additional information."""
        return self._debug_enabled
    
    @property
    def is_hard_pruned(self):
        """Whether or not the model is currently in a hard-pruned state."""
        return self._is_hard_pruned

    def project(self, weight: torch.Tensor, name: str):
        pass