import os
import random
import pyrallis
import json
from tqdm import tqdm
import psutil

import numpy as np
import torch
import random
from torch.utils.tensorboard import SummaryWriter

from loaders.image_loader import load_images
from loaders.param_loader import load_unpaired_params
from models.base import load_model
from models.ae import load_ae
from models.position import load_positional_embedding
from models.base.model_helper import ModelHelper
from optimization.optimizer import load_optimizer
from optimization.scheduler import load_lr_scheduler
from loss.task_loss import load_task_loss
from loss.reconstruction_loss import load_reconstruction_loss
from utils.log_utils import log_scalar_dict, create_experiment_dir
from utils.visualize_util import visualize_weights_layerwise, visualize_weights_channelwise

from options import AETrainConfig

@pyrallis.wrap()
def main(cfg: AETrainConfig):
    random_seed = None
    if cfg.logging.random_seed == -1:
        random_seed = random.randint(0, 2**32 - 1)
        print("Random Seed: ", random_seed)
    else:
        random_seed = cfg.logging.random_seed
    
    cfg.logging.random_seed = random_seed
    init_seed(random_seed)
    # ----------------------------------------
    # basic configuration
    # ----------------------------------------
    use_cuda = not cfg.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # ----------------------------------------
    # logging configuration
    # ----------------------------------------
    if not cfg.logging.disable_logging:
        logger = SummaryWriter(log_dir=os.path.join(cfg.logging.log_dir, "runs", cfg.logging.exp_name, cfg.logging.sub_exp_name))
        logger.add_text("Config", json.dumps(pyrallis.encode(cfg), indent=4))
    else:
        logger = None
        
    exp_dir_path = create_experiment_dir(cfg.logging.log_dir, cfg.logging.exp_name, 'ae', cfg.logging.sub_exp_name)
    
    # ----------------------------------------
    # ae configuration
    # ----------------------------------------
    ae = load_ae(cfg.ae).to(device)
    if cfg.ae.checkpoint_path is not None:
        ae.load(cfg.ae.checkpoint_path)
        print("Load AE weights from: ", cfg.ae.checkpoint_path)
    ae.train()
    print("AE Framework")
    print(ae)
    
    # ----------------------------------------
    # positional embeddings configuration
    # ----------------------------------------
    positional_embedding_dict = {}
        
    for mc, info in cfg.model_task_dict.items():
        model_name, num_classes = mc.split('&')
        
        cluster_cfg_path = info['cluster_cfg_path']
        cluster_names = None
        
        with open(cluster_cfg_path, 'r') as f:
            cluster_cfg = json.load(f)
            cluster_names = list(cluster_cfg['struct'].keys())
        
        
        original_model = load_model(model_name, int(num_classes)).to(device)
        model_helper = ModelHelper(original_model)
        model_helper.set_cluster(True, cluster_cfg_path)
        
        positional_embedding_model = load_positional_embedding(cfg.ae.embedding, device)
        
        positional_embedding_model._fit_param_indices(cluster_names, model_helper.get_learnable_weights_shapes())
        
        positional_embedding_model._calculate_positional_embeddings(prefix=mc)
        
        _, positional_embeddings, _, normalized_positional_embeddings = positional_embedding_model.get_indices_and_positional_embeddings()
        positional_embeddings = positional_embeddings if cfg.ae.use_normalization_embedding else normalized_positional_embeddings
        
        if cfg.ae.mu_std_normalization_embedding:
            for name, positional_embedding in positional_embeddings.items():
                positional_embeddings[name] = (positional_embedding - positional_embedding.mean()) / positional_embedding.std()
        positional_embedding_dict[mc] = positional_embeddings

    if not cfg.ae.use_embeddings:
        print("Do not use positional embeddings")
    else:
        print("Use positional embeddings" if not cfg.ae.use_normalization_embedding else "Use normalization positional embeddings")
        print("Use mu std normalization positional embeddings" if cfg.ae.mu_std_normalization_embedding else "Do not use mu std normalization positional embeddings")
    
    # ----------------------------------------
    # param loader configuration
    # ----------------------------------------
    param_loader = load_unpaired_params(cfg.param_dir, None, None, cfg.param_types)

    # ----------------------------------------
    # opimization configuration
    # ----------------------------------------
    task_loss = load_task_loss(cfg.task)
    ae_rec_loss = load_reconstruction_loss(cfg)
    
    optimizer = load_optimizer(ae.parameters(), cfg.optim)
    scheduler = load_lr_scheduler(optimizer, cfg.optim)
    
    if cfg.logging.visualize:
        os.makedirs(os.path.dirname(cfg.logging.visualize_save_file), exist_ok=True)
        assert cfg.ae.checkpoint_path is not None
        print("Start visualization.")
        for model_classnum, task_info in cfg.model_task_dict.items():
            positional_embeddings = positional_embedding_dict[model_classnum]
            model_name, num_classes = model_classnum.split('&')
            
            test_model_path = task_info['test_model_path']
            cluster_cfg_path = task_info['cluster_cfg_path']
            task_name = task_info['task_name']
            data_dir = task_info['data_dir']
            
            original_model = load_model(model_name, num_classes=int(num_classes)).to(device)
            original_model.eval()
            
            model_helper = ModelHelper(original_model)
            
            model_helper.load(test_model_path, device)
            model_helper.set_cluster(True, cluster_cfg_path)
            
            learnable_weights = model_helper.get_learnable_weights()
            reconstructed_weights = ae.predict_all(positional_embeddings, learnable_weights)
            
            test_loader = load_images(data_dir, task_name, data_type='test', batch_size=128)
            
            visualize_weights_layerwise(list(learnable_weights.values()), list(reconstructed_weights.values()), list(learnable_weights.keys()), cfg.logging.visualize_save_file)
        
            visualize_weights_channelwise(list(learnable_weights.values()), list(reconstructed_weights.values()), list(learnable_weights.keys()), os.path.dirname(cfg.logging.visualize_save_file))
            
            
            print(f'\n Starting eval {test_model_path} on test set.')
            model_helper.update_weights(reconstructed_weights)
            test_loss = 0
            correct = 0
            with torch.no_grad():
                for data, target in test_loader:
                    data, target = data.to(device), target.to(device)
                    output = model_helper.model(data)
                    test_loss += task_loss(output, target).item()  # sum up batch loss
                    pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                    correct += pred.eq(target.view_as(pred)).sum().item()

            test_loss /= len(test_loader.dataset)
            accuracy = 100. * correct / len(test_loader.dataset)
            print("Accuracy: ", accuracy)
            
        print("Finish visualization.")
        exit(0)
            
            
    # ----------------------------------------
    # each epoch
    # ----------------------------------------
    max_acc = 0
    for epoch in tqdm(range(1, cfg.epochs + 1), desc='Epoch'):
        loss = train_ae(param_loader, positional_embedding_dict, ae, ae_rec_loss, optimizer, device)

        if loss.isnan().item() is True:
            # This can result in an infinite loop, be careful
            print("Loss is NaN. Skipping this batch.")
            continue
        
        if epoch % cfg.logging.log_interval == 0 and not cfg.logging.disable_logging:
            loss_dict = dict(
                loss=loss,
                reconstruction_loss=loss
            )
            _log_training(logger, epoch, loss_dict, scheduler.get_last_lr(), cfg.epochs)
        
        if cfg.eval_epochs_interval is not None and epoch % cfg.eval_epochs_interval == 0:
            test_losses, test_accuracies = test_ae(positional_embedding_dict, ae, cfg.model_task_dict, task_loss, device)
            if logger is not None:
                log_scalar_dict(test_losses, title='eval_loss', iteration=epoch, logger=logger)
                log_scalar_dict(test_accuracies, title='eval_accuracy', iteration=epoch, logger=logger)
                
            avg_acc = sum(test_accuracies.values()) / len(test_accuracies)
            
            if avg_acc > max_acc:
                max_acc = avg_acc
                ae.save(os.path.join(exp_dir_path, f"ae_{cfg.logging.sub_exp_name}_best.pth"))
            
        if cfg.save_epochs_interval is not None and epoch % cfg.save_epochs_interval == 0:
            ae.save(os.path.join(exp_dir_path, f"ae_{cfg.logging.sub_exp_name}_{epoch}.pth"))
            
        scheduler.check_and_step(epoch)
        

def train_ae(param_loader, positional_embedding_dict, ae, ae_rec_loss, optimizer, device):
    process = psutil.Process(os.getpid())

    def get_cpu_mem():
        return process.memory_info().rss  # bytes

    ae.train()

    MB = 1024 * 1024
    total_loss = torch.tensor(0., device=device)

    # 初始化峰值
    ae_gpu_peak = 0
    ae_cpu_peak = 0

    for sample in param_loader:
        # param_var_info 用于簇间和簇内加权
        param_data, param_type, param_name, param_var_info, model_classnum = sample

        positional_embeddings = positional_embedding_dict[model_classnum]
        positional_embedding = positional_embeddings[param_name]

        param_data = param_data.to(device)
        positional_embedding = positional_embedding.to(device)

        # -------------------------
        # 测量 GPU/CPU 峰值
        # -------------------------
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        cpu_before = get_cpu_mem()

        reconstructed_param = ae._predict_weights(param_data, positional_embedding)
        reconstructed_param.retain_grad()
        assert reconstructed_param.shape == param_data.shape

        cpu_after = get_cpu_mem()
        batch_gpu_peak = torch.cuda.max_memory_allocated()
        batch_cpu_peak = cpu_after - cpu_before

        ae_gpu_peak = max(ae_gpu_peak, batch_gpu_peak)
        ae_cpu_peak = max(ae_cpu_peak, batch_cpu_peak)

        # -------------------------
        # 计算 loss 并反向
        # -------------------------
        rec_loss = ae_rec_loss(
            [reconstructed_param], 
            [param_data], 
            min_var=param_var_info['min_var'], 
            max_var=param_var_info['max_var']
        )
        loss = rec_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss

    avg_loss = total_loss / len(param_loader)

    print(f"AE Training GPU Peak Memory (MB): {ae_gpu_peak / MB:.2f}")
    print(f"AE Training CPU Peak Memory (MB): {ae_cpu_peak / MB:.2f}")
    

    return avg_loss

@torch.no_grad()
def test_ae(positional_embedding_dict, ae, model_task_dict, task_loss, device):
    # ae.eval()
    ae.train()
    
    test_losses = {}
    test_accuracies = {}
    
    for model_classnum, task_info in model_task_dict.items():
        positional_embeddings = positional_embedding_dict[model_classnum]
        model_name, num_classes = model_classnum.split('&')
        
        test_model_path = task_info['test_model_path']
        cluster_cfg_path = task_info['cluster_cfg_path']
        task_name = task_info['task_name']
        data_dir = task_info['data_dir']
        
        tgt_exp_name = f"{model_name}_{task_name}"
        
        test_loader = load_images(data_dir, task_name, data_type='test', batch_size=128)
        
        original_model = load_model(model_name, num_classes=int(num_classes)).to(device)
        original_model.eval()
        
        model_helper = ModelHelper(original_model)
        
        model_helper.load(test_model_path, device)
        model_helper.set_cluster(True, cluster_cfg_path)
        
        print(f'\n Starting eval {test_model_path} on test set.')
        learnable_weights = model_helper.get_learnable_weights()
        reconstructed_weights = ae.predict_all(positional_embeddings, learnable_weights)
        
        model_helper.update_weights(reconstructed_weights)
        
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model_helper.model(data)
                test_loss += task_loss(output, target).item()  # sum up batch loss
                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)
        
        test_losses[tgt_exp_name] = test_loss
        test_accuracies[tgt_exp_name] = accuracy
            
        print('Test set on {}: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
            tgt_exp_name,
            test_loss, correct, len(test_loader.dataset),
            accuracy))

    return test_losses, test_accuracies

def _log_training(logger: SummaryWriter, training_step: int, loss_dict: dict, lr: float, epochs: int):
    log_scalar_dict(loss_dict,
                    title='training_loss',
                    iteration=training_step,
                    logger=logger)
    log_scalar_dict(dict(learning_rate=lr),
                    title="learning_rate",
                    iteration=training_step,
                    logger=logger)

    print(f"[{training_step}/{epochs}] Loss = {loss_dict['loss'].item():.8f} ({''.join(f'{k} = {v.item():.8f}, ' for k, v in loss_dict.items() if k != 'loss')})")
        
def init_seed(seed = 2025):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

if __name__ == '__main__':
    main()
