from typing import List, Tuple, Dict

import os
from contextlib import contextmanager
import numpy as np
import torch
from torch import nn

from options import EmbeddingsConfig
class PositionalEmbedding(nn.Module):
    def __init__(self, config: EmbeddingsConfig, device: torch.device):
        super(PositionalEmbedding, self).__init__()
        self.device = device
        self.type = config.type
        self.embedding_fusion_mode = config.fusion_mode
        self.normalization_mode = config.normalization_mode
        self.num_idxs = config.num_idxs
        self.base = config.base
        self.levels = config.enc_levels
        self.embedding_cache_folder = config.embedding_cache_folder
        self.gauss_scale = config.gauss_scale
        self.output_size = self._calculate_output_size()
        
        self.indices = None
        self.normalized_indices = None
        
        self.positional_embeddings = None
        self.normalized_positional_embeddings = None
        with zero_seed():
            self.ffn_B = nn.Parameter(torch.randn((self.levels * self.num_idxs, self.num_idxs))
                                      * torch.Tensor(self.gauss_scale), requires_grad=False)

    def _calculate_output_size(self):
        if self.embedding_fusion_mode == 'concat':
            return self.levels * 2 * self.num_idxs
        elif self.embedding_fusion_mode == 'sum' and self.type == 'basic':
            return self.levels * 2

    def forward(self, pos):
        if self.type == 'basic':
            return self._basic(pos)
        elif self.type == 'ffn':
            return self._ffn(pos)
        else:
            raise NotImplementedError(f'Unsupported positional embedding type {self.type}')

    def _ffn(self, pos):
        if self.embedding_fusion_mode == 'concat':
            # Efficient implementation
            x = (torch.tensor(pos) * 2 * np.pi) @ self.ffn_B.T
            final_embeddings = torch.dstack([torch.sin(x), torch.cos(x)]).flatten()
        else:
            raise NotImplementedError(f'Unsupported embedding fusion mode {self.embedding_fusion_mode} for type {self.type}')
        return final_embeddings

    def _basic(self, pos):
        if self.embedding_fusion_mode == 'concat':
            # Efficient implementation
            x = (torch.tensor(pos).unsqueeze(-1) * (self.base ** torch.arange(self.levels)) * np.pi)
            final_embeddings = torch.dstack([torch.sin(x), torch.cos(x)]).flatten()
        elif self.embedding_fusion_mode == 'sum':
            pe_list = []
            # Non-efficient implementation
            for p in pos:
                pe_levels = p * (self.base ** torch.arange(self.levels)) * np.pi
                # Interleaving sin and cos on pe_levels
                pe_list.append(torch.dstack([torch.sin(pe_levels), torch.cos(pe_levels)]).flatten())
            final_embeddings = torch.vstack(pe_list).sum(dim=0)
        else:
            raise NotImplementedError(f'Unsupported embedding fusion mode {self.embedding_fusion_mode} for type {self.type}')
        return final_embeddings

    def __hash__(self):
        pe_type = {
            'basic': 0,
            'ffn': 1
        }
        return hash((self.base, self.levels, self.num_idxs, self.output_size, pe_type[self.type],
                     *tuple(self.gauss_scale)))
    
    def __str__(self):
        return f"PE_{self.type}_{self.base}_{self.levels}_{self.num_idxs}_{self.output_size}_{self.embedding_fusion_mode}"
        
    def _fit_param_indices(self, param_layer_names: List[str], learnable_weights_shapes: Dict[str, torch.Size]) -> List[List[Tuple]]:
        assert len(param_layer_names) == len(learnable_weights_shapes.keys()), "Number of param layer names must match number of learnable weights"
        
        self.indices = {}
        self.normalize_indices = {}

        max_index = max([max(weights_shape) for weights_shape in learnable_weights_shapes.values()])
        num_params = len(param_layer_names)
        
        for param_idx, name in enumerate(param_layer_names):
            curr_param_indices = []
            curr_normalized_param_indices = []
            
            curr_output_channels = None
            if len(learnable_weights_shapes[name]) == 1:
                curr_output_channels = 1
            else:
                curr_output_channels = learnable_weights_shapes[name][0]
            
            for output_channel_idx in range(curr_output_channels):
                # 为每个输出通道生成索引
                curr_param_indices.append((param_idx, output_channel_idx))
                
                # 根据归一化模式处理索引
                if self.normalization_mode == "None":
                    curr_normalized_param_indices.append((param_idx, output_channel_idx))
                elif self.normalization_mode == "global":
                    curr_normalized_param_indices.append(
                        (param_idx / max_index, output_channel_idx / max_index))
                elif self.normalization_mode == "local":
                    curr_normalized_param_indices.append(
                        (param_idx / num_params, output_channel_idx / curr_output_channels))
                else:
                    raise ValueError(f"Unsupported normalization mode {self.normalization_mode}")
            
            self.indices[name] = curr_param_indices
            self.normalize_indices[name] = curr_normalized_param_indices
        
        return self.indices, self.normalize_indices

    def _calculate_positional_embeddings(self, prefix: str = '') -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
        assert self.indices is not None, "Indices must be fitted before calculating positional embeddings"
        os.makedirs(os.path.join(self.embedding_cache_folder, f"{str(self)}_embeddings_{hash(self)}"), exist_ok=True)
        try:
            print("Trying to load precomputed embeddings")
            positional_embeddings = {}
            for param_name, _ in self.indices.items():
                positional_embeddings[param_name] = torch.load(os.path.join(self.embedding_cache_folder, f"{str(self)}_embeddings_{hash(self)}", f"{prefix}_param_{param_name}.pt"), weights_only=True).to(self.device)
                print(f"Loaded positional embeddings for {prefix} {param_name}")
            
            normalized_embeddings = {}
            for param_name, _ in self.normalize_indices.items():
                normalized_embeddings[param_name] = torch.load(os.path.join(self.embedding_cache_folder, f"{str(self)}_embeddings_{hash(self)}",  f"{prefix}_normalized_param_{param_name}.pt"), weights_only=True).to(self.device)
                print(f"Loaded normalized positional embeddings for {prefix} {param_name}")
            
            print("Finished loading precomputed embeddings")
            self.positional_embeddings = positional_embeddings
            self.normalized_positional_embeddings = normalized_embeddings
            
            return self.positional_embeddings, self.normalized_positional_embeddings
        except IOError:
            print("Couldn't load precomputed embeddings, computing embeddings")
        
        positional_embeddings = {}
        for param_name, param_indices in self.indices.items():
            print(f"Calculating param {param_name} embeddings")
            positional_embedding_list = []
            for idx in param_indices:
                positional_embedding_list.append(self.forward(idx))
            positional_embeddings[param_name] = torch.stack(positional_embedding_list).to(self.device)
        self.positional_embeddings = positional_embeddings
            
        normalized_positional_embeddings = {}
        for param_name, param_indices in self.normalize_indices.items():
            print(f"Calculating normalized param {param_name} embeddings")
            positional_embedding_list = []
            for idx in param_indices:
                positional_embedding_list.append(self.forward(idx))
                
            normalized_positional_embeddings[param_name] = torch.stack(positional_embedding_list).to(self.device)
        self.normalized_positional_embeddings = normalized_positional_embeddings
        
        for param_name, positional_embedding in self.positional_embeddings.items():
            print(f"Saving positional embeddings for param {prefix} {param_name}")
            torch.save(positional_embedding, os.path.join(self.embedding_cache_folder, f"{str(self)}_embeddings_{hash(self)}", f"{prefix}_param_{param_name}.pt"))
        
        for param_name, positional_embedding in self.normalized_positional_embeddings.items():
            print(f"Saving normalized positional embeddings for param {prefix} {param_name}")
            torch.save(positional_embedding, os.path.join(self.embedding_cache_folder,f"{str(self)}_embeddings_{hash(self)}", f"{prefix}_normalized_param_{param_name}.pt"))
            
        return self.positional_embeddings, self.normalized_positional_embeddings
    
    def get_indices_and_positional_embeddings(self) -> Tuple[List[List[Tuple]], List[torch.Tensor]]:
        return self.indices, self.positional_embeddings, self.normalized_indices, self.normalized_positional_embeddings

@contextmanager
def zero_seed():
    seed = 0
    try:
        seed = torch.random.get_rng_state()
        torch.manual_seed(0)
        yield
    finally:
        torch.random.set_rng_state(seed)
