import pdb
from time import time
from typing import List, Callable, Union

import einops
import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch import nn
import tinycudann as tcnn
from torch.cuda.amp import custom_bwd, custom_fwd
from model.encoder.transformer import CrossAttn, SelfAttn
from utils.model_tools import hook_fn_decorator
import numpy as np
import math
from pdb import set_trace as bb

def trunc_exp(x, max_val=15.0):
    return torch.exp(torch.clamp(x, max=max_val))

class Embedder:
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        self.create_embedding_fn()
        
    def create_embedding_fn(self):
        d = self.kwargs['input_dims']
        max_freq = self.kwargs['max_freq_log2']
        N_freqs = self.kwargs['num_freqs']
        
        if self.kwargs['log_sampling']:
            self.freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
        else:
            self.freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
        
        self.out_dim = d * len(self.freq_bands) * len(self.kwargs['periodic_fns'])

    def embed(self, inputs: torch.Tensor):
        outs = []
        # 遍历每个维度
        for i in range(inputs.shape[-1]):
            x = inputs[..., i]
            # 遍历频率
            for freq in self.freq_bands:
                for p_fn in self.kwargs['periodic_fns']:
                    outs.append(p_fn(x * freq * torch.pi))
        return torch.stack(outs, dim=-1)


def get_embedder(multires, include_input=False, i=0):
    if i == -1:
        return nn.Identity(), 3
    
    embed_kwargs = {
        'include_input' : include_input,
        'input_dims' : 3,
        'max_freq_log2' : multires-1,
        'num_freqs' : multires,
        'log_sampling' : True,
        'periodic_fns' : [torch.sin, torch.cos],
    }
    
    embedder_obj = Embedder(**embed_kwargs)
    embed = lambda x, eo=embedder_obj : eo.embed(x)
    return embed, embedder_obj.out_dim

class TargetNet(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        pass

    def construct_opt_blocks(
        self, ftask_dim, weight_dim, deriv_hidden_dim, driv_num_layers, *args, **kwargs
    ):
        raise NotImplementedError

    def topo_order(self):
        raise NotImplementedError

    def reverse_topo_order(self):
        raise NotImplementedError

    def get_submodule_names(self) -> List[str]:
        if not hasattr(self, "submodule_names"):
            self.submodule_names = []
            for subm in self.get_submodules():
                for name, module in self.named_modules():
                    if subm is module:
                        self.submodule_names.append(name)
                        break
        return self.submodule_names

    def merge_submodule_weights(self, weight_dicts):
        """
        convert the weight dict of the submodules to the weight dict of the target net.
        """
        weight_dict = {}
        for sub_name, wd in zip(self.get_submodule_names(), weight_dicts):
            for k, v in wd.items():
                weight_dict[sub_name + "." + k] = v

        return weight_dict


class MLP(nn.Module):
    """
    A simple MLP with a variable number of layers and hidden dimensions.
    """

    def __init__(
        self, in_dim, out_dim, hidden_dim, num_layers, activation=F.leaky_relu
    ):
        super().__init__()
        if num_layers == 1:
            self.fcs = nn.ModuleList([nn.Linear(in_dim, out_dim)])
        else:
            self.fcs = nn.ModuleList([nn.Linear(in_dim, hidden_dim)])
            self.fcs.extend(
                [nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)]
            )
            self.fcs.append(nn.Linear(hidden_dim, out_dim))
        self.activation = activation

    def forward(self, x):
        for i, fc in enumerate(self.fcs):
            if i == len(self.fcs) - 1:
                x = fc(x)  # Remove the activation function from the last layer
            elif fc.in_features == fc.out_features:
                x = self.activation(fc(x)) + x
            else:
                x = self.activation(fc(x))
        return x


class FIN_FOUT(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.in_dim = args[0]
        self.out_dim = args[1]
        self.hidden_dim = kwargs.get("hidden_dim", 1024)
        self.num_layers = kwargs.get("num_layers", 1)

        self.next_blocks = []
        self.prev_blocks = []
        self.net = SelfAttn(
            self.in_dim, hidden_dim=self.hidden_dim, num_layers=self.num_layers
        )

    """
    Dummy function, does not change anything.
    """

    def add2in_dim(self, dim):
        pass

    def link(self, next_block):
        next_block.add2in_dim(self.out_dim)
        self.next_blocks.append(next_block)
        next_block.prev_blocks.append(self)

    def pseudo_forward(self, x):
        res = self.net(x)
        self.z_in = res
        return res

    def pseudo_backward(self, x):
        res = self.net(x)
        self.dldin = res
        return res

    def get_zin(self):
        return self.z_in

    def get_dldin(self, prev_blk):
        return self.dldin


class ModuleEncoder(nn.Module):

    def __init__(
        self, target_net, weight_dim, hidden_dim, num_layers, column_grouping=False
    ):
        super().__init__()
        self.name_shape_dict = {k: v.shape for k, v in target_net.named_parameters()}
        weight_shape = self.name_shape_dict["weight"]
        self.out_dim = weight_shape[-2]
        self.in_dim = weight_shape[-1]
        self.weight_dim = weight_dim
        if column_grouping:
            self.num_tokens = self.in_dim 
            self.split_mlp = MLP(in_dim=self.out_dim, out_dim=weight_dim, hidden_dim=hidden_dim, num_layers=num_layers)
        else:
            self.num_tokens = self.out_dim 
            self.split_mlp = MLP(in_dim=self.in_dim, out_dim=weight_dim, hidden_dim=hidden_dim, num_layers=num_layers)

        self.pe = nn.Parameter(torch.randn(1, self.num_tokens, weight_dim))
        self.norm = nn.LayerNorm(weight_dim)

    def forward(self, weight_dict):
        W = weight_dict["weight"] 
        if W.ndim == 2:
            W = W.unsqueeze(0)

        tokens = W # [B, num_tokens, in_dim]
        tokens = self.split_mlp(tokens)  # [B, num_tokens, weight_dim]
        tokens = tokens + self.pe
        tokens = self.norm(tokens)
        return tokens  # shape = [B, num_tokens, weight_dim]



class ModuleDecoder(nn.Module):

    def __init__(
        self, target_net, weight_dim, hidden_dim, num_layers, column_grouping=False
    ):
        super().__init__()
        self.name_shape_dict = {k: v.shape for k, v in target_net.named_parameters()}
        weight_shape = self.name_shape_dict["weight"]

        self.out_dim = weight_shape[-2]
        self.in_dim = weight_shape[-1]
        if column_grouping:
            self.num_tokens = self.in_dim 
            self.post_mlp = MLP(in_dim=weight_dim, out_dim=self.out_dim, hidden_dim=hidden_dim, num_layers=num_layers, activation=F.relu,)
        else:
            self.num_tokens = self.out_dim 
            self.post_mlp = MLP(in_dim=weight_dim, out_dim=self.in_dim, hidden_dim=hidden_dim, num_layers=num_layers, activation=F.relu,)

        self.norm = nn.LayerNorm(weight_dim)

    def forward(self, weight_vec):
        weight_vec = self.norm(weight_vec)
        weight_vec = self.post_mlp(weight_vec)  # [B, token_dim, token_input_dim]

        W = weight_vec
        weight_dict = {
            "weight": W.view(-1, *self.name_shape_dict["weight"]),
        }
        return weight_dict


class OptBlock(nn.Module):
    """
    a OptSubBlock is responsible for estimating the derivative of output w.r.t. the inputs, and the target net parameters.

    """

    def __init__(
        self,
        module,
        ftask_dim,
        out_dim,
        weight_dim,
        hidden_dim,
        num_layers,
        dl_din_way="direct",
        dl_dw_way="direct",
        **kwargs,
    ):
        super().__init__()
        self.module = module

        self.ftask_dim = ftask_dim
        self.in_dim = 0
        self.out_dim = out_dim
        self.weight_dim = weight_dim
        self.hidden_dim = hidden_dim
        self.dl_din_way = dl_din_way
        self.dl_dw_way = dl_dw_way
        self.num_layers = num_layers
        self.next_blocks = []
        self.prev_blocks = []
        self.use_hyper_crossattn = kwargs.get("use_hyper_crossattn")
        self.enable_each_layer_lr = kwargs.get("enable_each_layer_lr")
        self.column_grouping = kwargs.get("column_grouping")
        self.slice_dl_din = kwargs.get("slice_dl_din")
        # bb()

    def setup(self):
        if self.column_grouping:
            self.num_tokens = int(self.module.in_features)
        else:
            self.num_tokens = int(self.module.out_features)
        self.ftask_zin = CrossAttn(
            self.weight_dim,
            hidden_dim=self.hidden_dim,
            num_layers=self.num_layers,
        )
        self.token_task = CrossAttn(
            self.weight_dim,
            hidden_dim=self.hidden_dim,
            num_layers=self.num_layers,
        )
        self.learn_p_forward = nn.Parameter(torch.randn(1, self.num_tokens, self.weight_dim))
        if self.use_hyper_crossattn:
            self.weightemb_zprime = CrossAttn(
                self.weight_dim,
                hidden_dim=self.hidden_dim,
                num_layers=self.num_layers,
            )
            self.task_grad = CrossAttn(
                self.weight_dim,
                hidden_dim=self.hidden_dim,
                num_layers=self.num_layers,
            )
            self.token_grad = CrossAttn(
                self.weight_dim,
                hidden_dim=self.hidden_dim,
                num_layers=self.num_layers,
            )
            self.dl_dout_prime_net = CrossAttn(
                self.weight_dim,
                hidden_dim=self.hidden_dim,
                num_layers=self.num_layers,
            )
            self.learn_p_backward_z_grad = nn.Parameter(torch.randn(1, self.num_tokens, self.weight_dim))
            self.learn_p_backward_dl_dout = nn.Parameter(torch.randn(1, self.num_tokens, self.weight_dim))

        else:
            self.forward_net = CrossAttn(
                self.weight_dim,
                hidden_dim=self.hidden_dim,
                num_layers=self.num_layers,
            )

        if self.enable_each_layer_lr:
            self.lr_attn = CrossAttn(
                self.weight_dim,
                hidden_dim=self.hidden_dim,
                num_layers=self.num_layers,
            )
            self.lr_pre_mlp = MLP(self.weight_dim, 32, self.weight_dim//2, num_layers=2)  # e.g., D → 32
            self.lr_mlp = MLP(32 * self.num_tokens, 1, 128, num_layers=2)

        self.out_din = CrossAttn(
            self.weight_dim,
            hidden_dim=self.hidden_dim,
            num_layers=self.num_layers,
        )
        self.dout_dw = CrossAttn(
            self.weight_dim,
            hidden_dim=self.hidden_dim,
            num_layers=self.num_layers,
        )
        self.dl_din_net = CrossAttn(
            self.weight_dim,
            hidden_dim=self.hidden_dim,
            num_layers=self.num_layers,
        )
        self.dl_dw_net = CrossAttn(
            self.weight_dim,
            hidden_dim=self.hidden_dim,
            num_layers=self.num_layers,
        )

    def pseudo_forward(self, ftask, weight_emb):
        z_in = torch.cat([prev.get_zin() for prev in self.prev_blocks], dim=-2)

        if z_in.shape[0] != weight_emb.shape[0]:
            weight_emb = einops.repeat(weight_emb, "1 L D -> n L D", n=z_in.shape[0])

        task_context = self.ftask_zin(ftask, z_in)
        z_prime = self.token_task(self.learn_p_forward.expand(ftask.shape[0], -1, -1), task_context)

        if self.use_hyper_crossattn:
            out = self.weightemb_zprime(weight_emb, z_prime)
        else:
            out = self.forward_net(weight_emb, z_in)

        if self.enable_each_layer_lr:
            lr_attn = self.lr_attn(z_prime, weight_emb)  # [B, L, D]
            lr_reduced = self.lr_pre_mlp(lr_attn)        # [B, L, 32]
            lr_flat = lr_reduced.view(lr_reduced.shape[0], -1)  # [B, L×32]
            self.lr = self.lr_mlp(lr_flat).unsqueeze(-1) 
        else:
            self.lr = 1.0

        self.z_in = z_in
        self.nxt_z_in = out                  # [B, 1, 1]
        return out

    def get_zin(self):
        return self.nxt_z_in

    def get_lr(self):
        return self.lr

    def pseudo_backward(self, ftask, weight_emb):
        z_in = self.z_in
        dl_dout = torch.cat([next.get_dldin(self) for next in self.next_blocks], dim=-2)

        if z_in.shape[0] != weight_emb.shape[0]:
            weight_emb = einops.repeat(weight_emb, "1 L D -> n L D", n=z_in.shape[0])

        if self.use_hyper_crossattn:
            task_context = self.task_grad(ftask, z_in)
            z_grad = self.token_grad(self.learn_p_backward_z_grad.expand(ftask.shape[0], -1, -1), task_context)
            dout_din = self.out_din(weight_emb, z_grad)  # N L D
            dout_dw = self.dout_dw(z_grad, weight_emb)  # N L D

            dl_dout_prime = self.dl_dout_prime_net(self.learn_p_backward_dl_dout.expand(ftask.shape[0], -1, -1), dl_dout)
            dl_din = self.dl_din_net(dl_dout_prime, dout_din)
            dl_dw = self.dl_dw_net(dl_dout_prime, dout_dw)
            # bb()
        else:
            dout_din = self.out_din(z_in, weight_emb)  # N L D
            dout_dw = self.dout_dw(weight_emb, z_in)  # N L D
            
            dl_din = self.dl_din_net(dout_din, dl_dout)
            dl_dw = self.dl_dw_net(dout_dw, dl_dout)

        self.dl_din = dl_din
        return dl_dw

    def get_dldin(self, prev_blk):
        len_prev_blocks = len(self.prev_blocks)
        idx = self.prev_blocks.index(prev_blk)
        context_len = self.dl_din.shape[1]
        part_len = context_len // len_prev_blocks
        if self.slice_dl_din:
            return self.dl_din[
                :,
                part_len * idx : part_len * (idx + 1),
            ]
        return self.dl_din

    """
    connect the output of a module to another module's input
    """

    def link(self, next_block):
        next_block.add2in_dim(self.weight_dim)

        self.next_blocks.append(next_block)
        next_block.prev_blocks.append(self)

    def add2in_dim(self, dim):
        self.in_dim += dim


# @torch.compile(mode='reduce-overhead')  # Disabled to avoid CUDA Graph warnings
class OptLayer(nn.Module):
    """
    A OptBlock takes the input and output shape of each sub-block of the target network.
    it first does a peusdo forward mimicking the forward pass of the target network.
    stores the internal results, and mimicking the backward pass.

    target net must implement three functions returning the submodule, input shapes, and output shapes.

    """

    def __init__(
        self,
        target_net: TargetNet,
        ftask_dim,
        weight_dim,
        deriv_hidden_dim,
        driv_num_layers,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.ftask_dim = ftask_dim
        self.weight_dim = weight_dim
        # bb()
        self.opt_blocks, self.forward_in, self.dloss_dout = (
            target_net.construct_opt_blocks(
                ftask_dim,
                weight_dim,
                deriv_hidden_dim,
                driv_num_layers,
                *args,
                **kwargs,
            )
        )
        for opt_block in self.opt_blocks:
            opt_block.setup()

    def pseudo_forward(self, ftask, weight_embs):
        for fin in self.forward_in:
            fin.pseudo_forward(ftask)

        for idx, (opt_block, weight_emb) in enumerate(
            zip(self.opt_blocks, weight_embs)
        ):
            opt_block.pseudo_forward(ftask, weight_emb)
            # print(f'opt_block {idx} done')

    def pseudo_backward(self, ftask, weight_embs):
        for fout in self.dloss_dout:
            fout.pseudo_backward(ftask)

        dw_dicts = []
        for opt_block, weight_emb, idx in reversed(
            list(zip(self.opt_blocks, weight_embs, range(len(self.opt_blocks))))
        ):
            dl_dw = opt_block.pseudo_backward(ftask, weight_emb)
            # print(f'opt_block {idx} backward done')
            dw_dicts.append(dl_dw)

        dw_dicts = list(reversed(dw_dicts))
        return dw_dicts

    def get_lrs(self):
        return [blk.get_lr() for blk in self.opt_blocks]

    def forward(self, ftask, weight_embs):
        self.pseudo_forward(ftask, weight_embs)
        dw_dicts = self.pseudo_backward(ftask, weight_embs)
        return dw_dicts


class ParamLN(nn.Module):
    def __init__(self, weight_dim):
        super().__init__()
        self.ln = nn.LayerNorm(weight_dim)

    def forward(self, weight_dict):
        return self.ln(weight_dict)

class InitWeightHypernet(nn.Module):
    def __init__(self, target_net, embed_dim, hidden_dim=128, num_layers=4, column_grouping=False):

        super(InitWeightHypernet, self).__init__()
        self.embed_dim = embed_dim
        self.num_tokens = []               
        self.token_queries = nn.ParameterList()  
        self.cross_attn_layers = nn.ModuleList() 
        for module in target_net.get_submodules():
            if not hasattr(module, 'in_features') or not hasattr(module, 'out_features'):
                raise ValueError("Target module must have `in_features` and `out_features` attributes")
            in_feat = module.in_features
            out_feat = module.out_features
            if column_grouping:
                L = in_feat
            else:
                L = out_feat
            self.num_tokens.append(L)

            token = nn.Parameter(torch.randn(1, L, embed_dim))
            self.token_queries.append(token)

            cross_attn = CrossAttn(self.embed_dim, hidden_dim=hidden_dim, num_layers=num_layers)
            self.cross_attn_layers.append(cross_attn)

    def forward(self, ftask):
        B, N, D = ftask.shape
        assert D == self.embed_dim, "Task embedding dimension D must match hypernet embed_dim"

        weight_embeddings = []

        for token, cross_attn in zip(self.token_queries, self.cross_attn_layers):
            attn_output = cross_attn(token.expand(B, -1, -1), ftask)  # attn_output: [B, L, D]
            weight_embeddings.append(attn_output)
        return weight_embeddings

class Hypernet(nn.Module):

    def __init__(
        self,
        target_net: TargetNet,
        ftask_dim,
        weight_dim,
        deriv_hidden_dim,
        driv_num_layers,
        codec_hidden_dim,
        codec_num_layers,
        num_layers,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.target_net = target_net
        self.ftask_dim = ftask_dim
        self.weight_dim = weight_dim
        self.ftask_adapter_mlp = MLP(ftask_dim, weight_dim, weight_dim, 4)
        self.encoder_adapter_mlp = MLP(weight_dim, weight_dim, weight_dim, 4)
        self.decoder_adapter_mlp = MLP(weight_dim, weight_dim, weight_dim, 4)
        self.enable_each_layer_lr = kwargs.get("enable_each_layer_lr")
        self.enable_weight_init = kwargs.get("enable_weight_init")
        self.num_layers = num_layers
        if not self.enable_weight_init:
            self.encoders = nn.ModuleList(
                [
                    ModuleEncoder(
                        target_module,
                        weight_dim,
                        codec_hidden_dim,
                        codec_num_layers,
                        column_grouping=kwargs.get("column_grouping")
                    )
                    for target_module in target_net.get_submodules()
                ]
            )
        if not self.enable_each_layer_lr:
            self.dynamic_lrs = nn.Parameter(torch.zeros(num_layers).fill_(-1e-2))
        self.encoders = nn.ModuleList(
            [
                ModuleEncoder(
                    target_module,
                    weight_dim,
                    codec_hidden_dim,
                    codec_num_layers,
                    column_grouping=kwargs.get("column_grouping")
                )
                for target_module in target_net.get_submodules()
            ]
        )
        self.decoders = nn.ModuleList(
            [
                ModuleDecoder(
                    target_module,
                    weight_dim,
                    codec_hidden_dim,
                    codec_num_layers,
                    column_grouping=kwargs.get("column_grouping")
                )
                for target_module in target_net.get_submodules()
            ]
        )
        self.opt_layer = OptLayer(
            target_net,
            weight_dim,
            weight_dim,
            deriv_hidden_dim,
            driv_num_layers,
            *args,
            **kwargs,
        )

        self.layer_norms = nn.ModuleList(
            [ParamLN(weight_dim) for submodule in self.target_net.get_submodules()]
        )
        self.init_weight_hypernet = InitWeightHypernet(
            target_net,
            weight_dim,
            hidden_dim=128,
            num_layers=4,
            column_grouping=kwargs.get("column_grouping")
        )

    def forward_block(self, ftask, weight_dicts, opt_block):
        weight_upd_dicts = opt_block(ftask, weight_dicts)
        weight_upd_dicts = [
            ln(submodule) for ln, submodule in zip(self.layer_norms, weight_upd_dicts)
        ]
        if self.enable_each_layer_lr:
            lrs = opt_block.get_lrs()
            return weight_upd_dicts, lrs
        # bb()
        return weight_upd_dicts


    def forward_blocks(self, ftask, warmup_lr=1.0):
        base_weight_dict = {k: v for k, v in self.target_net.named_parameters()}

        # Project task embeddings to weight token dimension if needed
        ftask_proj = self.ftask_adapter_mlp(ftask)

        # Get initial weights using init_weight_hypernet (expects (B, L, weight_dim))
        weight_embs = self.init_weight_hypernet(ftask_proj)

        # Decode initial weights to real space
        weight_dicts = [decoder(w) for (decoder, w) in zip(self.decoders, weight_embs)]

        final_weight_dicts = []
        for i in range(self.num_layers):
            weight_upd_embs, lrs = self.forward_block(ftask_proj, weight_embs, self.opt_layer)
            weight_upd_dicts = [decoder(w) for (decoder, w) in zip(self.decoders, weight_upd_embs)]
            weight_dicts = [
                {k: weights[k] + einops.einsum(upds[k], lr.view(-1), "b ..., b -> b ...") * warmup_lr for k in weights}
                for weights, upds, lr in zip(weight_dicts, weight_upd_dicts, lrs)
            ]
            weight_embs = [encoder(weight_dict) for encoder, weight_dict in zip(self.encoders, weight_dicts)]

            weight_dict = self.target_net.merge_submodule_weights(weight_dicts)
            weight_dict = {
                k: einops.einsum(v, base_weight_dict[k], "b ..., ... -> b ...") for k, v in weight_dict.items()
            }
            final_weight_dicts.append(weight_dict)

        return final_weight_dicts

    def generate_weights(self, ftask, warmup_lr=1.0):
        final_weight_dicts = self.forward_blocks(ftask, warmup_lr)
        self.generated_weights = final_weight_dicts
        return final_weight_dicts

    def forward(
        self,
        *inputs,
        final_weight_dicts=None,
        early_sup=False,
        act_idx=None,
        chunk_size=None,
    ):


        if final_weight_dicts is None:
            final_weight_dicts = self.generated_weights
        if act_idx is None and hasattr(self, "act_idx") and self.act_idx is not None:
            act_idx = self.act_idx
        if act_idx is not None:
            final_weight_dicts = [
                {k: v[act_idx : act_idx + 1] for k, v in weight_dict.items()}
                for weight_dict in final_weight_dicts
            ]
        inputs = [x[None] for x in inputs]
        assert final_weight_dicts is not None
        if early_sup:
            return torch.stack(
                [
                    torch.vmap(
                        torch.func.functional_call,
                        in_dims=(None, 0, 0),
                        randomness="different",
                    )(self.target_net, generated_weight, inputs)
                    for generated_weight in final_weight_dicts
                ]
            )
        else:
            # Use chunk_size to avoid memory issues in distributed training
            if chunk_size is None:
                chunk_size = 32  # Default chunk size to prevent OOM
            
            if len(inputs) > 1:
                results = torch.vmap(
                    torch.func.functional_call,
                    in_dims=(None, 0, tuple([0] * len(inputs))),
                    chunk_size=chunk_size,
                    randomness="different",
                )(self.target_net, final_weight_dicts[-1], tuple(inputs))
            else:
                results = torch.vmap(
                    torch.func.functional_call,
                    in_dims=(None, 0, 0),
                    chunk_size=chunk_size,
                    randomness="different",
                )(self.target_net, final_weight_dicts[-1], inputs[0])
                if isinstance(results, tuple):
                    results = [r[0] for r in results]
                else:
                    results = results[0]
                return results
            results = [r[0] for r in results]
            return results

    def activate_idx(self, idx):
        self.act_idx = idx

    def cleanup(self):
        """Clean up generated weights and indices to free memory"""
        if hasattr(self, "generated_weights"):
            if self.generated_weights is not None:
                # Explicitly delete and clear
                del self.generated_weights
                self.generated_weights = None
        if hasattr(self, "act_idx"):
            self.act_idx = None



class Toy(TargetNet):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 10)
        self.fc2 = nn.Linear(10, 5)

    def forward(self, x):
        # pdb.set_trace()
        x = F.leaky_relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def get_in_dims(self):
        return [4, 10]

    def get_out_dims(self):
        return [10, 5]

    def get_submodules(self):
        return [self.fc1, self.fc2]


class NGPradianceField_tar(TargetNet):
    def __init__(
        self,
        aabb: Union[torch.Tensor, List[float]],
        density_activation: Callable = lambda x: trunc_exp(x - 1),
        unbounded: bool = False,
        n_hidden_layers=3,
        n_neurons=64,
        encoding_size=24
    ) -> None:
        super(NGPradianceField_tar, self).__init__()

        if not isinstance(aabb, torch.Tensor):
            aabb = torch.tensor(aabb, dtype=torch.float32)
        self.register_buffer("aabb", aabb, persistent=False)

        self.encoder, self.embedding_dim = get_embedder(encoding_size, include_input=False)
        
        # input Linear
        self.layers = nn.ModuleList([
            nn.Linear(self.embedding_dim, n_neurons, bias=False),
        ])

        # hidden Linear
        for _ in range(n_hidden_layers - 1):
            self.layers.append(nn.Linear(n_neurons, n_neurons, bias=False))

        # output Linear
        # self.layers.append(nn.Linear(n_neurons, 4, bias=False))
        self.layers.append(nn.Linear(n_neurons, 16, bias=False))
        self.activation = torch.nn.ReLU()
        self.density_activation = density_activation 
        self.D = len(self.layers)
        self.unbounded = unbounded

    def query_density(self, positions: torch.Tensor):
        if self.unbounded:
            positions = self.contract_to_unisphere(positions)
        else:
            aabb_min, aabb_max = torch.split(self.aabb, 3, dim=-1)
            positions = (positions - aabb_min) / (aabb_max - aabb_min)
        selector = ((positions > 0.0) & (positions < 1.0)).all(dim=-1)

        encoded_positions = self.encoder(positions)
        out = encoded_positions
        for i, layer in enumerate(self.layers):
            out = layer(out)
            if i < self.D - 1:
                out = self.activation(out)
        out = out[:, :4]

        density_before_activation = out[..., 3:] 
        density = self.density_activation(density_before_activation)
        density = density * selector[..., None]
        return density

    def forward(self, positions: torch.Tensor):
        if self.unbounded:
            positions = self.contract_to_unisphere(positions)
        else:
            aabb_min, aabb_max = torch.split(self.aabb, 3, dim=-1)
            positions = (positions - aabb_min) / (aabb_max - aabb_min)
        selector = ((positions > 0.0) & (positions < 1.0)).all(dim=-1)
        encoded_positions = self.encoder(positions)

        out = encoded_positions
        # add ReLU after each Linear(except last layer)
        for i, layer in enumerate(self.layers):
            out = layer(out)
            if i < self.D - 1:
                out = self.activation(out)
        out = out[..., :4]

        rgb = torch.sigmoid(out[..., :3])  
        density_before_activation = out[..., 3:] 
        density = self.density_activation(density_before_activation)
        density = density * selector[..., None]
        return rgb, density

    def contract_to_unisphere(
        x: torch.Tensor,
        aabb: torch.Tensor,
        eps: float = 1e-6,
        derivative: bool = False,
    ):
        aabb_min, aabb_max = torch.split(aabb, 3, dim=-1)
        x = (x - aabb_min) / (aabb_max - aabb_min)
        x = x * 2 - 1  # aabb is at [-1, 1]
        mag = x.norm(dim=-1, keepdim=True)
        mask = mag.squeeze(-1) > 1

        if derivative:
            dev = (2 * mag - 1) / mag**2 + 2 * x**2 * (
                1 / mag**3 - (2 * mag - 1) / mag**4
            )
            dev[~mask] = 1.0
            dev = torch.clamp(dev, min=eps)
            return dev
        else:
            x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask])
            x = x / 4 + 0.5  # [-inf, inf] is at [0, 1]
            return x


    def get_last_layer_idx(self):
        return self.D - 1
    
    def get_last_layer(self):
        weight = self.layers[-1].weight
        bias = self.layers[-1].bias
        return torch.cat([weight, bias[:, None]], dim=1).t

    def get_submodules(self):
        for layer in self.layers:
            yield layer

    def print_params(self):
        for i, layer in enumerate(self.layers):
            nparam = (layer.weight.numel() + layer.bias.numel()) / 1e6
            print(f"siren layer {i} param cnt {nparam}M")

    def construct_opt_blocks(
        self,
        ftask_dim,  # Task feature dimension for hypernetwork
        weight_dim,  # Weight embedding dimension
        deriv_hidden_dim,  # Hidden dimension for derivative network
        driv_num_layers,  # Number of layers in derivative network
        *args,
        in_dim=64,  # Input dimension for optimization blocks
        out_dim=64,  # Output dimension for optimization blocks
        **kwargs,
    ):
        """
        Construct optimization blocks for hypernetwork weight updates
        Args:
            ftask_dim: Task feature dimension
            weight_dim: Weight embedding dimension
            deriv_hidden_dim: Hidden dimension for derivative network
            driv_num_layers: Number of layers in derivative network
            in_dim: Input dimension for optimization blocks
            out_dim: Output dimension for optimization blocks
        """
        opt_blocks = nn.ModuleList()

        # Create opt blocks for all layers
        layer_blocks = [
            OptBlock(
                layer,
                ftask_dim,
                out_dim,
                weight_dim,
                deriv_hidden_dim,
                driv_num_layers,
                *args,
                **kwargs,
            )
            for layer in self.layers
        ]

        opt_blocks.extend(layer_blocks)

        # Link blocks sequentially
        for i in range(len(layer_blocks) - 1):
            layer_blocks[i].link(layer_blocks[i + 1])

        # Create forward_in and dloss_dout blocks
        forward_in = nn.ModuleList(
            [
                FIN_FOUT(
                    ftask_dim,
                    in_dim,
                    hidden_dim=deriv_hidden_dim,
                    num_layers=driv_num_layers,
                )
            ]
        )
        dloss_dout = nn.ModuleList(
            [
                FIN_FOUT(
                    ftask_dim,
                    in_dim,
                    hidden_dim=deriv_hidden_dim,
                    num_layers=driv_num_layers,
                )
            ]
        )

        # Link input and output blocks
        forward_in[0].link(layer_blocks[0])
        layer_blocks[-1].link(dloss_dout[0])

        return opt_blocks, forward_in, dloss_dout

class SIREN_tar(TargetNet):
    def __init__(self, D, W, input_ch, output_ch, out_bias=0.0, omega=30.0, *args, **kwargs):
        super().__init__()
        self.omega = omega
        self.D = D
        self.W = W
        self.input_ch = input_ch
        self.output_ch = output_ch
        self.out_bias = out_bias

        # Build layers similar to SIREN
        layers = []
        modulation_weight_scale = []
        # modulation_weight_shift = []
        modulation_bias_scale = []
        modulation_bias_shift = []
        modulation_bias = []
        last_dim = input_ch

        self.total_steps = 1000
        for i in range(D):
            cur_dim = W if i < D - 1 else output_ch
            linear = nn.Linear(last_dim, cur_dim)
            layers.append(linear)
            if i == 0:
                bound = 3 * 1 / last_dim  
            elif i == D - 1:
                bound = 3.5 * math.sqrt(6 / last_dim) / self.omega
            else:    
                bound = 3.5 * math.sqrt(6 / last_dim) / self.omega

            mod_weight_scale = nn.Parameter(torch.empty(cur_dim, last_dim).uniform_(-bound, bound))
            modulation_weight_scale.append(mod_weight_scale)


            mod_bias_scale = nn.Parameter(torch.ones((cur_dim)) * 1e-3)
            modulation_bias_scale.append(mod_bias_scale)
            mod_bias_shift = nn.Parameter(torch.zeros(cur_dim))
            modulation_bias_shift.append(mod_bias_shift)
            last_dim = cur_dim

        self.layers = nn.ModuleList(layers)
        self.modulation_weight_scale = nn.ParameterList(modulation_weight_scale)
        # self.modulation_weight_shift = nn.ParameterList(modulation_weight_shift)
        # self.modulation_bias = nn.ParameterList(modulation_bias)
        self.modulation_bias_scale = nn.ParameterList(modulation_bias_scale)
        self.modulation_bias_shift = nn.ParameterList(modulation_bias_shift)
        self.print_params()

    def update_modulation_factor(self, current_step):
        progress = min(float(current_step) / self.total_steps, 1.0)
        self.modulation_factor.fill_(1.0 - progress)

    def siren_activation(self, x):
        """SIREN non-linearity with frequency scaling."""
        return torch.sin(self.omega * x)

    def forward(self, x, global_step):
        # self.update_modulation_factor(global_step)
        B, query_shape = x.shape[0], x.shape[1:-1]
        x = x.view(B, -1, x.shape[-1])
        for i, layer in enumerate(self.layers):

            bias = layer.bias * self.modulation_bias_scale[i] + self.modulation_bias_shift[i]
            weight = layer.weight * self.modulation_weight_scale[i]
            # for ablation
            # bias = layer.bias
            # weight = layer.weight

            x = F.linear(x, weight, bias)

            if i < self.D - 1:
                x = self.siren_activation(x)
            else:
                x = x + self.out_bias

        x = x.view(B, *query_shape, -1)
        return x
        
    def get_last_layer_idx(self):
        return self.D - 1
    
    def get_last_layer(self):
        weight = self.layers[-1].weight
        bias = self.layers[-1].bias
        return torch.cat([weight, bias[:, None]], dim=1).t

    def get_submodules(self):
        for layer in self.layers:
            yield layer

    def print_params(self):
        for i, layer in enumerate(self.layers):
            nparam = (layer.weight.numel() + layer.bias.numel()) / 1e6
            print(f"siren layer {i} param cnt {nparam}M")

    def construct_opt_blocks(
        self,
        ftask_dim,  # Task feature dimension for hypernetwork
        weight_dim,  # Weight embedding dimension
        deriv_hidden_dim,  # Hidden dimension for derivative network
        driv_num_layers,  # Number of layers in derivative network
        *args,
        in_dim=64,  # Input dimension for optimization blocks
        out_dim=64,  # Output dimension for optimization blocks
        **kwargs,
    ):
        """
        Construct optimization blocks for hypernetwork weight updates
        Args:
            ftask_dim: Task feature dimension
            weight_dim: Weight embedding dimension
            deriv_hidden_dim: Hidden dimension for derivative network
            driv_num_layers: Number of layers in derivative network
            in_dim: Input dimension for optimization blocks
            out_dim: Output dimension for optimization blocks
        """
        opt_blocks = nn.ModuleList()

        # Create opt blocks for all layers
        layer_blocks = [
            OptBlock(
                layer,
                ftask_dim,
                out_dim,
                weight_dim,
                deriv_hidden_dim,
                driv_num_layers,
                *args,
                **kwargs,
            )
            for layer in self.layers
        ]

        opt_blocks.extend(layer_blocks)

        # Link blocks sequentially
        for i in range(len(layer_blocks) - 1):
            layer_blocks[i].link(layer_blocks[i + 1])

        # Create forward_in and dloss_dout blocks
        forward_in = nn.ModuleList(
            [
                FIN_FOUT(
                    ftask_dim,
                    in_dim,
                    hidden_dim=deriv_hidden_dim,
                    num_layers=driv_num_layers,
                )
            ]
        )
        dloss_dout = nn.ModuleList(
            [
                FIN_FOUT(
                    ftask_dim,
                    in_dim,
                    hidden_dim=deriv_hidden_dim,
                    num_layers=driv_num_layers,
                )
            ]
        )

        # Link input and output blocks
        forward_in[0].link(layer_blocks[0])
        layer_blocks[-1].link(dloss_dout[0])

        return opt_blocks, forward_in, dloss_dout
    
class TokenCompressor(nn.Module):
    def __init__(self, input_dim=3, dim=256, num_queries=256):
        super().__init__()
        self.project_kv = nn.Linear(input_dim, dim)
        self.num_queries = num_queries
        self.dim = dim
        # 可学习 query
        self.query_embed = nn.Parameter(torch.randn(1, num_queries, dim))

    def forward(self, quant):
        B, C, H, W = quant.shape
        kv = quant.view(B, C, H * W).permute(0, 2, 1)  # [B, HW, 3]
        kv = self.project_kv(kv)                      # [B, HW, 256]

        # expand query to batch size
        query = self.query_embed.expand(B, -1, -1)     # [B, N, 256]

        attn = torch.matmul(query, kv.transpose(-2, -1))  # [B, N, HW]
        attn = F.softmax(attn / (self.dim ** 0.5), dim=-1)
        out = torch.matmul(attn, kv)  # [B, N, 256]

        return out.view(B, self.dim, int(self.num_queries ** 0.5), int(self.num_queries ** 0.5))

class DownsampleConv(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=3, stride=2, padding=1),  # [B, hidden, 32, 32]
            nn.ReLU(),
            nn.Conv2d(hidden_channels, out_channels, kernel_size=3, stride=2, padding=1), # [B, out, 16, 16]
            nn.ReLU(),
        )

    def forward(self, x):
        return self.conv(x)