import os
import ujson as json
import zipfile
import numpy as np
import pickle
import random
import torch
import shutil



def load_DiT(file_path):

    state_dict = torch.load(file_path, map_location=torch.device('cpu'))  # Load state dictionary
    return state_dict


def load_pretrained_weights_auto_report(model, pretrained_model_state_dict, blk_prefix, DiT_blks):

    assert len(blk_prefix) == DiT_blks, 'len(blk_prefix) must be the same as DiT_blks!'

    report = {}
    model_state_dict = model.state_dict()

    # for prev_name, prev_param in pretrained_model_state_dict.items():
        # print(f'>>>check prev layer name={prev_name}')

    # for model_name, prev_param in model_state_dict.items():
        # print(f'>>>check custom layer name={model_name}')

    # t_embedder
    for model_name, model_param in model_state_dict.items():
        for prev_name, prev_param in pretrained_model_state_dict.items():
            if 't_embedder' in prev_name and 't_embedder' in model_name:
                # print(f'prev_name={prev_name}') # t_embedder.mlp.0.bias
                model_layer_name = '.'.join(model_name.split('.')[3:])
                if model_layer_name == prev_name:
                    model_state_dict[model_name].copy_(prev_param)
                    report[prev_name] = model_name
                    break

    # blocks
    for custom_blk_idx in range(DiT_blks):

        for model_name, model_param in model_state_dict.items():
            for prev_name, prev_param in pretrained_model_state_dict.items():

                if f'blocks.{custom_blk_idx}' in model_name:
                    expect_prev_layer_name = model_name.split('.')[3] + '.' + str(blk_prefix[custom_blk_idx]) + '.' + '.'.join(model_name.split('.')[5:])

                    if expect_prev_layer_name == prev_name:
                        model_state_dict[model_name].copy_(pretrained_model_state_dict[expect_prev_layer_name])
                        report[expect_prev_layer_name] = model_name
                        break
    return report



# for model param stats
def compute_layer_parameter_stats(model):
    layer_stats = {}
    for name, param in model.named_parameters():
        if 'weight' in name and 'net_d' in name:  # Only consider weight parameters in denoising network
            layer_name = name # name.split('.')[0]  # Get the name of the layer
            if layer_name not in layer_stats:
                layer_stats[layer_name] = {'mean': [], 'std': []}
            layer_stats[layer_name]['mean'].append(param.mean().item())
            layer_stats[layer_name]['std'].append(param.std().item())
    # Compute mean and std across all parameters in each layer
    for layer_name in layer_stats:
        layer_stats[layer_name]['mean'] = torch.tensor(layer_stats[layer_name]['mean']).mean().item()
        layer_stats[layer_name]['std'] = torch.tensor(layer_stats[layer_name]['std']).mean().item()
    return layer_stats

# Only works with huggingface param names
def freeze_layers_clip(model, freeze_layer_num):
    assert hasattr(model, 'clip')
    assert freeze_layer_num <= 12 and freeze_layer_num >= -1

    if freeze_layer_num == -1:
        return

    for name, param in model.clip.named_parameters():
        # top layers always need to train
        if 'final_layer_norm' in name or 'text_projection' in name \
                or 'post_layernorm' in name or 'visual_projection' in name \
                or 'logit_scale' in name:
            continue # need to train
        
        elif 'text_model.encoder.layers' in name or 'vision_model.encoder.layers' in name:
            layer_num = int(name.split('.layers.')[1].split('.')[0])
            if layer_num >= freeze_layer_num:
                continue # need to train

        print(name)
        param.requires_grad = False
        

def load_json(filename):
    with open(filename, "r") as f:
        return json.load(f)


def read_lines(filepath):
    with open(filepath, "r") as f:
        return [e.strip("\n") for e in f.readlines()]


def mkdirp(p):
    if not os.path.exists(p):
        os.makedirs(p)

def deletedir(p):
    if os.path.exists(p):
        shutil.rmtree(p)
