from typing import List, Tuple, Dict

import os
from contextlib import contextmanager
import numpy as np
import torch
from torch import nn

from options import EmbeddingsConfig
class PositionalEmbedding(nn.Module):
    def __init__(self, config: EmbeddingsConfig, param_index_dict: Dict[str, int], device: torch.device):
        super(PositionalEmbedding, self).__init__()
        self.device = device
        self.type = config.type
        self.embedding_fusion_mode = config.fusion_mode
        self.normalization_mode = config.normalization_mode
        self.num_idxs = config.num_idxs
        self.base = config.base
        self.levels = config.enc_levels
        self.embedding_cache_folder = config.embedding_cache_folder
        self.gauss_scale = config.gauss_scale
        self.output_size = self._calculate_output_size()
        self.param_index_dict = param_index_dict
        
        self.indices = None
        self.normalized_indices = None
        
        self.positional_embeddings = None
        self.normalized_positional_embeddings = None

    def _calculate_output_size(self):
        if self.embedding_fusion_mode == 'concat':
            return self.levels * 2 * self.num_idxs
        elif self.embedding_fusion_mode == 'sum' and self.type == 'basic':
            return self.levels * 2

    def forward(self, pos):
        if self.type == 'basic':
            return self._basic(pos)
        else:
            raise NotImplementedError(f'Unsupported positional embedding type {self.type}')

    def _basic(self, pos):
        if self.embedding_fusion_mode == 'concat':
            # Efficient implementation
            x = (torch.tensor(pos).unsqueeze(-1) * (self.base ** torch.arange(self.levels)) * np.pi)
            final_embeddings = torch.dstack([torch.sin(x), torch.cos(x)]).flatten()
        elif self.embedding_fusion_mode == 'sum':
            pe_list = []
            # Non-efficient implementation
            for p in pos:
                pe_levels = p * (self.base ** torch.arange(self.levels)) * np.pi
                # Interleaving sin and cos on pe_levels
                pe_list.append(torch.dstack([torch.sin(pe_levels), torch.cos(pe_levels)]).flatten())
            final_embeddings = torch.vstack(pe_list).sum(dim=0)
        else:
            raise NotImplementedError(f'Unsupported embedding fusion mode {self.embedding_fusion_mode} for type {self.type}')
        return final_embeddings

    def __hash__(self):
        pe_type = {
            'basic': 0,
        }
        return hash((self.base, self.levels, self.num_idxs, self.output_size, pe_type[self.type],
                     *tuple(self.gauss_scale)))
    
    def __str__(self):
        return f"PE_{self.type}_{self.base}_{self.levels}_{self.num_idxs}_{self.output_size}_{self.embedding_fusion_mode}"
        
    def _fit_param_indices(self, param_layer_names: List[str], learnable_weights_shapes: Dict[str, torch.Size]):
        assert len(param_layer_names) == len(learnable_weights_shapes.keys()), "Number of param layer names must match number of learnable weights"
        
        self.indices = {}
        self.normalize_indices = {}

        max_index = max([max(weights_shape) for weights_shape in learnable_weights_shapes.values()])
        num_params = len(param_layer_names)
        
        for param_idx, name in enumerate(param_layer_names):
            curr_param_indices = []
            curr_normalized_param_indices = []
            
            curr_output_channels = None
            assert  len(learnable_weights_shapes[name]) >= 2, "Learnable weights must have at least 2 dimensions"
            
            curr_output_channels = learnable_weights_shapes[name][0]
                
            if len(learnable_weights_shapes[name]) == 2:
                curr_channel_dim = learnable_weights_shapes[name][1]
            elif len(learnable_weights_shapes[name]) == 4:
                curr_channel_dim = learnable_weights_shapes[name][1] * learnable_weights_shapes[name][2] * learnable_weights_shapes[name][3]
                
            
            for output_channel_idx in range(curr_output_channels):
                # 为每个输出通道生成索引
                curr_param_indices.append((param_idx, output_channel_idx, curr_channel_dim))
                
                # 根据归一化模式处理索引
                if self.normalization_mode == "None":
                    curr_normalized_param_indices.append((param_idx, output_channel_idx, curr_channel_dim))
                elif self.normalization_mode == "global":
                    curr_normalized_param_indices.append(
                        (param_idx / max_index, output_channel_idx / max_index, curr_channel_dim))
                elif self.normalization_mode == "local":
                    curr_normalized_param_indices.append(
                        (param_idx / num_params, output_channel_idx / curr_output_channels, curr_channel_dim))
                else:
                    raise ValueError(f"Unsupported normalization mode {self.normalization_mode}")
            
            self.indices[name] = curr_param_indices
            self.normalize_indices[name] = curr_normalized_param_indices
        
        return self.indices, self.normalize_indices

    def _calculate_positional_embeddings(self, prefix: str = '') -> Tuple[Dict[str, Dict[str, torch.Tensor]], Dict[str, Dict[str, torch.Tensor]]]:
        assert self.indices is not None, "Indices must be fitted before calculating positional embeddings"
        for p_src, _ in self.param_index_dict.items():
            os.makedirs(os.path.join(self.embedding_cache_folder, f"{str(self)}_{p_src}_embeddings_{hash(self)}"), exist_ok=True)
        
        try:
            print("Trying to load precomputed embeddings")
            positional_embeddings = {}
            normalized_embeddings = {}
            for p_src, p_index in self.param_index_dict.items():
                positional_embeddings[p_src] = {}
                normalized_embeddings[p_src] = {}
                
                for param_name, _ in self.indices.items():
                    positional_embeddings[p_src][param_name] = torch.load(os.path.join(self.embedding_cache_folder, f"{str(self)}_{p_src}_embeddings_{hash(self)}", f"{prefix}_param_{param_name}.pt"), weights_only=True).to(self.device)
                    # print(f"Loaded positional embeddings for {prefix} {param_name}")
                
                for param_name, _ in self.normalize_indices.items():
                    normalized_embeddings[p_src][param_name] = torch.load(os.path.join(self.embedding_cache_folder, f"{str(self)}_{p_src}_embeddings_{hash(self)}",  f"{prefix}_normalized_param_{param_name}.pt"), weights_only=True).to(self.device)
                    # print(f"Loaded normalized positional embeddings for {prefix} {param_name}")
            
            print("Finished loading precomputed embeddings")
            self.positional_embeddings = positional_embeddings
            self.normalized_positional_embeddings = normalized_embeddings
            
            return self.positional_embeddings, self.normalized_positional_embeddings
        except IOError:
            print("Couldn't load precomputed embeddings, computing embeddings")
        
        print("Calculating positional embeddings")
        positional_embeddings = {}
        normalized_positional_embeddings = {}
        for p_src, p_index in self.param_index_dict.items():
            positional_embeddings[p_src] = {}
            normalized_positional_embeddings[p_src] = {}
            for param_name, param_indices in self.indices.items():
                # print(f"Calculating param {param_name} embeddings")
                positional_embedding_list = []
                for idx in param_indices:
                    idx = (p_index, ) + idx
                    positional_embedding_list.append(self.forward(idx))
                positional_embeddings[p_src][param_name] = torch.stack(positional_embedding_list).to(self.device)
            
                
            for param_name, param_indices in self.normalize_indices.items():
                # print(f"Calculating normalized param {param_name} embeddings")
                positional_embedding_list = []
                for idx in param_indices:
                    idx = (p_index, ) + idx
                    positional_embedding_list.append(self.forward(idx))
                    
                normalized_positional_embeddings[p_src][param_name] = torch.stack(positional_embedding_list).to(self.device)
            
        self.positional_embeddings = positional_embeddings
        self.normalized_positional_embeddings = normalized_positional_embeddings
        
        print("Finished calculating positional embeddings")
        print(f"Saving positional embeddings {self.embedding_cache_folder}")
        
        for p_src, p_index in self.param_index_dict.items():
            for param_name, positional_embedding in self.positional_embeddings[p_src].items():
                # print(f"Saving positional embeddings for param {prefix} {param_name}")
                torch.save(positional_embedding, os.path.join(self.embedding_cache_folder, f"{str(self)}_{p_src}_embeddings_{hash(self)}", f"{prefix}_param_{param_name}.pt"))
            
            for param_name, positional_embedding in self.normalized_positional_embeddings[p_src].items():
                # print(f"Saving normalized positional embeddings for param {prefix} {param_name}")
                torch.save(positional_embedding, os.path.join(self.embedding_cache_folder, f"{str(self)}_{p_src}_embeddings_{hash(self)}",  f"{prefix}_normalized_param_{param_name}.pt"))
            
        return self.positional_embeddings, self.normalized_positional_embeddings
    
    def get_indices_and_positional_embeddings(self):
        return self.indices, self.positional_embeddings, self.normalized_indices, self.normalized_positional_embeddings

@contextmanager
def zero_seed():
    seed = 0
    try:
        seed = torch.random.get_rng_state()
        torch.manual_seed(0)
        yield
    finally:
        torch.random.set_rng_state(seed)
