import os
import random
import pyrallis
import json
from tqdm import tqdm
from typing import Dict, List
import psutil

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from loaders.image_loader import load_images
from loaders.param_loader import load_paired_params
from models.base import load_model
from models.diffusion import load_diff
from models.ae import load_ae
from models.position import load_positional_embedding
from models.base.model_helper import ModelHelper
from optimization.optimizer import load_optimizer
from optimization.scheduler import load_lr_scheduler
from loss.task_loss import load_task_loss
from loss.reconstruction_loss import load_reconstruction_loss
from utils.log_utils import log_scalar_dict, create_experiment_dir
from utils.visualize_util import visualize_weights
from utils.visualize_util import visualize_weights_layerwise, visualize_weights_channelwise

from options import DiffusionTrainConfig

@pyrallis.wrap()
def main(cfg: DiffusionTrainConfig):
    init_seed(2025)
    # ----------------------------------------
    # basic configuration
    # ----------------------------------------
    use_cuda = not cfg.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # ----------------------------------------
    # logging configuration
    # ----------------------------------------

    if not cfg.logging.disable_logging:
        logger = SummaryWriter(log_dir=os.path.join(cfg.logging.log_dir, "runs", cfg.logging.exp_name, cfg.logging.sub_exp_name))
        logger.add_text("Config", json.dumps(pyrallis.encode(cfg), indent=4))
    else:
        logger = None
        
    exp_dir_path = create_experiment_dir(cfg.logging.log_dir, cfg.logging.exp_name, 'diffusion', cfg.logging.sub_exp_name)
    
    # ----------------------------------------
    # ae configuration
    # ----------------------------------------
    ae = load_ae(cfg.ae).to(device)
    assert cfg.ae.checkpoint_path is not None, "Please provide a checkpoint path for the autoencoder."
    ae.load(cfg.ae.checkpoint_path)
    ae.eval()
    
    # ----------------------------------------
    # diffusion model configuration
    # ----------------------------------------
    model_diff = load_diff(
        unet_name=cfg.diffusion.unet_name,
        dim=cfg.diffusion.input_dim,
        dim_cond=cfg.diffusion.condition_dim,
        self_cond=True,
        layer_channels=cfg.diffusion.layer_channels,
        init_ch=cfg.diffusion.init_ch,
        final_ch=cfg.diffusion.final_ch,
        timesteps=cfg.diffusion.num_timesteps
    ).to(device)
    if cfg.diffusion.checkpoint_path is not None:
        model_diff.load(cfg.diffusion.checkpoint_path)
        print(f"Loaded diffusion model from {cfg.diffusion.checkpoint_path}")
    model_diff.train()
    if cfg.diffusion.init_ch == 3:
        print("Diffusion use positional embeddings")
    else:
        print("Diffusion do not use positional embeddings")
        
    print("Diffusion Loss Type: ", cfg.loss_type)
    
    # ----------------------------------------
    # positional embeddings configuration
    # ----------------------------------------
    positional_embedding_dict = {}
    for mc, info in cfg.model_task_dict.items():
        model_name, num_classes = mc.split('&')
        
        cluster_cfg_path = info['cluster_cfg_path']
        cluster_names = None
        
        with open(cluster_cfg_path, 'r') as f:
            cluster_cfg = json.load(f)
            cluster_names = list(cluster_cfg['struct'].keys())
        
        original_model = load_model(model_name, int(num_classes)).to(device)
        model_helper = ModelHelper(original_model)
        model_helper.set_cluster(True, cluster_cfg_path)
        
        positional_embedding_model = load_positional_embedding(cfg.ae.embedding, device)
        
        positional_embedding_model._fit_param_indices(cluster_names, model_helper.get_learnable_weights_shapes())
        
        positional_embedding_model._calculate_positional_embeddings(prefix=mc)
        _, positional_embeddings, _, normalized_positional_embeddings = positional_embedding_model.get_indices_and_positional_embeddings()
        positional_embeddings = positional_embeddings if cfg.ae.use_normalization_embedding else normalized_positional_embeddings
        if cfg.ae.mu_std_normalization_embedding:
            for name, positional_embedding in positional_embeddings.items():
                positional_embeddings[name] = (positional_embedding - positional_embedding.mean()) / positional_embedding.std()
        positional_embedding_dict[mc] = positional_embeddings
    
    if not cfg.ae.use_embeddings:
        print("Do not use positional embeddings")
    else:
        print("Use positional embeddings" if not cfg.ae.use_normalization_embedding else "Use normalization positional embeddings")
        print("Use mu std normalization positional embeddings" if cfg.ae.mu_std_normalization_embedding else "Do not use mu std normalization positional embeddings")
    
    
    # ----------------------------------------
    # param loader configuration
    # ----------------------------------------
    param_loader = load_paired_params(cfg.param_dir, cfg.param_names, cfg.src_tgt_dict, None)

    # ----------------------------------------
    # opimization configuration
    # ----------------------------------------
    task_loss = load_task_loss(cfg.task)

    optimizer = load_optimizer(model_diff.parameters(), cfg.optim)
    scheduler = load_lr_scheduler(optimizer, cfg.optim)
    
    # ----------------------------------------
    # each epoch
    # ----------------------------------------
    if cfg.logging.visualize:
        test_diff(positional_embedding_dict, model_diff, ae, cfg.model_task_dict, task_loss, cfg.logging, device)
        exit(0)
        
    max_acc = 0
    for epoch in tqdm(range(1, cfg.epochs + 1), desc='Epoch'):
        model_diff.train()
        
        loss = train_diff(param_loader, positional_embedding_dict, model_diff, ae, optimizer, cfg.loss_type, device)

        exit(0)
        if loss.isnan().item() is True:
            # This can result in an infinite loop, be careful
            print("Loss is NaN. Skipping this batch.")
            continue
       
        if epoch % cfg.logging.log_interval == 0 and not cfg.logging.disable_logging:
            loss_dict = dict(
                loss=loss,
                reconstruction_loss=loss
            )
            _log_training(logger, epoch, loss_dict, scheduler.get_last_lr(), cfg.epochs)
        
        if cfg.eval_epochs_interval is not None and epoch % cfg.eval_epochs_interval == 0:
            test_losses, test_accuracies = test_diff(positional_embedding_dict, model_diff, ae, cfg.model_task_dict, task_loss, cfg.logging, device)
            if logger is not None:
                log_scalar_dict(test_losses, title='eval_loss', iteration=epoch, logger=logger)
                log_scalar_dict(test_accuracies, title='eval_accuracy', iteration=epoch, logger=logger)

            avg_acc = sum(test_accuracies.values()) / len(test_accuracies)
            
            if avg_acc > max_acc:
                max_acc = avg_acc
                model_diff.save(os.path.join(exp_dir_path, f"diffusion_{cfg.logging.sub_exp_name}_best.pth"))
        
        if cfg.save_epochs_interval is not None and epoch % cfg.save_epochs_interval == 0:
            model_diff.save(os.path.join(exp_dir_path, f"diffusion_{cfg.logging.sub_exp_name}_{epoch}.pth"))
            
        scheduler.check_and_step(epoch)
        

def train_diff(param_loader, positional_embedding_dict, model_diff, ae, optimizer, loss_type, device):
    total_loss = torch.tensor(0.).to(device)
    
    process = psutil.Process(os.getpid())
    MB = 1024 * 1024
    # 初始化峰值
    ae_gpu_peak = 0
    ae_cpu_peak = 0
    
    def get_cpu_mem():
        return process.memory_info().rss  # bytes
    
    for sample in param_loader:
        src_param_data, tgt_param_data, src_param_type, tgt_param_type, param_name, model_classnum = sample
        
        positional_embeddings = positional_embedding_dict[model_classnum]
        positional_embedding = positional_embeddings[param_name]
        
        src_param_data = src_param_data.to(device)
        tgt_param_data = tgt_param_data.to(device)
        positional_embedding = positional_embedding.to(device)
        if len(positional_embedding.shape) == 1:
            positional_embedding = positional_embedding.unsqueeze(0)
        
        # -------------------------
        # 测量 GPU/CPU 峰值
        # -------------------------
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        cpu_before = get_cpu_mem()
        
        src_encode_param = ae._encode_weights(src_param_data)
        tgt_encode_param = ae._encode_weights(tgt_param_data)
        
        assert src_encode_param.shape == tgt_encode_param.shape == positional_embedding.shape
        
        src_encode_param_ch = src_encode_param.unsqueeze(1)
        tgt_encode_param_ch = tgt_encode_param.unsqueeze(1)
        positional_embedding_ch = positional_embedding.unsqueeze(1)
        if model_diff.model.init_ch == 3:
            input = torch.cat((src_encode_param_ch, positional_embedding_ch), dim=1)
        elif model_diff.model.init_ch == 2:
            input = src_encode_param_ch
        cond = input
        batch_size = src_encode_param.shape[0]
        
        t = torch.randint(0, model_diff.timesteps, (batch_size,), device=device).long()
        loss = model_diff.p_losses(tgt_encode_param_ch, t=t, c=cond, loss_type=loss_type)
        
        cpu_after = get_cpu_mem()
        batch_gpu_peak = torch.cuda.max_memory_allocated()
        batch_cpu_peak = cpu_after - cpu_before

        ae_gpu_peak = max(ae_gpu_peak, batch_gpu_peak)
        ae_cpu_peak = max(ae_cpu_peak, batch_cpu_peak)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        total_loss += loss
        
    print(f"AE Training GPU Peak Memory (MB): {ae_gpu_peak / MB:.2f}")
    print(f"AE Training CPU Peak Memory (MB): {ae_cpu_peak / MB:.2f}")
    
    return total_loss / len(param_loader)


@torch.no_grad()
def test_diff(positional_embedding_dict, model_diff, ae, model_task_dict, task_loss, visualize_cfg, device):
    model_diff.eval()
    
    test_losses = {}
    test_accuracies = {}
    
    if visualize_cfg.visualize:
        os.makedirs(visualize_cfg.visualize_save_file, exist_ok=True)
        
    for model_classnum, task_info in model_task_dict.items():
        encode_weights = {}
        positional_embeddings = positional_embedding_dict[model_classnum]
        model_name, num_classes = model_classnum.split('&')

        src_model_path = task_info['src_model_path']
        tgt_model_path = task_info['tgt_model_path']
        cluster_cfg_path = task_info['cluster_cfg_path']
        task_name = task_info['task_name']
        data_dir = task_info['data_dir']
        
        tgt_exp_name = f"{model_name}_{task_name}"
        
        test_loader = load_images(data_dir, task_name, data_type='test', batch_size=128)
        
        src_model = load_model(model_name, num_classes=int(num_classes)).to(device)
        src_model.eval()
        tgt_model = load_model(model_name, num_classes=int(num_classes)).to(device)
        tgt_model.eval()
        
        model_helper = ModelHelper(src_model)
        model_helper.load(src_model_path, device)
        model_helper.set_cluster(True, cluster_cfg_path)
        
        tgt_model_helper = ModelHelper(tgt_model)
        tgt_model_helper.load(tgt_model_path, device)
        tgt_model_helper.set_cluster(True, cluster_cfg_path)

        tgt_weights = tgt_model_helper.get_learnable_weights()
            
        learnable_weights = model_helper.get_learnable_weights()
        
        for name, weight in learnable_weights.items():
            encode_weights[name] = ae._encode_weights(weight)
    
        pure_weights = {}
        pure_encode_weights = {}
        for name, encode_weight in encode_weights.items():
            positional_embedding = positional_embeddings[name]
            if len(positional_embedding.shape) == 1:
                positional_embedding = positional_embedding.unsqueeze(0)
                
            encode_weight_ch = encode_weight.unsqueeze(1)
            positional_embedding_ch = positional_embedding.unsqueeze(1)
            
            if model_diff.model.init_ch == 3:
                input = torch.cat((encode_weight_ch, positional_embedding_ch), dim=1)
            elif model_diff.model.init_ch == 2:
                input = encode_weight_ch
                
            cond = input
            shape = (encode_weight.size(0), 1, encode_weight.size(1))
            embed, embeds = model_diff.sample(shape, condition=cond)
            embed = embed.squeeze(1)

            pure_encode_weights[name] = embed
            pure_weights[name] = ae._decode_weights(embed, positional_embedding, learnable_weights[name].shape)
        
        if visualize_cfg.visualize:
            save_weights(learnable_weights, visualize_cfg.visualize_save_file, 'src')
            save_weights(tgt_weights, visualize_cfg.visualize_save_file, 'tgt')
            save_weights(pure_weights, visualize_cfg.visualize_save_file, 'pure')
            # visualize_weights_layerwise(list(tgt_weights.values()), list(pure_weights.values()), list(learnable_weights.keys()), os.path.join(visualize_cfg.visualize_save_file, f"{tgt_exp_name}_layerwise"))
            
            # visualize_weights_channelwise(list(tgt_weights.values()), list(pure_weights.values()), list(learnable_weights.keys()), visualize_cfg.visualize_save_file)
            
            # visualize_weights_layerwise(list(encode_weights.values()), list(pure_encode_weights.values()), list(learnable_weights.keys()), os.path.join(visualize_cfg.visualize_save_file, f"{tgt_exp_name}_encode", f"{tgt_exp_name}_encode_layerwise"))
            
            # visualize_weights_channelwise(list(encode_weights.values()), list(pure_encode_weights.values()), list(learnable_weights.keys()), os.path.join(visualize_cfg.visualize_save_file, f"{tgt_exp_name}_encode"))
        
        print(f'\n Starting eval {src_model_path} on test set.')
        
        tgt_model_helper.update_weights(pure_weights)
        
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = tgt_model_helper.model(data)
                test_loss += task_loss(output, target).item()  # sum up batch loss
                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)
        
        test_losses[tgt_exp_name] = test_loss
        test_accuracies[tgt_exp_name] = accuracy
        
        print('Test set on {}: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
            tgt_exp_name,
            test_loss, correct, len(test_loader.dataset),
            accuracy))

    return test_losses, test_accuracies


def _log_training(logger: SummaryWriter, training_step: int, loss_dict: dict, lr: float, epochs: int):
    log_scalar_dict(loss_dict,
                    title='training_diff_loss',
                    iteration=training_step,
                    logger=logger)
    log_scalar_dict(dict(learning_rate=lr),
                    title="learning_rate",
                    iteration=training_step,
                    logger=logger)

    print(f"[{training_step}/{epochs}] Loss = {loss_dict['loss'].item():.8f} ({''.join(f'{k} = {v.item():.8f}, ' for k, v in loss_dict.items() if k != 'loss')})")

def _visualize_z_p(z_src: Dict[str, torch.Tensor], z_tgt: Dict[str, torch.Tensor], p: Dict[str, torch.Tensor], save_dir: str):
    os.makedirs(save_dir, exist_ok=True)
    print(f"Visualizing z src in {os.path.join(save_dir, 'z_src')}")
    visualize_weights(z_src, os.path.join(save_dir, "z_src"), mode="both")
    print(f"Visualizing z tgt in {os.path.join(save_dir, 'z_tgt')}")
    visualize_weights(z_tgt, os.path.join(save_dir, "z_tgt"), mode="both")
    print(f"Visualizing p in {os.path.join(save_dir, 'p')}")
    visualize_weights(p, os.path.join(save_dir, "p"), mode="both")

def save_weights(weights, save_dir, type):
    root = os.path.join(save_dir, type)
    os.makedirs(root, exist_ok=True)
    for name, weight in weights.items():
        os.makedirs(os.path.join(root, name), exist_ok=True)
        torch.save(weight, os.path.join(root, name, 'model_cont10.pth'))

def init_seed(seed = 2025):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

if __name__ == '__main__':
    main()
