from typing import Callable, Any
import socket
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from datetime import datetime
from tqdm import tqdm
from torch.utils.tensorboard.writer import SummaryWriter
from easydict import EasyDict
from typing import Any, Optional

from ..Activation import ACTIVATIONS_CLASSES

def LN(x: torch.Tensor, eps: float = 1e-5) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    mu = x.mean(dim=-1, keepdim=True)
    x = x - mu
    std = x.std(dim=-1, keepdim=True)
    x = x / (std + eps)
    return x, mu, std

def c2a(x):
    """
    将笛卡尔坐标 (feature, dim) 转为超球坐标 (feature, dim-1)
    输出角度范围：
    theta_1,...,theta_{dim-2} ∈ [0, π]
    theta_{dim-1} ∈ [0, 2π)
    """
    # x: (f, dim)
    r = torch.norm(x, dim=-1, keepdim=True)
    x = x / r
    _, dim = x.shape
    norms_suffix = torch.sqrt(torch.flip(torch.cumsum(torch.flip(x ** 2, dims=[1]), dim=1), dims=[1]))
    idxs = torch.arange(dim - 2, device=x.device)
    theta_part = torch.acos(x[:, idxs] / norms_suffix[:, idxs])
    theta_last = torch.atan2(x[:, -1], x[:, -2]) % (2 * torch.pi)
    return torch.cat([theta_part, theta_last.unsqueeze(1)], dim=1)

def a2c(angles):
    """
    angles: (B, m)  # B 个特征，m 个超球角，范围是 [0, 2π) 或 [0, π)
    return: (B, m+1)  # 转换成直角坐标
    """
    B, _ = angles.shape

    # 先算 cos 和 sin
    cos_vals = torch.cos(angles)   # (B, m)
    sin_vals = torch.sin(angles)   # (B, m)

    # 前缀累乘的 sin，用 cumprod 一次性算出来
    # prefix_sin[i] 表示到第 i-1 个角的 sin 累乘，第一列全 1
    prefix_sin = torch.cat(
        [torch.ones((B, 1), device=angles.device, dtype=angles.dtype),
         sin_vals[:, :-1].cumprod(dim=1)],
        dim=1
    )  # (B, m)

    # 前 m 个坐标
    coords = prefix_sin * cos_vals  # (B, m)
    # 最后一个坐标是所有 sin 的累乘
    last_coord = sin_vals.cumprod(dim=1)[:, -1:]  # (B, 1)
    return torch.cat([coords, last_coord], dim=1)  # (B, m+1)


def angles_weight(n, m, device="cpu"):
    """
    生成形状为 [n, m] 的 tensor，每行是 m+1 维单位球面上的均匀采样方向的超球坐标角度
    n: 样本数
    m: 角度数量（对应 m+1 维的超球坐标）
    """
    # Step 1: 在 m+1 维空间上均匀采样
    cart = torch.randn(n, m + 1, device=device)
    cart = cart / torch.norm(cart, dim=-1, keepdim=True)  # 单位化
    # Step 2: 转为超球坐标（角度表示）
    angles = c2a(cart)  # (n, m)
    return angles

class RotaryEncoder(nn.Module):
    def __init__(self, feature_dim: int, hidden_dim: int, device=None):
        super().__init__()
        assert device is not None
        self.feature_dim = feature_dim
        self.hidden_dim = hidden_dim
        self.weight = nn.Parameter(angles_weight(feature_dim, hidden_dim - 1, device=device))

    def forward(self, x):
        return x @ a2c(self.weight)


class TiedTranspose(nn.Module):
    def __init__(self, encoder: nn.Linear):
        super().__init__()
        self.decoder = encoder

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x @ self.decoder.weight

    @property
    def weight(self) -> torch.Tensor:
        return self.decoder.weight


class RotaryDecoder(nn.Module):
    def __init__(self, feature_dim: int, hidden_dim: int, device=None):
        super().__init__()
        assert device is not None
        self.feature_dim = feature_dim
        self.hidden_dim = hidden_dim
        self.a_weight = nn.Parameter(angles_weight(feature_dim, hidden_dim - 1, device=device))

    def forward(self, x):
        return x @ a2c(self.a_weight).T

    @property
    def weight(self):
        return a2c(self.a_weight).T
    

class RotarySAE(nn.Module):
    def __init__(self, args, feature_dim: int, activation: nn.Module = nn.ReLU()):
        """
        Args:
            args (): The args, comming from the config.toml
            feature_dim (int): the model middle activation dimension.
        """
        super().__init__()
        self.tied = eval(args.autoencoder.tied)
        self.normalize = eval(args.autoencoder.normalize)

        self.dtype = args.autoencoder.dtype
        self.device = args.exp.device
        self.rate = int(args.autoencoder.rate)

        self.pre_bias = nn.Parameter(torch.zeros(feature_dim, device=self.device))
        self.latent_bias = nn.Parameter(torch.zeros(self.rate * feature_dim, device=self.device))
        self.encoder = nn.Linear(feature_dim, self.rate * feature_dim, device=self.device, bias=False)

        if self.tied:
            self.decoder = TiedTranspose(self.encoder)
        else:
            self.decoder = RotaryDecoder(feature_dim, self.rate * feature_dim, device=self.device)

        self.activation = activation

    def preprocess(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:
        if not self.normalize:
            return x, dict()
        x, mu, std = LN(x)
        return x, dict(mu=mu, std=std)

    def encode_pre_act(self, x: torch.Tensor, latent_slice: slice = slice(None)) -> torch.Tensor:
        """
        :param x: input data (shape: [batch, n_inputs])
        :param latent_slice: slice of latents to compute
            Example: latent_slice = slice(0, 10) to compute only the first 10 latents.
        :return: autoencoder latents before activation (shape: [batch, n_latents])
        """
        x = x - self.pre_bias
        latents_pre_act = F.linear(
            x, self.encoder.weight, self.latent_bias[latent_slice]
        )
        return latents_pre_act

    def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:
        """
        :param x: input data (shape: [batch, n_inputs])
        :return: autoencoder latents (shape: [batch, n_latents])
        """
        x, info = self.preprocess(x)
        return self.activation(self.encode_pre_act(x)), info

    def decode(self, latents: torch.Tensor, info: Optional[dict[str, Any]] = None) -> torch.Tensor:
        """
        :param latents: autoencoder latents (shape: [batch, n_latents])
        :return: reconstructed data (shape: [batch, n_inputs])
        """
        ret = self.decoder(latents) + self.pre_bias
        if self.normalize:
            assert info is not None
            ret = ret * info["std"] + info["mu"]
        return ret

    def forward(self, x):
        x, info = self.preprocess(x)
        latents_pre_act = self.encode_pre_act(x)
        latents = self.activation(latents_pre_act)
        recons = self.decode(latents, info)

        return latents_pre_act, latents, recons

    def state_dict(self, destination=None, prefix="", keep_vars=False):
        sd = super().state_dict(destination, prefix, keep_vars)
        sd[prefix + "activation"] = self.activation.__class__.__name__
        if hasattr(self.activation, "state_dict"):
            sd[prefix + "activation_state_dict"] = self.activation.state_dict()
        return sd

    @classmethod
    def from_state_dict(
        cls, args, state_dict: dict[str, torch.Tensor], strict: bool = True
    ) -> "RotarySAE":
        """_summary_

        Args:
            args (EasyDict): comming from the config.toml, controling the model hyperparameters
            state_dict (dict[str, torch.Tensor]): loading from the save path.
            strict (bool, optional), Defaults to True.

        Returns:
            OriginalSAE: the SAEs model
        """
        _, feature_dim = state_dict["encoder.weight"].shape

        # Retrieve activation
        activation_class_name = state_dict.pop("activation", "ReLU")
        activation_class = ACTIVATIONS_CLASSES.get(activation_class_name, nn.ReLU)
        activation_state_dict = state_dict.pop("activation_state_dict", {})
        if hasattr(activation_class, "from_state_dict"):
            activation = activation_class.from_state_dict(
                activation_state_dict, strict=strict
            )
        else:
            activation = activation_class()
            if hasattr(activation, "load_state_dict"):
                activation.load_state_dict(activation_state_dict, strict=strict)

        autoencoder = cls(args, feature_dim, activation=activation)
        state_dict['decoder.a_weight'] = state_dict['decoder.weight']
        print(state_dict['decoder.weight'].shape)
        for each in state_dict.keys():
            print(each)
        autoencoder.load_state_dict(state_dict, strict=False)
        return autoencoder
