import os
import copy
import torch
import math
import time
import pickle
import wandb
import gzip
from torch.distributed import is_initialized, get_rank, all_reduce, ReduceOp
from tools import get_first_device, get_gpu_mem_usage, block_split, CopyDirection
from torch.optim import Optimizer
import ista_daslab_tools
import ista_daslab_micro_adam
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, Optional


class NanoAdam(Optimizer):
    """
    1. select weights that have  smallest magnitudes to update. Instead of topk gradients.
    2. dynamically change the update density, and the mask chosen by topk is updated every interval steps.
    """

    def __init__(
        self,
        params,
        lr: float,
        k_init: float = 0.01,
        largest: bool = False,
        betas: Tuple[float, float] = (0.9, 0.999),
        weight_decay: float = 0,
        eps: float = 1e-8,
        log_every: int = 100,
        total_steps: int = 1,
        # {"classifier", "layernorm"}, set(),
        exclude_layers: set = {"classifier", "layernorm"},
        dynamic_density: bool = False,
        mask_interval: int = 100,
        density_interval: int = 100,
        mask_criterion: str = "weights",
    ):
        defaults = dict(lr=lr, weight_decay=weight_decay, eps=eps)
        super(NanoAdam, self).__init__(params, defaults)

        self.lr = lr
        self.k_init = k_init
        self.end_density = 0.04
        self.mask_criterion = mask_criterion
        if dynamic_density:
            assert (
                self.k_init > self.end_density
            ), f"init density {self.k_init} should be larger than end_density {self.end_density}."

        self.k_current = k_init
        self.largest = largest
        self.weight_decay = weight_decay
        self.beta1, self.beta2 = betas
        self.eps = eps
        self.total_steps = total_steps
        self.exclude_layers = set(exclude_layers) if exclude_layers else set()
        self.dynamic_density = dynamic_density
        self.mask_interval = mask_interval if mask_interval > 0 else int(
            total_steps/2)
        self.density_interval = density_interval

        self.model_size = sum(
            [p.numel() for group in self.param_groups for p in group["params"]]
        )

        self.steps = 0  # how many optimization steps were performed so far
        self.log_every = log_every

        self.device = get_first_device()
        self.blocks = ista_daslab_tools.get_sm_count()
        self.d_block_size = (
            ista_daslab_tools.get_max_floats_for_shared_memory_per_thread_block() // 2
        )
        # for a100
        # self.blocks = 108
        # self.d_block_size = 20736
        # for a6000
        # self.blocks = 84
        # self.d_block_size = 6144

        self.reduce_mem = True
        self.log_effective_lr = False
        if not self.reduce_mem:
            self.update_frequencies = {}
            self.updated_params = {}
            for group in self.param_groups:
                for name, p in zip(group["names"], group["params"]):

                    self.update_frequencies[name] = torch.zeros_like(
                        p.data.view(-1), dtype=torch.int
                    )
                    self.updated_params[name] = []

        self.log_microadam_statistics = True

    def _initialize_parameter_state(self, name, p, lr, wd):
        layer_size = p.numel()
        st = self.state[p]

        rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
        if is_exclude_layer(name, self.exclude_layers):
            # Use in-place operations with preserved memory format
            st.update(
                {
                    # Exponential moving average of gradient values
                    "exp_avg": torch.zeros_like(p, memory_format=torch.preserve_format),
                    # Exponential moving average of squared gradient values
                    "exp_avg_sq": torch.zeros_like(
                        p, memory_format=torch.preserve_format
                    ),
                }
            )
            if self.log_effective_lr:
                st.update(
                    {
                        "effective_lr": torch.zeros_like(
                            p, memory_format=torch.preserve_format
                        ),

                    }
                )

        else:
            # Precompute reusable values
            d_block_size = min(layer_size, self.d_block_size)

            # Compute block parameters once
            topk_blocks, d_index_topk = block_split(layer_size, d_block_size)
            k_block_many = int(math.ceil(d_block_size * self.k_init))
            k_block_few = int(
                math.ceil((layer_size - d_index_topk) * self.k_init))
            k_index = topk_blocks * k_block_many
            k = topk_blocks * k_block_many + k_block_few

            # Memory-efficient tensor initialization
            st.update(
                {
                    "d": layer_size,
                    "d_block_size": d_block_size,
                    "topk_full_blocks_count": topk_blocks,
                    "d_index_topk": d_index_topk,
                    "k_block_size_many": k_block_many,
                    "k_block_size_few": k_block_few,
                    "k_index": k_index,
                    "k": k,
                    "I": torch.zeros(k, dtype=torch.int16, device=self.device),
                    # Exponential moving average of gradient values
                    "exp_avg": torch.zeros(k, dtype=torch.bfloat16, device=self.device),
                    # Exponential moving average of squared gradient values
                    "exp_avg_sq": torch.zeros(
                        k, dtype=torch.bfloat16, device=self.device
                    ),
                }
            )
            if self.log_effective_lr:
                st.update(
                    {
                        "effective_lr": torch.zeros(
                            k, dtype=torch.bfloat16, device=self.device
                        ),
                    }
                )

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        self.steps += 1
        # self.mask_update_flag = False
        self._initialize_wandb_dir()

        time_start = time.time()
        sparsity_u = self._update_parameters()
        elapsed_step = time.time() - time_start

        self._log_statistics(sparsity_u, elapsed_step)

        return loss

    def _initialize_wandb_dir(self):
        if self.steps == 1:
            rank = (
                torch.distributed.get_rank()
                if torch.distributed.is_initialized()
                else 0
            )
            if rank == 0:
                self.wandb_dir = wandb.run.dir

    def _update_parameters(self):
        sparsity_u = 0
        for group in self.param_groups:
            lr = group["lr"]
            wd = group.get("weight_decay", self.weight_decay)
            for name, p in zip(group["names"], group["params"]):
                if p.grad is None or p is None:
                    continue
                sp_u = self.update_step(p, lr, wd, name=name)

                sparsity_u += sp_u

        return sparsity_u

    def _log_statistics(self, sparsity_u, elapsed_step):
        if self.log_microadam_statistics:
            self._log(sparsity_u, elapsed_step)
        if not self.reduce_mem and self.steps % self.log_every == 0:
            self.update_freq_to_wandb()

    @torch.no_grad()
    def update_step(
        self,
        p,
        lr,
        wd,
        name="",
    ):
        st = self.state[p]
        if not st:
            self._initialize_parameter_state(name, p, lr, wd)
            st = self.state[p]

        log_microadam_statistics = self.log_microadam_statistics and self.steps % self.log_every == 0
        sp_u = 0

        if is_exclude_layer(name, self.exclude_layers):
            exp_avg, exp_avg_sq = st["exp_avg"], st["exp_avg_sq"]
            grad = p.grad

            # Apply weight decay directly to the parameter (AdamW-style decay)
            if wd > 0:
                p.data.mul_(1 - lr * wd)

            # Update biased first and second moment estimates
            # m_t = β1 * m_t-1 + (1 - β1) * g_t
            exp_avg.mul_(self.beta1).add_(grad, alpha=1 - self.beta1)
            # v_t = β2 * v_t-1 + (1 - β2) * g_t^2
            exp_avg_sq.mul_(self.beta2).addcmul_(
                grad, grad, value=1 - self.beta2)

            # Bias correction
            bias_correction1 = 1 - self.beta1**self.steps
            bias_correction2_sqrt = (1 - self.beta2**self.steps) ** 0.5

            # Compute step size
            adapted_lr = lr / bias_correction1
            denom = exp_avg_sq.sqrt().div_(bias_correction2_sqrt).add_(self.eps)
            # Parameter update
            p.addcdiv_(exp_avg, denom, value=-adapted_lr)
            if self.log_effective_lr:
                st["effective_lr"] = lr / denom

            if not self.reduce_mem:
                self.update_freq(p_name=name, is_exclude_layer=True)
            if log_microadam_statistics:
                # compute sparsify
                sp_u = (grad == 0).sum()  # check sparsity before zerorizing
        else:
            grad = p.grad.view(-1)
            param_data = p.data.view(-1)
            if self.mask_criterion == "weights":
                target_ = param_data
            elif self.mask_criterion == "gradients":
                target_ = grad
            else:
                raise ValueError(
                    f"mask_criterion {self.mask_criterion} is not supported. "
                    "Please choose from ['weights', 'gradients']."
                )

            density_update_flag = self.update_density(st)

            # !===== method1: store index in int16, then select out the grad =====!
            # d = st["d"]
            # d_block_size = st["d_block_size"]
            # topk_full_blocks_count, d_index_topk = (
            #     st["topk_full_blocks_count"],
            #     st["d_index_topk"],
            # )
            # k_block_size_many = st["k_block_size_many"]
            # k_block_size_few = st["k_block_size_few"]
            # k_index = st["k_index"]
            # k = st["k"]
            # I = st["I"]
            # exp_avg = st["exp_avg"]
            # exp_avg_sq = st["exp_avg_sq"]
            # # STEP 5 + 9 (only for I)
            # # if time to update mask:
            # if self.steps == 1 or density_update_flag or self.steps % self.mask_interval == 0:
            #     I[:k_index] = (
            #         torch.topk(
            #             input=param_data[0:d_index_topk]
            #             .abs()
            #             .view(topk_full_blocks_count, d_block_size),
            #             # example: slice has size 1, but ks[-1] is 4
            #             k=k_block_size_many,
            #             sorted=False,
            #             largest=self.largest,
            #         )
            #         .indices.to(dtype=torch.int16)
            #         .view(-1)
            #     )

            #     if k_block_size_few > 0:  # there is a small block left
            #         I[k_index:] = (
            #             torch.topk(
            #                 input=param_data[d_index_topk:].abs(),
            #                 # example: slice has size 1, but ks[-1] is 4
            #                 k=k_block_size_few,
            #                 sorted=False,
            #                 largest=self.largest,
            #             )
            #             .indices.to(dtype=torch.int16)
            #             .view(-1)
            #         )

            #     if not self.reduce_mem:
            #         # update the param update frequencies
            #         self.update_freq(
            #             name,
            #             I,
            #             d_block_size,
            #             topk_full_blocks_count,
            #             k_block_size_many,
            #             k_block_size_few,
            #             is_exclude_layer=False,
            #         )
            #     # update mask count
            #     if not self.mask_update_flag:
            #         self.mask_update_times += 1
            #         self.mask_update_flag = True

            # # weight decay step
            # if wd > 0:
            #     p.mul_(1 - lr * wd)

            # # Define the row and column indices
            # row_indices = torch.arange(topk_full_blocks_count, dtype=torch.int)
            # col_indices = I[:k_index].view(
            #     topk_full_blocks_count, -1).to(torch.int)
            # I_k_index_int = I[k_index:].to(torch.int)  # Avoid repeated casting

            # # Select the elements
            # chosen_grad_many = grad[0:d_index_topk].view(
            #     topk_full_blocks_count, d_block_size)[row_indices[:, None], col_indices]
            # chosen_grad_few = grad[d_index_topk:][I_k_index_int]

            # # Flatten selected_elements and concatenate with selected_elements_few
            # chosen_grad = torch.cat(
            #     [chosen_grad_many.flatten(), chosen_grad_few])
            # # Update biased first and second moment estimates
            # # m_t = β1 * m_t-1 + (1 - β1) * g_t
            # exp_avg.mul_(self.beta1).add_(chosen_grad, alpha=1 - self.beta1)
            # # v_t = β2 * v_t-1 + (1 - β2) * g_t^2
            # exp_avg_sq.mul_(self.beta2).addcmul_(
            #     chosen_grad, chosen_grad, value=1 - self.beta2
            # )

            # # Bias correction
            # bias_correction1 = 1 - self.beta1**self.steps
            # bias_correction2_sqrt = (1 - self.beta2**self.steps) ** 0.5

            # # Compute step size
            # adapted_lr = lr / bias_correction1
            # denom = exp_avg_sq.sqrt().div_(bias_correction2_sqrt).add_(self.eps)

            # # Modified parameter update
            # update_values = -adapted_lr * (exp_avg / denom)
            # param_data[d_index_topk:][I_k_index_int] += update_values[k_index:]
            # param_data[0:d_index_topk].view(topk_full_blocks_count, d_block_size)[
            #     row_indices[:, None], col_indices] += update_values[:k_index].view(topk_full_blocks_count, -1)
            # p.data = param_data.view(p.data.shape)

            # if log_microadam_statistics:
            #     # compute sparsify
            #     sp_u = d - k  # check sparsity before zerorizin

            # !===== method2: store index in long, then select out the grad =====!
            d = st["d"]
            k = st["k"]
            I = st["I"]
            exp_avg = st["exp_avg"]
            exp_avg_sq = st["exp_avg_sq"]
            d_block_size = st["d_block_size"]
            topk_full_blocks_count, d_index_topk = (
                st["topk_full_blocks_count"],
                st["d_index_topk"],
            )
            k_block_size_many = st["k_block_size_many"]
            k_index = st["k_index"]
            # STEP 5 + 9 (only for I)
            # if time to update mask:
            if self.steps == 1 or density_update_flag or self.steps % self.mask_interval == 0:
                k_block_size_few = st["k_block_size_few"]

                I[:k_index] = (
                    torch.topk(
                        input=target_[0:d_index_topk]
                        .abs()
                        .view(topk_full_blocks_count, d_block_size),
                        k=k_block_size_many,
                        sorted=False,
                        largest=self.largest,
                    )
                    .indices.to(dtype=torch.int16)
                    .view(-1)
                )

                if k_block_size_few > 0:  # there is a small block left
                    I[k_index:] = (
                        torch.topk(
                            input=target_[d_index_topk:].abs(),
                            # example: slice has size 1, but ks[-1] is 4
                            k=k_block_size_few,
                            sorted=False,
                            largest=self.largest,
                        )
                        .indices.to(dtype=torch.int16)
                        .view(-1)
                    )

                if not self.reduce_mem:
                    # update the param update frequencies
                    self.update_freq(
                        name,
                        I,
                        d_block_size,
                        topk_full_blocks_count,
                        k_block_size_many,
                        k_block_size_few,
                        is_exclude_layer=False,
                    )

            # weight decay step
            if wd > 0:
                p.mul_(1 - lr * wd)

            # Create mask for top-k elements
            mask = I.clone().to(torch.long).to(grad.device)
            block_offset = torch.arange(
                topk_full_blocks_count, device=grad.device, dtype=torch.long) * d_block_size

            mask[:k_index] += block_offset.repeat_interleave(k_block_size_many)
            mask[k_index:] += topk_full_blocks_count * d_block_size

            # Extract the chosen gradients
            chosen_grad = grad[mask]

            # Update biased first and second moment estimates
            # m_t = β1 * m_t-1 + (1 - β1) * g_t
            # print(f"exp_avg device: {exp_avg.device}, chosen_grad: {chosen_grad.device}")
            exp_avg.mul_(self.beta1).add_(chosen_grad, alpha=1 - self.beta1)
            # v_t = β2 * v_t-1 + (1 - β2) * g_t^2
            exp_avg_sq.mul_(self.beta2).addcmul_(
                chosen_grad, chosen_grad, value=1 - self.beta2
            )

            # Bias correction
            bias_correction1 = 1 - self.beta1**self.steps
            bias_correction2_sqrt = (1 - self.beta2**self.steps)**0.5

            # Compute step size
            adapted_lr = lr / bias_correction1
            denom = exp_avg_sq.sqrt().div_(bias_correction2_sqrt).add_(self.eps)
            # Modified parameter update
            update_values = -adapted_lr * (exp_avg / denom)
            param_data.index_add_(0, mask, update_values)
            p.data.copy_(param_data.view(p.data.shape))
            if self.log_effective_lr:
                st["effective_lr"] = lr / denom
            if log_microadam_statistics:
                # compute sparsify
                sp_u = d - k  # check sparsity before zerorizin

        return sp_u

    def _log(self, sparsity_u, elapsed_step):
        if self.reduce_mem:
            return
        if self.steps % self.log_every != 0:
            return

        if is_initialized():
            sync_data = torch.tensor(
                [
                    sparsity_u,
                    elapsed_step,
                ],
                dtype=torch.float32,
                requires_grad=False,
                device="cuda",
            )  # correct, loss, size
            all_reduce(sync_data, op=ReduceOp.AVG)
            (
                sparsity_u,
                elapsed_step,
            ) = sync_data
            del sync_data
            torch.cuda.empty_cache()

        if not is_initialized() or get_rank() == 0:
            wandb_data = {
                "step/optimizer_steps": self.steps,
                "step/gpu_mem_usage": get_gpu_mem_usage(),
                "step/sparsity_u": sparsity_u / self.model_size * 100.0,
                "step/elapsed_step": elapsed_step,
                "step/density": self.k_current,
            }
            wandb.log(wandb_data, commit=False)

    def update_freq(
        self,
        p_name="",
        chosen_index=None,
        d_block_size=None,
        topk_full_blocks_count=None,
        k_block_size_many=None,
        k_block_size_few=None,
        is_exclude_layer=False,
    ):
        if is_exclude_layer:
            self.update_frequencies[p_name] += 1
        else:
            device = chosen_index.device if chosen_index is not None else "cpu"
            recovered_index = self.recover_original_indices(
                chosen_index,
                d_block_size,
                topk_full_blocks_count,
                k_block_size_many,
            )
            freq_dtype = self.update_frequencies[p_name].dtype
            increment_value = torch.tensor(1, dtype=freq_dtype, device=device)
            # Use in-place index_put with proper dtype handling
            self.update_frequencies[p_name].index_put_(
                indices=(recovered_index,
                         ), values=increment_value, accumulate=True
            )
            # Explicit cleanup (optional but recommended)
            del recovered_index, increment_value
            torch.cuda.empty_cache() if device.type == "cuda" else None

        # Maintain original counting logic
        with torch.no_grad():
            count = torch.sum(self.update_frequencies[p_name] > 0).cpu()
            self.updated_params[p_name].append(count.item())

    def recover_original_indices(self, chosen_index, d_block_size,
                                 topk_full_blocks_count,
                                 k_block_size_many,):
        device = chosen_index.device
        d_index_topk = topk_full_blocks_count * d_block_size
        k_index = topk_full_blocks_count * k_block_size_many

        # Recover original indices for full blocks
        block_offsets = (
            torch.arange(
                topk_full_blocks_count, device=device, dtype=torch.long
            ).repeat_interleave(k_block_size_many)
            * d_block_size
        )
        # Vectorized index calculation
        full_indices = chosen_index[:k_index].to(torch.long) + block_offsets

        # Recover original indices for the small block
        # Small block processing with direct device placement
        small_indices = chosen_index[k_index:].to(torch.long) + d_index_topk

        # Concatenate without device transfer
        return torch.cat([full_indices.to(torch.long), small_indices.to(torch.long)])

    def update_freq_to_wandb(self):
        if not is_initialized() or get_rank() == 0:
            total_num_updated_params = sum(
                updated_param[-1] for updated_param in self.updated_params.values()
            )
            wandb.log(
                {
                    f"statistics/accumulated number of updated parameters": total_num_updated_params,
                    f"statistics/fraction of parameters updated at least once": 100
                    * (total_num_updated_params / self.model_size),
                },
                commit=False,
            )

    def update_density(self, st):
        if self.dynamic_density and self.steps % self.density_interval == 0:
            self.k_current = (
                self.k_init
                - self.steps * (self.k_init - self.end_density) /
                self.total_steps
            )

            # Recalculate density-related quantities in `st`
            st["k_block_size_many"] = int(
                math.ceil(st["d_block_size"] * self.k_current))
            st["k_block_size_few"] = int(
                math.ceil((st["d"] - st["d_index_topk"]) * self.k_current)
            )  # 0 for d % self.d_block_size = 0
            st["k_index"] = st["topk_full_blocks_count"] * \
                st["k_block_size_many"]
            st["k"] = st["k_index"] + st["k_block_size_few"]

            # Explicitly delete old tensors to free memory
            # if "I" in st:
            st["I"][:st["k"]] = 0
            st["I"] = st["I"][:st["k"]].clone()
            # if "exp_avg" in st:
            st["exp_avg"][:st["k"]] = 0
            st["exp_avg"] = st["exp_avg"][:st["k"]].clone()
            # if "exp_avg_sq" in st:
            st["exp_avg_sq"][:st["k"]] = 0
            st["exp_avg_sq"] = st["exp_avg_sq"][:st["k"]].clone()

            # Call torch.cuda.empty_cache() if tensors are on GPU
            torch.cuda.empty_cache()

            return True
        else:
            return False


def is_exclude_layer(layer_name, exclude_layers):
    """
    Check if a layer should be excluded based on its name.

    Args:
        layer_name (str): The name of the layer.
        exclude_layers (set): A set of additional layer names or patterns to exclude.

    Returns:
        bool: True if the layer should be excluded, False otherwise.
    """
    if "norm" in layer_name.lower():
        return True

    # Exclude the final layer (commonly named "classifier" or "lm_head")
    if "classifier" in layer_name.lower() or "lm_head" in layer_name.lower():
        return True

    return any(name_part in layer_name.lower() for name_part in exclude_layers)
