import random
from abc import ABC, abstractmethod
from typing import List, Tuple, Dict

import einops
import numpy as np
import torch
import torch.nn as nn
from torch import nn, distributions
import torch.nn.functional as F


from options import AEConfig
from models.ae.ae_fc_modules import Encoder, Decoder

class AutoEncoder(nn.Module, ABC):
    def __init__(self, cfg: AEConfig, input_size: int):
        super().__init__()
        self.cfg = cfg
        self.input_size = input_size

    def save(self, path: str):
        torch.save(self.state_dict(), path)

    def load(self, path: str):
        self.load_state_dict(torch.load(path, weights_only=False))

    @abstractmethod
    def forward(self, positional_embedding: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError()

    def _expand_weight(self, layer_original_weight: torch.Tensor, target_size: int, type: str) -> torch.Tensor:
        assert len(layer_original_weight.shape) == 2 
        
        current_size = layer_original_weight.size(-1)
        if current_size < target_size:
            if type == "linear":
                # 使用线性插值进行扩充
                layer_original_weight_expanded = F.interpolate(
                    layer_original_weight.unsqueeze(0), 
                    size=(target_size),  
                    mode='linear',  
                    align_corners=False
                ).squeeze(0)
            elif type == "pad":
                # 使用末尾补均值进行扩充
                row_means = layer_original_weight.mean(dim=-1, keepdim=True)
                
                padding = target_size - current_size
                padding_value = row_means
                layer_original_weight_expanded = F.pad(layer_original_weight, (0, padding), value=padding_value)
            elif type == "pad_zero":
                # 在首尾补充相同个数的0
                min_coord = int(target_size / 2 - current_size / 2)
                max_coord = int(target_size / 2 + current_size / 2)

                # 补零计算
                padding_left = min_coord
                padding_right = target_size - max_coord

                # 使用 F.pad 补零，只在最后一维进行
                layer_original_weight_expanded = F.pad(layer_original_weight, (padding_left, padding_right), value=0)
            else:
                raise ValueError(f"Unsupported expansion method: {type}")
        else:
            layer_original_weight_expanded = layer_original_weight  # 如果已达到所需大小，不做扩充
        
        
        return layer_original_weight_expanded
    
    def _encode_weights(self, layer_original_weight: torch.Tensor) -> torch.Tensor:
        # 卷积转换成2维，缩减kernel维度
        if len(layer_original_weight.shape) == 4:
            layer_original_weight_reshape = layer_original_weight.reshape(layer_original_weight.shape[0], -1)
        elif len(layer_original_weight.shape) == 1:
            layer_original_weight_reshape = layer_original_weight.unsqueeze(0)
        else:
            layer_original_weight_reshape = layer_original_weight
        
        assert len(layer_original_weight_reshape.shape) == 2
        layer_original_weight_expanded = self._expand_weight(layer_original_weight_reshape, self.input_size, self.cfg.expansion_type)
        
        if self.cfg.param_batch_size is None:
            layer_encode_weight = self.encode(layer_original_weight_expanded)
        else:
            layer_encode_weight_list = []
            for batch_idx in range(0, layer_original_weight_expanded.shape[0], self.cfg.param_batch_size):
                weights_batch = layer_original_weight_expanded[batch_idx: batch_idx + self.cfg.param_batch_size]
                layer_encode_weight_list.append(self.encode(weights_batch))
                
            layer_encode_weight = torch.vstack(layer_encode_weight_list)
        return layer_encode_weight
        
    def _decode_weights(self, encode_weight: torch.Tensor, positional_embedding: torch.Tensor, original_shape: torch.Size) -> torch.Tensor:
        assert encode_weight.shape == positional_embedding.shape
        if self.cfg.param_batch_size is None:
            decode_weight = self.decode(encode_weight, positional_embedding)
        else:
            decode_weight_list = []
            for batch_idx in range(0, encode_weight.shape[0], self.cfg.param_batch_size):
                encode_weight_batch = encode_weight[batch_idx: batch_idx + self.cfg.param_batch_size]
                positional_embedding_batch = positional_embedding[batch_idx: batch_idx + self.cfg.param_batch_size]
                decode_weight_list.append(self.decode(encode_weight_batch, positional_embedding_batch))
                
            decode_weight = torch.vstack(decode_weight_list)
        original_length = None
        if len(original_shape) == 4:
            original_length = original_shape[1] * original_shape[2] * original_shape[3]
        elif len(original_shape) == 2:
            original_length = original_shape[1]
        elif len(original_shape) == 1:
            original_length = original_shape[0]
        
        decode_weight = self._sample_layer_weights_by_shape(decode_weight, original_length)
        
        decode_weight = decode_weight.reshape(original_shape)
        return decode_weight
        
    def _predict_weights(self, layer_original_weight: torch.Tensor, layer_positional_embeddings: torch.Tensor):
        # 卷积转换成2维，缩减kernel维度
        if len(layer_original_weight.shape) == 4:
            layer_original_weight_reshape = layer_original_weight.reshape(layer_original_weight.shape[0], -1)
        elif len(layer_original_weight.shape) == 1:
            layer_original_weight_reshape = layer_original_weight.unsqueeze(0)
        else:
            layer_original_weight_reshape = layer_original_weight
        
        assert len(layer_original_weight_reshape.shape) == 2 and len(layer_positional_embeddings.shape) == 2
        layer_original_weight_expanded = self._expand_weight(layer_original_weight_reshape, self.input_size, self.cfg.expansion_type)
        if self.cfg.param_batch_size is None:
            layer_reconstructed_weight = self.forward(layer_original_weight_expanded, layer_positional_embeddings)

        else:
            layer_reconstructed_weight_list = []
            for batch_idx in range(0, layer_original_weight_expanded.shape[0], self.cfg.param_batch_size):
                weights_batch = layer_original_weight_expanded[batch_idx: batch_idx + self.cfg.param_batch_size]
                embeddings_batch = layer_positional_embeddings[batch_idx: batch_idx + self.cfg.param_batch_size]
                layer_reconstructed_weight_list.append(self.forward(weights_batch, embeddings_batch))
                
            layer_reconstructed_weight = torch.vstack(layer_reconstructed_weight_list)
        
        layer_reconstructed_weight = self._sample_layer_weights_by_shape(layer_reconstructed_weight, layer_original_weight_reshape.shape[-1])
        
        layer_reconstructed_weight = layer_reconstructed_weight.reshape(layer_original_weight.shape)

        return layer_reconstructed_weight
    
    def _sample_layer_weights_by_shape(self, layer_reconstructed_weights: torch.Tensor, layer_learnable_output_size: int):
        curr_predicted_output_size = layer_reconstructed_weights.shape[-1]
        
        if self.cfg.sampling_mode == "center":
            # 对输出特征（out_features）按中心进行采样
            min_coord = int(curr_predicted_output_size / 2 - layer_learnable_output_size / 2)
            max_coord = int(curr_predicted_output_size / 2 + layer_learnable_output_size / 2)
            sampled_predicted_weight = layer_reconstructed_weights[:,min_coord:max_coord]
        
        elif self.cfg.sampling_mode == "average":
            # 对输出特征使用平均池化
            sampled_predicted_weight = F.avg_pool1d(layer_reconstructed_weights.unsqueeze(1),
                                                     curr_predicted_output_size - layer_learnable_output_size + 1,
                                                     1).squeeze(1)
        elif self.cfg.sampling_mode == "max":
            # 对输出特征使用最大池化
            sampled_predicted_weight = F.max_pool1d(layer_reconstructed_weights.unsqueeze(1),
                                                     curr_predicted_output_size - layer_learnable_output_size + 1,
                                                     1).squeeze(1)
        else:
            raise ValueError(f"Unsupported sampling mode {self.cfg.sampling_mode}")

        return sampled_predicted_weight
            
    def predict_all(self, positional_embeddings: Dict[str, torch.Tensor],
                    original_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        
        reconstructed_weights = {}
        with torch.no_grad():
            for name in positional_embeddings.keys():
                embedding = positional_embeddings[name]
                original_weight = original_weights[name]
                layer_reconstructed_weights = self._predict_weights(original_weight, embedding)
                reconstructed_weights[name] = layer_reconstructed_weights
        
        return reconstructed_weights
        
        
class LinearAutoEncoder(AutoEncoder):
    def __init__(self, cfg: AEConfig, input_size: int):
        super().__init__(cfg, input_size)
        self.input_noise_factor = cfg.input_noise_factor  # 0.001
        self.latent_noise_factor = cfg.latent_noise_factor  # 0.5
        print(f"LinearAutoEncoder: input_noise_factor={self.input_noise_factor}, latent_noise_factor={self.latent_noise_factor}")

        self.encoder = Encoder(cfg.enc_dim_list)
        self.decoder = Decoder(cfg.dec_dim_list)
        
        # self.latent_vector_norm_list = []
        # self.positional_embeddings_norm_list = []

    def encode(self, x):
        return self.encoder(x)

    def decode(self, x, positional_embeddings):
        assert x.shape[0] == positional_embeddings.shape[0]
        if self.cfg.latent_z_score:
            x = self.z_score(x)
            positional_embeddings = self.z_score(positional_embeddings)
        
        if self.cfg.use_embeddings:
            x = torch.cat([x, positional_embeddings], dim=-1)
        return self.decoder(x)

    def forward(self, x, positional_embeddings) -> torch.Tensor:
        assert x.shape[0] == positional_embeddings.shape[0], f"{x.shape} != {positional_embeddings.shape}"
        
        if self.input_noise_factor > 0 and self.training:
            x = self.add_noise(x, self.input_noise_factor)
            
        x = self.encode(x)

        # self.latent_vector_norm_list.append(x.norm())
        # self.positional_embeddings_norm_list.append(positional_embeddings.norm())
        
        if self.cfg.latent_z_score:
            x = self.z_score(x)
            positional_embeddings = self.z_score(positional_embeddings)
            
        if self.latent_noise_factor > 0 and self.training:
            x = self.add_noise(x, self.latent_noise_factor)
        
        x = self.decode(x, positional_embeddings)
        return x

    def add_noise(self, x, noise_factor):
        if not isinstance(noise_factor, float):
            assert len(noise_factor) == 2
            noise_factor = random.uniform(noise_factor[0], noise_factor[1])
        return torch.randn_like(x) * noise_factor + x * (1 - noise_factor)

    @property
    def output_size(self) -> int:
        return self.cfg.output_size

    def calc_weights_norms(self, weights: List[torch.Tensor]) -> List[torch.Tensor]:
        return [
            torch.norm(weight, dim=0)  # 沿着输出维度（dim=0）计算L2范数
            for weight in weights
        ]

    # 按行进行标准化，
    def z_score(self, x: torch.Tensor) -> torch.Tensor:
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return (x - mean) / std
    
    # def clear_norm_cache(self):
    #     self.latent_vector_norm_list = []
    #     self.positional_embeddings_norm_list = []