import json
import os
import pyrallis
from tqdm import tqdm

import models
import torch
import torch.nn as nn

from options import PrepareParamConfig

def cluster_params(state_dict, cfg: PrepareParamConfig):
    clusters = {}
    
    cluster_cfg = cfg.cluster_cfg
    for i, cluster in enumerate(cluster_cfg):
        cluster_name = cluster[0] + '_cluster'
        param_base = state_dict[cluster[0]]
        if len(param_base.shape) == 4:
            param_base = param_base.view(param_base.shape[0], -1)
            
        for j in range(1, len(cluster)):
            param_name = cluster[j]
            param = state_dict[param_name]
            param = param.view(param.shape[0], -1)
            param_base = torch.cat((param_base, param), dim=1)

        clusters[cluster_name] = param_base
        print(param_base.shape)
    return clusters

@pyrallis.wrap()
def main(cfg: PrepareParamConfig):
    print('-' * 50)
    print('MODEL DIR:', cfg.model_path)
    print('SAVE DIR:', cfg.save_dir)
    print('-' * 50)
    
    # ----------------------------------------
    # basic configuration
    # ----------------------------------------
    os.makedirs(cfg.save_dir, exist_ok=True)
    # ----------------------------------------
    # each model
    # ----------------------------------------
    param_info = {}
    param_var_info = {}
    param_files = []

    for num in tqdm(range(1, cfg.param_num + 1)):
        # load model's state_dict
        state = torch.load(cfg.model_path.format(num), map_location='cpu', weights_only=True)

        # param names configuration
        if num == 0 and len(cfg.param_names) == 0:
            cfg.param_names = []
            for name, param in state.items():
                if 'num_batches_tracked' not in name:
                    cfg.param_names.append(name)  # total parameters (except bn.num_batches_tracked)
        print('--> param_names:', cfg.param_names)

        # extract param struct
        if num == 1:
            param_info = get_param_struct(state, param_info, cfg)

        # extract and save param data
        min_max_var_dict = get_param_data(state, num, cfg)
        file_name_fmt = os.path.basename(cfg.model_path)
        param_files.append(file_name_fmt.format(num))
        param_var_info[file_name_fmt.format(num)] = min_max_var_dict

    # save param info
    param_info['files'] = param_files
    param_info['var_info'] = param_var_info
    param_info['model_classnum'] = cfg.model_classnum
    save_path = os.path.join(cfg.save_dir, 'param_info.json')
    with open(save_path, "w", encoding="utf-8") as json_file:
        json.dump(param_info, json_file, ensure_ascii=False, indent=4)


def get_param_data(state_dict, num: int, cfg: PrepareParamConfig):
    min_var = 1e9
    max_var = 0
    if cfg.param_fmt == 'paramwise':  # apart (by param name)
        file_name_fmt = os.path.basename(cfg.model_path)
        for name, param in state_dict.items():
            if name in cfg.param_names:
                save_dir = os.path.join(cfg.save_dir, name)
                os.makedirs(save_dir, exist_ok=True)
                save_path = os.path.join(str(save_dir), file_name_fmt.format(num))
                torch.save(param, save_path)
                
                var = param.var().item()
                min_var = min(min_var, var)
                max_var = max(max_var, var)

    elif cfg.param_fmt == 'clusterwise':
        clusters = cluster_params(state_dict, cfg)
        file_name_fmt = os.path.basename(cfg.model_path)
        for name, param in clusters.items():
            save_dir = os.path.join(cfg.save_dir, name)
            os.makedirs(save_dir, exist_ok=True)
            save_path = os.path.join(str(save_dir), file_name_fmt.format(num))
            torch.save(param, save_path)
            
            var = param.var().item()
            min_var = min(min_var, var)
            max_var = max(max_var, var)
    
    return {"min_var": min_var, "max_var": max_var}

def get_param_struct(state_dict, param_info, cfg: PrepareParamConfig):
    param_info['struct'] = {}
    
    if cfg.param_fmt == 'paramwise':  # apart (by param name)
        for name, param in state_dict.items():
            if name in cfg.param_names:
                param_info['struct'][name] = list(param.shape)
    elif cfg.param_fmt == 'clusterwise':
        cluster_cfg = cfg.cluster_cfg
        for cluster_param_names in cluster_cfg:
            cluster_name = cluster_param_names[0] + '_cluster'
            param_info['struct'][cluster_name] = {name : list(state_dict[name].shape) for name in cluster_param_names}

    return param_info








if __name__ == '__main__':
    main()
