import os
import torch
import json
from torch import nn
from typing import List, Dict

def cluster_params(state_dict, cluster_cfg: List[List[str]]):
    clusters = {}
    
    cluster_cfg = cluster_cfg
    for i, cluster in enumerate(cluster_cfg):
        cluster_name = cluster[0] + '_cluster'
        param_base = state_dict[cluster[0]]
        if len(param_base.shape) == 4:
            param_base = param_base.view(param_base.shape[0], -1)
            
        for j in range(1, len(cluster)):
            param_name = cluster[j]
            param = state_dict[param_name]
            param = param.view(param.shape[0], -1)
            param_base = torch.cat((param_base, param), dim=1)

        clusters[cluster_name] = param_base
    return clusters

def decluster_params(clusters, cluster_cfg: Dict):
    state_dict = {}
    for cluster_name, cluster_info in cluster_cfg.items():
        param_base = clusters[cluster_name]
        for param_name, param_shape in cluster_info.items():
            if len(param_shape) == 4:
                param_len = param_shape[1] * param_shape[2] * param_shape[3]
            elif len(param_shape) == 2:
                param_len = param_shape[1]
            elif len(param_shape) == 1:
                param_len = 1
            state_dict[param_name] = param_base[:, :param_len].view(param_shape)
            param_base = param_base[:, param_len:]
        assert param_base.shape[1] == 0
    return state_dict

class ModelHelper():
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model
        self.learnable_param_names = []
        self.if_cluster = None
        self.cluster_cfg = None
        
        self.init_default_learnable_param_names()
        
    def get_learnable_weights(self) -> Dict[str, torch.Tensor]:
        learnable_weights = {}
        state_dict = self.model.state_dict()
        # 簇默认是全部参数，不受learnable_param_names影响
        if self.if_cluster:
            cluster_cfg = []
            cluster_cfg_dict = self.cluster_cfg["struct"]
            for name, param_dict in cluster_cfg_dict.items():
                cluster_cfg.append(list(param_dict.keys()))
            learnable_weights = cluster_params(state_dict, cluster_cfg)
        else:
            for name, param in self.model.state_dict().items():
                if name in self.learnable_param_names:
                    learnable_weights[name] = param
        return learnable_weights
    
    def get_learnable_weights_shapes(self) -> Dict[str, torch.Size]:
        return {name: weight.shape for name, weight in self.get_learnable_weights().items()}
    
    def update_weights(self, reconstructed_weights: Dict[str, torch.Tensor]):
        learnable_weights = self.get_learnable_weights()
        if self.if_cluster:
            reconstructed_weights = decluster_params(reconstructed_weights, self.cluster_cfg["struct"])
            for name in reconstructed_weights.keys():
                assert self.model.state_dict()[name].shape == reconstructed_weights[name].shape
                if 'running_mean' not in name and 'running_var' not in name:
                    self.model.state_dict()[name].copy_(reconstructed_weights[name])
        else:
            for name in learnable_weights.keys():
                assert name in reconstructed_weights
                curr_layer_weights = learnable_weights[name]
                curr_predicted_weights = reconstructed_weights[name]
                assert curr_layer_weights.shape == curr_predicted_weights.shape
                # curr_layer_weights.data = curr_predicted_weights.data
                curr_layer_weights.copy_(curr_predicted_weights)

    # def reinitialize_learnable_weights(self):
    #     if self.if_cluster:
    #         for name, weight in self.model.state_dict().items():
    #             if len(weight.shape) >= 2:
    #                 nn.init.xavier_normal_(weight)
    #     else:
    #         for weight in self.get_learnable_weights().values():
    #             nn.init.xavier_normal_(weight)
    
    def load(self, path: str, device: torch.device):
        self.model.load_state_dict(torch.load(path, map_location=device, weights_only=False))
    
    def save(self, path: str):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save(self.model.state_dict(), path)
        
    def init_default_learnable_param_names(self):
        self.learnable_param_names = [name for name, _ in self.model.state_dict().items() if 'num_batches_tracked' not in name]
        
    def set_learnable_param_names(self, learnable_param_names: List[str]):
        self.learnable_param_names = learnable_param_names
            
    def set_cluster(self, if_cluster: bool, cluster_cfg_path: str):
        self.if_cluster = if_cluster
        if self.if_cluster:
            with open(cluster_cfg_path, 'r') as f:
                self.cluster_cfg = json.load(f)
            
            
    

    


