import torch
import os
import prettytable
from utils.conv_type import ConvMask, Conv1dMask, LinearMask, STRConv, ConvMaskMW, Conv1dMaskMW, LinearMaskMW
from datetime import datetime
import uuid
import numpy as np
import random
import yaml
import torch.nn as nn
import torch.distributed as dist

from fastargs import get_current_config
from fastargs.decorators import param
from typing import Any, Dict, Optional, Tuple 


get_current_config()


def sync_model_weights_all_reduce(model):
    """
    Synchronize model weights across all GPUs using all-reduce.
    This will average the weights across all ranks.
    """
    for param in model.parameters():
        # Perform all-reduce to sum the parameters across all processes
        dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
        # Divide by the world size to average the weights
        param.data /= dist.get_world_size()

def rescale_mw_model(model):
    for n, m in model.named_modules():
        if isinstance(m, (ConvMaskMW, Conv1dMaskMW, LinearMaskMW)):
            m.rescale_mw()

@param("prune_params.mw_l1_floor")
@param("prune_params.mw_l1_ceil")
def anneal_scale_mw_l1(step, total_steps, mw_l1_floor=5e-6, mw_l1_ceil=5e-5):
    if mw_l1_floor == 0 and mw_l1_ceil == 0:
        return 0
    x = np.linspace(0, 1, total_steps)
    smooth_growth = x**2
    schedule = mw_l1_floor * (mw_l1_ceil / mw_l1_floor) ** smooth_growth
    return schedule[step]
            
@param("prune_params.mw_wd_floor")
@param("prune_params.mw_wd_ceil")
def anneal_scale_mw(step, total_steps, mw_wd_floor=5e-6, mw_wd_ceil=5e-5):
    if mw_wd_floor == 0 and mw_wd_ceil == 0:
        return 0
    x = np.linspace(0, 1, total_steps)
    smooth_growth = x**2
    schedule = mw_wd_floor * (mw_wd_ceil / mw_wd_floor) ** smooth_growth
    return schedule[step]

def get_mw_l1(model):
    curr_loss = 0   
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMaskMW, Conv1dMaskMW, LinearMaskMW)):
            curr_loss += ((m.weight ** 2 +  m.m ** 2) * m.mask).sum()
            cnt += 1
    return curr_loss

def get_mw_wd(model):
    curr_loss = 0   
    cnt = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMaskMW, Conv1dMaskMW, LinearMaskMW)):
            curr_loss += ((m.weight * m.m * m.mask) ** 2).sum()
            cnt += 1
    return curr_loss

def get_model_mask(model):
    """
    Returns a list of mask elements
    """
    mask_list = []
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            mask_list.append(m.mask)
        if isinstance(m, (STRConv)):
            mask_list.append(m.get_mask())
    return mask_list

class GradientTracker:
    def __init__(self, window, dst_every, T_end):
        """
        Initialize the gradient tracker.
        Args:
            n_batches (int): Number of batches to track for the gradient accumulation.
        """
        self.window = window
        self.dst_every = dst_every
        self.T_end = T_end
        self.step_count = 0

        self.gradient_history = {}

    def register_hooks(self, model):
        """
        Register backward hooks to track gradients for each parameter.

        Args:
            model (torch.nn.Module): The model whose gradients are to be tracked.
        """
        for n, m in model.named_modules():
            if isinstance(m, (ConvMask, Conv1dMask, LinearMask, STRConv)):
                if m not in self.gradient_history:
                    self.gradient_history[n] = None
                m.weight.register_hook(self._create_hook(n))

    @param("experiment_params.dense_grad")
    def _create_hook(self, name, dense_grad=False):
        """
        Create a hook function for a specific parameter.
        Args:
            nam (str): The name of the module.
        Returns:
            callable: A hook function.
        """
        def hook(grad):
            # Append the current gradient to the history
            if self.check_if_backward_hook_should_accumulate_grad():
                if self.gradient_history[name] is None:
                    self.gradient_history[name] = torch.zeros_like(grad)
                self.gradient_history[name] += grad.clone().detach() / self.window
            else:
                self.gradient_history[name] = None

        return hook

    def update_step(self):
        self.step_count += 1

    def check_if_backward_hook_should_accumulate_grad(self):
        """
        Used by the backward hooks. Basically just checks how far away the next rigl step is, 
        if it's within `self.grad_accumulation_n` steps, return True.
        """

        if self.step_count >= self.T_end:
            return False

        steps_til_next_rigl_step = self.dst_every - (self.step_count % self.dst_every)
        return steps_til_next_rigl_step <= self.window

    def get_acc_grads(self):
        return self.gradient_history

def mask_momentum_rigl(model, optimizer):
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            if m.weight in optimizer.state:
                state = optimizer.state[m.weight]
                state['momentum_buffer'] *= m.mask
    return optimizer

def get_gradient_norm(model: nn.Module, masked=False) -> float:
    """Compute grad norm of model
    Args:
        model (torch.nn.Module): The model to reset.

    Returns:
        (float).
    """
    total_norm = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask, STRConv)):
            if masked and isinstance(m, STRConv):
                param_norm = (m.get_mask() * m.weight.grad.detach().cpu()).norm()
            elif masked and isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
                if isinstance(m, (ConvMaskMW, Conv1dMaskMW, LinearMaskMW)):
                    param_norm = (m.mask * m.m.data.detach() * m.weight.grad.detach() + m.mask * m.m.grad.detach() * m.weight.data.detach()).cpu().norm()
                else:
                    param_norm = (m.mask * m.weight.grad.detach()).cpu().norm()
            else:
                if isinstance(m, (ConvMaskMW, Conv1dMaskMW, LinearMaskMW)):
                    param_norm = (m.m.data.detach() * m.weight.grad.detach() + m.m.grad.detach() * m.weight.data.detach()).cpu().norm()
                else:
                    param_norm = m.weight.grad.detach().cpu().norm()
            total_norm += param_norm.item() ** 2

    return total_norm

def get_norm(model: nn.Module, masked=False) -> float:
    """Compute weight norm of model
    Args:
        model (torch.nn.Module): The model to reset.

    Returns:
        (float).
    """
    total_norm = 0
    for n, m in model.named_modules():
        if isinstance(m, (ConvMask, Conv1dMask, LinearMask, STRConv)):
            if masked and isinstance(m, STRConv):
                param_norm = (m.get_mask() * m.weight.data.detach().cpu()).norm()
            elif masked and isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
                if isinstance(m, (Conv1dMaskMW, Conv1dMaskMW, LinearMaskMW)):
                    param_norm = (m.mask * m.weight.data.detach() * m.m.data.detach()).cpu().norm()
                else:
                    param_norm = (m.mask * m.weight.data.detach()).cpu().norm()
            else:
                if isinstance(m, (Conv1dMaskMW, Conv1dMaskMW, LinearMaskMW)):
                    param_norm = (m.weight.data.detach() * m.m.data.detach()).cpu().norm()
                else:
                    param_norm = (m.weight.data.detach()).cpu().norm()

            total_norm += param_norm.item() ** 2

    return total_norm

def reset_weights(
    expt_dir: str, model: torch.nn.Module, training_type: str
) -> torch.nn.Module:
    """Reset (or don't) the weight to a given checkpoint based on the provided training type.

    Args:
        expt_dir (str): Directory of the experiment.
        model (torch.nn.Module): The model to reset.
        training_type (str): Type of training ('imp', 'wr', or 'lrr').

    Returns:
        torch.nn.Module: The model with reset weights.
    """
    if training_type == "imp":
        original_dict = torch.load(
            os.path.join(expt_dir, "checkpoints", "model_init.pt")
        )
    elif training_type == "wr":
        original_dict = torch.load(
            os.path.join(expt_dir, "checkpoints", "model_rewind.pt")
        )
    else:
        print("probably LRR, aint nothing to do -- or if PaI, we aren't touching it any case.")
        return model

    # not sure if this is needed, maybe we are better off without it.
    # original_weights = dict(
    #     filter(lambda v: v[0].endswith((".weight", ".bias")), original_dict.items())
    # )
    model_dict = model.state_dict()
    model_dict.update(original_dict)
    model.load_state_dict(model_dict)

    return model


def reset_optimizer(
    expt_dir: str, optimizer: torch.optim.Optimizer, training_type: str
) -> torch.optim.Optimizer:
    """Reset the optimizer state based on the provided training type.

    Args:
        expt_dir (str): Directory of the experiment.
        optimizer (torch.optim.Optimizer): The optimizer to reset.
        training_type (str): Type of training ('imp', 'wr', or 'lrr').

    Returns:
        torch.optim.Optimizer: The reset optimizer.
    """
    if training_type in {"imp", "lrr"}:
        optimizer.load_state_dict(
            torch.load(os.path.join(expt_dir, "artifacts", "optimizer_init.pt"))
        )
    elif training_type == "wr":
        optimizer.load_state_dict(
            torch.load(os.path.join(expt_dir, "artifacts", "optimizer_rewind.pt"))
        )

    return optimizer


def reset_only_weights(expt_dir: str, ckpt_name: str, model: torch.nn.Module) -> None:
    """Reset only the weights of the model from a specified checkpoint.

    Args:
        expt_dir (str): Directory of the experiment.
        ckpt_name (str): Checkpoint name.
        model (torch.nn.Module): The model to reset.
    """
    original_dict = torch.load(os.path.join(expt_dir, "checkpoints", ckpt_name))
    original_weights = dict(
        filter(lambda v: v[0].endswith((".weight", ".bias")), original_dict.items())
    )
    model_dict = model.state_dict()
    model_dict.update(original_weights)
    model.load_state_dict(model_dict)


def reset_only_masks(expt_dir: str, ckpt_name: str, model: torch.nn.Module) -> None:
    """Reset only the masks of the model from a specified checkpoint.

    Args:
        expt_dir (str): Directory of the experiment.
        ckpt_name (str): Checkpoint name.
        model (torch.nn.Module): The model to reset.
    """
    original_dict = torch.load(os.path.join(expt_dir, "checkpoints", ckpt_name))
    original_weights = dict(
        filter(lambda v: v[0].endswith(".mask"), original_dict.items())
    )
    model_dict = model.state_dict()
    model_dict.update(original_weights)
    model.load_state_dict(model_dict)


def get_model_density(model: nn.Module) -> float:
    """
    Compute and return the model density
    """
    nz = 0
    total = 0
    for n, m in model.named_modules():
        if isinstance(m, STRConv):
            mask = m.get_mask()
            nz += mask.sum()
            total += mask.numel()

        if isinstance(m, (ConvMask, Conv1dMask, LinearMask)):
            nz += m.mask.sum()
            total += m.mask.numel()
    
    return nz / total

def compute_sparsity(tensor: torch.Tensor) -> Tuple[float, int, int]:
    """Compute the sparsity of a given tensor. Sparsity = number of elements which are 0 in the mask.

    Args:
        tensor (torch.Tensor): The tensor to compute sparsity for.

    Returns:
        tuple: Sparsity, number of non-zero elements, and total elements.
    """
    remaining = tensor.sum().item()
    total = tensor.numel()
    sparsity = 1.0 - (remaining / total)
    return sparsity, remaining, total


def print_sparsity_info(model: torch.nn.Module, verbose: bool = True) -> float:
    """Print and return the sparsity information of the model.

    Args:
        model (torch.nn.Module): The model to check.
        verbose (bool, optional): Whether to print detailed sparsity info of each layer. Default is True.

    Returns:
        float: Overall sparsity of the model.
    """
    my_table = prettytable.PrettyTable()
    my_table.field_names = ["Layer Name", "Layer Sparsity", "Density", "Non-zero/Total"]
    total_params = 0
    total_params_kept = 0
    
    for name, layer in model.named_modules():
        if isinstance(layer, (ConvMask, Conv1dMask, LinearMask)):
            weight_mask = layer.mask
            sparsity, remaining, total = compute_sparsity(weight_mask)
            my_table.add_row([name, sparsity, 1 - sparsity, f"{remaining}/{total}"])
            total_params += total
            total_params_kept += remaining
        
        if isinstance(layer, (STRConv)):
            weight_mask = layer.get_mask()
            sparsity, remaining, total = compute_sparsity(weight_mask)
            my_table.add_row([name, sparsity, 1 - sparsity, f"{remaining}/{total}"])
            total_params += total
            total_params_kept += remaining


    overall_sparsity = 1 - (total_params_kept / total_params)

    if verbose:
        print(my_table)
        print("-----------")
        print(f"Overall Sparsity of All Layers: {overall_sparsity:.4f}")
        print("-----------")

    return overall_sparsity


@param("experiment_params.base_dir")
@param("experiment_params.resume_level")
@param("experiment_params.expt_name")
@param("experiment_params.resume_expt_name")
def gen_expt_dir(
    base_dir: str, resume_level: int, expt_name: str, resume_expt_name: Optional[str] = None
) -> str:
    """Create a new experiment directory and all the necessary subdirectories.
       If provided, instead of creating a new directory -- set the directory to the one provided.

    Args:
        base_dir (str): Base directory for experiments.
        resume_level (int): Level to resume from.
        resume_expt_name (str, optional): Name of the experiment to resume from. Default is None.

    Returns:
        str: Path to the experiment directory.
    """
    if resume_level != 0 and resume_expt_name:
        expt_dir = os.path.join(base_dir, resume_expt_name)
        print(f"Resuming from Level -- {resume_level}")
    elif resume_level == 0 and resume_expt_name is None:
        current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
        unique_id = uuid.uuid4().hex[:6]
        unique_name = f"experiment_{expt_name}_{current_time}_{unique_id}"
        expt_dir = os.path.join(base_dir, unique_name)
        print(f"Creating this Folder {expt_dir}:)")
    else:
        raise AssertionError(
            "Either start from scratch, or provide a path to the checkpoint :)"
        )

    if not os.path.exists(expt_dir):
        os.makedirs(expt_dir)
        os.makedirs(f"{expt_dir}/checkpoints")
        os.makedirs(f"{expt_dir}/metrics")
        os.makedirs(f"{expt_dir}/metrics/epochwise_metrics")
        os.makedirs(f"{expt_dir}/artifacts/")

    return expt_dir


@param("experiment_params.seed")
def set_seed(seed: int, is_deterministic: bool = False) -> None:
    """Set the random seed for reproducibility.

    Args:
        seed (int): Seed value.
        is_deterministic (bool, optional): Whether to set deterministic behavior. Default is False.
    """
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    if is_deterministic:
        print("This ran")
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")


@param("prune_params.prune_method")
@param("prune_params.num_levels")
@param("prune_params.prune_rate")
def generate_densities(prune_method: str, num_levels: int, prune_rate: float
) -> list[float]:
    """Generate a list of densities for pruning. The density is calculated as (1 - prune_rate) ^ i where i is the sparsity level.
       For example, if prune_rate = 0.2 and the num_levels = 5, the densities will be [1.0, 0.8, 0.64, 0.512, 0.4096].
    Args:
        prune_method (str): Method of pruning.
        num_levels (int): Number of pruning levels.
        prune_rate (float): Rate of pruning.

    Returns:
        list[float]: List of densities for each level.
    """
    if num_levels == 0:
        return None
    densities = [(1 - prune_rate) ** i for i in range(num_levels)]
    return densities


def save_config(expt_dir: str, config: Any) -> None:
    """Save the experiment configuration to a YAML file in the experiment directory.

    Args:
        expt_dir (str): Directory of the experiment.
        config (Any): Configuration to save.
    """
    nested_dict: Dict[str, Dict[str, Any]] = {}
    for (outer_key, inner_key), value in config.content.items():
        if outer_key not in nested_dict:
            nested_dict[outer_key] = {}
        nested_dict[outer_key][inner_key] = value

    with open(os.path.join(expt_dir, "expt_config.yaml"), "w") as file:
        yaml.dump(nested_dict, file, default_flow_style=False)

def save_ckpt(expt_dir, level, epoch, optimizer, model):
    
    torch.save(
        optimizer.state_dict(),
        os.path.join(expt_dir, "artifacts", "optimizer_{}_{}.pt".format(level, epoch)),
    )
    torch.save(
        model.module.state_dict(),
        os.path.join(expt_dir, "checkpoints", "model_{}_{}.pt".format(level, epoch)),
    )