import torch
from torch import Tensor
from typing import Optional, Dict, Any, Tuple
from torch import nn
from ..compressed import CompressedTensor
from typing import Optional, Dict, Any, Tuple, List
import random
import glob
import os

class SoftmaxFunction(torch.autograd.Function):
    @staticmethod
    @torch.compile
    def forward(
        input: Tensor,
        dim: int,
        dtype: Optional[torch.dtype] = None,
        compress_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tensor:
        original_dtype = input.dtype
        compute_dtype = dtype if dtype is not None else original_dtype
        input_compute = input.to(compute_dtype)
        input_max = torch.amax(input_compute, dim=dim, keepdim=True)
        input_shifted = input_compute - input_max
        exp_input = torch.exp(input_shifted)
        sum_exp = torch.sum(exp_input, dim=dim, keepdim=True)
        softmax_output = exp_input / sum_exp
        output = softmax_output.to(original_dtype if dtype is None else dtype)
        return output

    @staticmethod
    def setup_context(ctx, inputs: Tuple[Tensor, int, Optional[torch.dtype], Optional[Dict]], output: Tensor) -> None:
        input, dim, dtype, compress_kwargs = inputs
        ctx.dim = dim
        ctx.input_dtype = input.dtype
        ctx.output_dtype = output.dtype
        ctx.is_compressed = False
        if compress_kwargs is not None:
            use_optimizer_compress = compress_kwargs.pop("use_optimizer_compress", False)
            use_gradient_compress = compress_kwargs.pop("use_gradient_compress", False)
            ctx.save_for_backward(CompressedTensor(output, **compress_kwargs))
            compress_kwargs["use_optimizer_compress"] = False
            ctx.use_optimizer_compress = False
        else:
            ctx.save_for_backward(output)

    @staticmethod
    @torch.compile
    def backward(ctx, grad_output: Tensor) -> Tuple[Tensor, None, None, None]:
        output, = ctx.saved_tensors
        dim = ctx.dim
        if isinstance(output, CompressedTensor):
            output = output.reconstruct()
        else:
            output, = ctx.saved_tensors
        grad_output = grad_output.to(output.dtype)
        dot = torch.sum(grad_output * output, dim=dim, keepdim=True)
        grad_input = output * (grad_output - dot)
        grad_input = grad_input.to(ctx.input_dtype)


        return grad_input, None, None, None
