import os
import random
import pyrallis
import json
from tqdm import tqdm

import numpy as np
import torch
import random
from torch.utils.tensorboard import SummaryWriter

from loaders.image_loader import load_images
from loaders.param_loader import load_unpaired_params
from models.base import load_model
from models.ae import load_ae
from models.position import load_positional_embedding
from models.base.model_helper import ModelHelper
from optimization.optimizer import load_optimizer
from optimization.scheduler import load_lr_scheduler
from loss.task_loss import load_task_loss
from loss.reconstruction_loss import load_reconstruction_loss
from utils.log_utils import log_scalar_dict, create_experiment_dir
from utils.visualize_util import visualize_weights_layerwise, visualize_weights_channelwise

from options import AETrainConfig

@pyrallis.wrap()
def main(cfg: AETrainConfig):
    # ----------------------------------------
    # basic configuration
    # ----------------------------------------
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    
    
    ae_checkpoint_paths = [
        '/nfs196/wjx/projects/PMP/outputs/LT/ae/Imagenet_vit_b_16_train_ae_b01/checkpoints_20_05_2025_153025/ae_Imagenet_vit_b_16_train_ae_b01_best.pth',
        '/nfs196/wjx/projects/PMP/outputs/LT/ae/Imagenet_vit_b_16_train_ae_b23/checkpoints_20_05_2025_153126/ae_Imagenet_vit_b_16_train_ae_b23_best.pth',
        '/nfs196/wjx/projects/PMP/outputs/LT/ae/Imagenet_vit_b_16_train_ae_b45/checkpoints_20_05_2025_153136/ae_Imagenet_vit_b_16_train_ae_b45_best.pth',
        '/nfs196/wjx/projects/PMP/outputs/LT/ae/Imagenet_vit_b_16_train_ae_b67/checkpoints_20_05_2025_153147/ae_Imagenet_vit_b_16_train_ae_b67_best.pth',
        '/nfs196/wjx/projects/PMP/outputs/LT/ae/Imagenet_vit_b_16_train_ae_b89/checkpoints_20_05_2025_153158/ae_Imagenet_vit_b_16_train_ae_b89_best.pth',
        '/nfs196/wjx/projects/PMP/outputs/LT/ae/Imagenet_vit_b_16_train_ae_b1011/checkpoints_20_05_2025_153207/ae_Imagenet_vit_b_16_train_ae_b1011_best.pth'
    ]
    
    cluster_cfg_paths = [
        '/nfs196/wjx/projects/PMP/outputs/LT/Imagenet_vit_b_16_cluster_data_b01/tgt/vit_b_16_Imagenet/param_info.json',
        '/nfs196/wjx/projects/PMP/outputs/LT/Imagenet_vit_b_16_cluster_data_b23/tgt/vit_b_16_Imagenet/param_info.json',
        '/nfs196/wjx/projects/PMP/outputs/LT/Imagenet_vit_b_16_cluster_data_b45/tgt/vit_b_16_Imagenet/param_info.json',
        '/nfs196/wjx/projects/PMP/outputs/LT/Imagenet_vit_b_16_cluster_data_b67/tgt/vit_b_16_Imagenet/param_info.json',
        '/nfs196/wjx/projects/PMP/outputs/LT/Imagenet_vit_b_16_cluster_data_b89/tgt/vit_b_16_Imagenet/param_info.json',
        '/nfs196/wjx/projects/PMP/outputs/LT/Imagenet_vit_b_16_cluster_data_b1011/tgt/vit_b_16_Imagenet/param_info.json'
    ]
    
    test_model_path = '/nfs196/wjx/projects/PMP/outputs/LT/cont/vit_b_16_Imagenet_train_model_cont_tgt/checkpoints_12_02_2025_030042/model_cont10.pt'
    
    model_name = 'vit_b_16'
    
    original_model = load_model(model_name, num_classes=1000).to(device)
    original_model.eval()
    
    model_helper = ModelHelper(original_model)
    model_helper.load(test_model_path, device)
    
    for i in tqdm(range(len(ae_checkpoint_paths))):
        ae_ckp_path = ae_checkpoint_paths[i]
        cluster_cfg_path = cluster_cfg_paths[i]

        ae = load_ae(cfg.ae).to(device)
        ae.load(ae_ckp_path)
        
        ae.eval()
        
        model_helper.set_cluster(True, cluster_cfg_path)
        
        cluster_names = None
        
        with open(cluster_cfg_path, 'r') as f:
            cluster_cfg = json.load(f)
            cluster_names = list(cluster_cfg['struct'].keys())
            
        positional_embedding_model = load_positional_embedding(cfg.ae.embedding, device)
        
        positional_embedding_model._fit_param_indices(cluster_names, model_helper.get_learnable_weights_shapes())
        
        positional_embedding_model._calculate_positional_embeddings(prefix='vit_b_16&1000')
        
        _, positional_embeddings, _, normalized_positional_embeddings = positional_embedding_model.get_indices_and_positional_embeddings()
        positional_embeddings = positional_embeddings if cfg.ae.use_normalization_embedding else normalized_positional_embeddings
        
        if cfg.ae.mu_std_normalization_embedding:
            for name, positional_embedding in positional_embeddings.items():
                positional_embeddings[name] = (positional_embedding - positional_embedding.mean()) / positional_embedding.std()
        
            
        learnable_weights = model_helper.get_learnable_weights()
        reconstructed_weights = ae.predict_all(positional_embeddings, learnable_weights)
        
        model_helper.update_weights(reconstructed_weights)
        
        
    model_helper.save('/nfs196/wjx/projects/PMP/outputs/LT/model.pt')

    
        
@torch.no_grad()
def test_ae(positional_embedding_dict, ae, model_task_dict, task_loss, device):
    # ae.eval()
    ae.train()
    
    test_losses = {}
    test_accuracies = {}
    
    for model_classnum, task_info in model_task_dict.items():
        positional_embeddings = positional_embedding_dict[model_classnum]
        model_name, num_classes = model_classnum.split('&')
        
        test_model_path = task_info['test_model_path']
        cluster_cfg_path = task_info['cluster_cfg_path']
        task_name = task_info['task_name']
        data_dir = task_info['data_dir']
        
        tgt_exp_name = f"{model_name}_{task_name}"
        
        test_loader = load_images(data_dir, task_name, data_type='test', batch_size=128)
        
        original_model = load_model(model_name, num_classes=int(num_classes)).to(device)
        original_model.eval()
        
        model_helper = ModelHelper(original_model)
        
        model_helper.load(test_model_path, device)
        model_helper.set_cluster(True, cluster_cfg_path)
        
        print(f'\n Starting eval {test_model_path} on test set.')
        learnable_weights = model_helper.get_learnable_weights()
        reconstructed_weights = ae.predict_all(positional_embeddings, learnable_weights)
        
        model_helper.update_weights(reconstructed_weights)
        
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model_helper.model(data)
                test_loss += task_loss(output, target).item()  # sum up batch loss
                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)
        
        test_losses[tgt_exp_name] = test_loss
        test_accuracies[tgt_exp_name] = accuracy
            
        print('Test set on {}: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
            tgt_exp_name,
            test_loss, correct, len(test_loader.dataset),
            accuracy))

    return test_losses, test_accuracies

def init_seed(seed = 2025):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

if __name__ == '__main__':
    main()
