import os
import random
import pyrallis
import json
from tqdm import tqdm
from typing import Dict, List

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from loaders.image_loader import load_images
from loaders.param_loader import load_paired_params
from models.base import load_model
from models.diffusion import load_diff
from models.ae import load_ae
from models.position import load_positional_embedding
from models.base.model_helper import ModelHelper
from optimization.optimizer import load_optimizer
from optimization.scheduler import load_lr_scheduler
from loss.task_loss import load_task_loss
from loss.reconstruction_loss import load_reconstruction_loss
from utils.log_utils import log_scalar_dict, create_experiment_dir
from utils.visualize_util import visualize_weights
from utils.visualize_util import visualize_weights_layerwise, visualize_weights_channelwise
import time

from options import DiffusionTrainConfig

@pyrallis.wrap()
def main(cfg: DiffusionTrainConfig):
    init_seed(2025)
    # ----------------------------------------
    # basic configuration
    # ----------------------------------------
    use_cuda = not cfg.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # ----------------------------------------
    # ae configuration
    # ----------------------------------------
    ae = load_ae(cfg.ae).to(device)
    assert cfg.ae.checkpoint_path is not None, "Please provide a checkpoint path for the autoencoder."
    ae.load(cfg.ae.checkpoint_path)
    ae.eval()
    
    # ----------------------------------------
    # diffusion model configuration
    # ----------------------------------------
    model_diff = load_diff(
        unet_name=cfg.diffusion.unet_name,
        dim=cfg.diffusion.input_dim,
        dim_cond=cfg.diffusion.condition_dim,
        self_cond=True,
        layer_channels=cfg.diffusion.layer_channels,
        init_ch=cfg.diffusion.init_ch,
        final_ch=cfg.diffusion.final_ch,
        timesteps=cfg.diffusion.num_timesteps
    ).to(device)
    if cfg.diffusion.checkpoint_path is not None:
        model_diff.load(cfg.diffusion.checkpoint_path)
        print(f"Loaded diffusion model from {cfg.diffusion.checkpoint_path}")
    if cfg.diffusion.init_ch == 3:
        print("Diffusion use positional embeddings")
    else:
        print("Diffusion do not use positional embeddings")
        
    print("Diffusion Loss Type: ", cfg.loss_type)
    
    # ----------------------------------------
    # positional embeddings configuration
    # ----------------------------------------
    positional_embedding_dict = {}
    for mc, info in cfg.model_task_dict.items():
        model_name, num_classes = mc.split('&')
        
        cluster_cfg_path = info['cluster_cfg_path']
        cluster_names = None
        
        with open(cluster_cfg_path, 'r') as f:
            cluster_cfg = json.load(f)
            cluster_names = list(cluster_cfg['struct'].keys())
        
        original_model = load_model(model_name, int(num_classes)).to(device)
        model_helper = ModelHelper(original_model)
        model_helper.set_cluster(True, cluster_cfg_path)
        
        positional_embedding_model = load_positional_embedding(cfg.ae.embedding, device)
        
        positional_embedding_model._fit_param_indices(cluster_names, model_helper.get_learnable_weights_shapes())
        
        positional_embedding_model._calculate_positional_embeddings(prefix=mc)
        _, positional_embeddings, _, normalized_positional_embeddings = positional_embedding_model.get_indices_and_positional_embeddings()
        positional_embeddings = positional_embeddings if cfg.ae.use_normalization_embedding else normalized_positional_embeddings
        if cfg.ae.mu_std_normalization_embedding:
            for name, positional_embedding in positional_embeddings.items():
                positional_embeddings[name] = (positional_embedding - positional_embedding.mean()) / positional_embedding.std()
        positional_embedding_dict[mc] = positional_embeddings
    
    if not cfg.ae.use_embeddings:
        print("Do not use positional embeddings")
    else:
        print("Use positional embeddings" if not cfg.ae.use_normalization_embedding else "Use normalization positional embeddings")
        print("Use mu std normalization positional embeddings" if cfg.ae.mu_std_normalization_embedding else "Do not use mu std normalization positional embeddings")
    
    task_loss = load_task_loss(cfg.task)
    test_diff(positional_embedding_dict, model_diff, ae, cfg.model_task_dict, task_loss, cfg.logging, device)
        
@torch.no_grad()
def test_diff(positional_embedding_dict, model_diff, ae, model_task_dict, task_loss, visualize_cfg, device):
    model_diff.eval()
    
    test_losses = {}
    test_accuracies = {}
    
    if visualize_cfg.visualize:
        os.makedirs(visualize_cfg.visualize_save_file, exist_ok=True)
        
    for model_classnum, task_info in model_task_dict.items():
        encode_weights = {}
        positional_embeddings = positional_embedding_dict[model_classnum]
        model_name, num_classes = model_classnum.split('&')

        src_model_path = task_info['src_model_path']
        tgt_model_path = task_info['tgt_model_path']
        cluster_cfg_path = task_info['cluster_cfg_path']
        task_name = task_info['task_name']
        data_dir = task_info['data_dir']
        
        tgt_exp_name = f"{model_name}_{task_name}"
        
        test_loader = load_images(data_dir, task_name, data_type='test', batch_size=128)
        
        src_model = load_model(model_name, num_classes=int(num_classes)).to(device)
        src_model.eval()
        tgt_model = load_model(model_name, num_classes=int(num_classes)).to(device)
        tgt_model.eval()
        
        model_helper = ModelHelper(src_model)
        model_helper.load(src_model_path, device)
        model_helper.set_cluster(True, cluster_cfg_path)
        
        tgt_model_helper = ModelHelper(tgt_model)
        tgt_model_helper.load(tgt_model_path, device)
        tgt_model_helper.set_cluster(True, cluster_cfg_path)

        tgt_weights = tgt_model_helper.get_learnable_weights()
            
        learnable_weights = model_helper.get_learnable_weights()
        
        # for name, weight in learnable_weights.items():
        #     encode_weights[name] = ae._encode_weights(weight)
    
        # pure_weights = {}
        # pure_encode_weights = {}
        
        # total_encode_time = 0.0
        # total_diffusion_time = 0.0
        # total_decode_time = 0.0

        # for name, encode_weight in encode_weights.items():
            
        #     # --------------- AE Encode ---------------
        #     start = time.time()
        #     # 如果你有自己的 encode 函数，请改成你自己那行
        #     encoded = encode_weight    # <-- 占位，若 ae.encode(name) ，请替换
        #     total_encode_time += time.time() - start
        #     # -----------------------------------------

        #     positional_embedding = positional_embeddings[name]
        #     if len(positional_embedding.shape) == 1:
        #         positional_embedding = positional_embedding.unsqueeze(0)

        #     encode_weight_ch = encode_weight.unsqueeze(1)
        #     positional_embedding_ch = positional_embedding.unsqueeze(1)

        #     if model_diff.model.init_ch == 3:
        #         input = torch.cat((encode_weight_ch, positional_embedding_ch), dim=1)
        #     elif model_diff.model.init_ch == 2:
        #         input = encode_weight_ch

        #     cond = input
        #     shape = (encode_weight.size(0), 1, encode_weight.size(1))

        #     # --------------- Diffusion ---------------
        #     start = time.time()
        #     embed, embeds = model_diff.sample(shape, condition=cond)
        #     total_diffusion_time += time.time() - start
        #     # -----------------------------------------

        #     embed = embed.squeeze(1)

        #     # --------------- AE Decode ---------------
        #     start = time.time()
        #     decoded = ae._decode_weights(embed, positional_embedding, learnable_weights[name].shape)
        #     total_decode_time += time.time() - start
        #     # -----------------------------------------

        #     pure_encode_weights[name] = embed
        #     pure_weights[name] = decoded


        # print("==== Time Summary ====")
        # print(f"AE Encode Total Time     : {total_encode_time/ 3600:.4f}")
        # print(f"Diffusion Total Time     : {total_diffusion_time/ 3600:.4f}")
        # print(f"AE Decode Total Time     : {total_decode_time/ 3600:.4f}")
        
        import psutil, os
        process = psutil.Process(os.getpid())

        def get_cpu_mem():
            return process.memory_info().rss  # bytes

        MB = 1024 * 1024

        # 选择一层参数（比如倒数第二层）
        name, weight = list(learnable_weights.items())[-2]
        pos_embed = positional_embeddings[name]
        if len(pos_embed.shape) == 1:
            pos_embed = pos_embed.unsqueeze(0)

        # -------------------------
        # 1) AE Encode + Decode
        # -------------------------
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        cpu_before = get_cpu_mem()

        

        # AE Decode
        encode_weight = ae._encode_weights(weight)  # 如果你的 AE Encode 有实际操作，可在这里替换
        decoded_weight = ae._decode_weights(encode_weight, pos_embed, learnable_weights[name].shape)

        cpu_after = get_cpu_mem()
        ae_gpu_peak = torch.cuda.max_memory_allocated()
        ae_cpu_peak = cpu_after - cpu_before

        print(f"AE (Encode+Decode) GPU Peak Memory (MB): {ae_gpu_peak / MB:.2f}")
        print(f"AE (Encode+Decode) CPU Peak Memory (MB): {ae_cpu_peak / MB:.2f}")

        # -------------------------
        # 2) Diffusion 推理
        # -------------------------
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        cpu_before = get_cpu_mem()

        # AE Encode
        encode_weight_ch = encode_weight.unsqueeze(1)
        pos_embed_ch = pos_embed.unsqueeze(1)
        # Diffusion 输入
        if model_diff.model.init_ch == 3:
            input_diff = torch.cat((encode_weight_ch, pos_embed_ch), dim=1)
        elif model_diff.model.init_ch == 2:
            input_diff = encode_weight_ch

        cond = input_diff
        shape = (encode_weight.size(0), 1, encode_weight.size(1))

        embed, embeds = model_diff.sample(shape, condition=cond)
        embed = embed.squeeze(1)

        cpu_after = get_cpu_mem()
        diff_gpu_peak = torch.cuda.max_memory_allocated()
        diff_cpu_peak = cpu_after - cpu_before

        print(f"Diffusion GPU Peak Memory (MB): {diff_gpu_peak / MB:.2f}")
        print(f"Diffusion CPU Peak Memory (MB): {diff_cpu_peak / MB:.2f}")

        # for name, encode_weight in encode_weights.items():
        #     positional_embedding = positional_embeddings[name]
        #     if len(positional_embedding.shape) == 1:
        #         positional_embedding = positional_embedding.unsqueeze(0)
                
        #     encode_weight_ch = encode_weight.unsqueeze(1)
        #     positional_embedding_ch = positional_embedding.unsqueeze(1)
            
        #     if model_diff.model.init_ch == 3:
        #         input = torch.cat((encode_weight_ch, positional_embedding_ch), dim=1)
        #     elif model_diff.model.init_ch == 2:
        #         input = encode_weight_ch
                
        #     cond = input
        #     shape = (encode_weight.size(0), 1, encode_weight.size(1))
        #     embed, embeds = model_diff.sample(shape, condition=cond)
        #     embed = embed.squeeze(1)

        #     pure_encode_weights[name] = embed
        #     pure_weights[name] = ae._decode_weights(embed, positional_embedding, learnable_weights[name].shape)
        
        # if visualize_cfg.visualize:
        #     save_weights(learnable_weights, visualize_cfg.visualize_save_file, 'src')
        #     save_weights(tgt_weights, visualize_cfg.visualize_save_file, 'tgt')
        #     save_weights(pure_weights, visualize_cfg.visualize_save_file, 'pure')
            # visualize_weights_layerwise(list(tgt_weights.values()), list(pure_weights.values()), list(learnable_weights.keys()), os.path.join(visualize_cfg.visualize_save_file, f"{tgt_exp_name}_layerwise"))
            
            # visualize_weights_channelwise(list(tgt_weights.values()), list(pure_weights.values()), list(learnable_weights.keys()), visualize_cfg.visualize_save_file)
            
            # visualize_weights_layerwise(list(encode_weights.values()), list(pure_encode_weights.values()), list(learnable_weights.keys()), os.path.join(visualize_cfg.visualize_save_file, f"{tgt_exp_name}_encode", f"{tgt_exp_name}_encode_layerwise"))
            
            # visualize_weights_channelwise(list(encode_weights.values()), list(pure_encode_weights.values()), list(learnable_weights.keys()), os.path.join(visualize_cfg.visualize_save_file, f"{tgt_exp_name}_encode"))
        
        # print(f'\n Starting eval {src_model_path} on test set.')
        
        # tgt_model_helper.update_weights(pure_weights)
        
    #     test_loss = 0
    #     correct = 0
    #     with torch.no_grad():
    #         for data, target in test_loader:
    #             data, target = data.to(device), target.to(device)
    #             output = tgt_model_helper.model(data)
    #             test_loss += task_loss(output, target).item()  # sum up batch loss
    #             pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
    #             correct += pred.eq(target.view_as(pred)).sum().item()

    #     test_loss /= len(test_loader.dataset)
    #     accuracy = 100. * correct / len(test_loader.dataset)
        
    #     test_losses[tgt_exp_name] = test_loss
    #     test_accuracies[tgt_exp_name] = accuracy
        
    #     print('Test set on {}: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
    #         tgt_exp_name,
    #         test_loss, correct, len(test_loader.dataset),
    #         accuracy))

    # return test_losses, test_accuracies

def save_weights(weights, save_dir, type):
    root = os.path.join(save_dir, type)
    os.makedirs(root, exist_ok=True)
    for name, weight in weights.items():
        os.makedirs(os.path.join(root, name), exist_ok=True)
        torch.save(weight, os.path.join(root, name, 'model_cont10.pth'))

def init_seed(seed = 2025):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

if __name__ == '__main__':
    main()
