"""
Neural Similarity Analysis

Evaluate model representations against neural recordings using multiple
similarity metrics (RSA, SVCCA, TSVD-Regression).

Usage:
    python main_similarity.py --spiking --arch resnet18 --checkpoint-path <path> --mode single
"""

import os
import sys
import shutil
import argparse
import logging

import torch

from networks.resnet_ann import SupConResNet, SupCEResNet
from networks.resnet_snn import SupConResNetSNN, SupCEResNetSNN
from benchmark import StaticBenchmark


def get_args():
    parser = argparse.ArgumentParser(description="Neural Representation Similarity Analysis")

    # Model settings
    parser.add_argument('--spiking', action='store_true', help='Use SNN model')
    parser.add_argument('--timesteps', type=int, default=4)
    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18')
    
    # Dataset settings
    parser.add_argument("--train-dataset", default="cifar10", type=str)
    parser.add_argument("--checkpoint-path", default=None, type=str)
    parser.add_argument("--batch-size", default=128, type=int)
    parser.add_argument("--gpu-id", default='0', type=str)
    
    # Evaluation mode
    parser.add_argument("--mode", default='single', choices=["single", "multi"], type=str)
    parser.add_argument("--UnSup", action='store_true', help='Unsupervised learning mode')
    parser.add_argument("--note", default="", type=str)

    return parser.parse_args()


def main():
    opt = get_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id

    # Dataset configurations
    dataset_img_size = {'cifar10': 32, 'tinyimagenet': 64}
    dataset_num_classes = {'cifar10': 10, 'tinyimagenet': 200}
    img_size = dataset_img_size[opt.train_dataset]
    num_classes = dataset_num_classes[opt.train_dataset]

    # Initialize model
    if opt.UnSup:
        # Contrastive learning model
        if opt.spiking:
            model = SupConResNetSNN(name=opt.arch, timestep=opt.timesteps)
        else:
            model = SupConResNet(name=opt.arch)
        model.encoder = torch.nn.DataParallel(model.encoder)
        model.cuda()
    else:
        # Supervised learning model
        if opt.spiking:
            model = SupCEResNetSNN(name=opt.arch, timestep=opt.timesteps, num_classes=num_classes)
        else:
            model = SupCEResNet(name=opt.arch, num_classes=num_classes)
        model = torch.nn.DataParallel(model)
        model.cuda()

    # Single checkpoint evaluation
    if opt.mode == 'single':
        ckpt = opt.checkpoint_path
        logging_dir = f"./save/Analysis/{opt.arch}_{opt.note}"
        if not os.path.exists(logging_dir):
            os.makedirs(logging_dir)
        
        logging.basicConfig(
            filename=os.path.join(logging_dir, 'bio_score.log'),
            level=logging.DEBUG
        )

        # Load checkpoint
        checkpoint = torch.load(ckpt, map_location="cpu", weights_only=False)
        if 'model' in checkpoint.keys():
            model.load_state_dict(checkpoint['model'])
        elif 'state_dict' in checkpoint.keys():
            model.load_state_dict(checkpoint['state_dict'])
        logging.info('Successfully loaded pretrained model')
        print('Successfully loaded pretrained model')

        # Initialize benchmark
        if opt.spiking:
            benchmark = StaticBenchmark(timestep=opt.timesteps, mean=True)
        else:
            benchmark = StaticBenchmark()
        
        # Run model inference and cache activations
        cache_dir = benchmark(model, img_size, opt.batch_size)

        # Create results directory
        bio_score_dir = os.path.join(logging_dir, 'BioScore')
        if not os.path.exists(bio_score_dir):
            os.makedirs(bio_score_dir)

        # Evaluate similarity metrics
        ckpt_name = os.path.basename(ckpt).split('.')[0]
        with open(os.path.join(bio_score_dir, f'{ckpt_name}.sim'), 'ab') as f:
            results = benchmark.evaluate(save_f=f)
        
        # Log results summary
        logging.debug(f"Results: {results}")
        print("\n" + "="*50)
        print("SIMILARITY ANALYSIS RESULTS")
        print("="*50)
        
        # Aggregate results by metric and brain area
        from collections import defaultdict
        metric_scores = defaultdict(list)
        for layer, dataset, metric, brain_area, score in results:
            key = f"{dataset}_{metric}_{brain_area}"
            metric_scores[key].append(score)
        
        import numpy as np
        for key, scores in metric_scores.items():
            print(f"{key}: max={np.max(scores):.4f}, mean={np.mean(scores):.4f}")
        
        print("="*50)

        # Cleanup cache
        logging.info("Deleting cache files...")
        shutil.rmtree(cache_dir)

    elif opt.mode == 'multi':
        # Multiple checkpoint evaluation (for tracking training dynamics)
        opt.checkpoint_path = '.cache/ckpts'
        
        if opt.spiking:
            filter_setting = f'SNN_{opt.train_dataset}_{opt.arch}'
        else:
            filter_setting = f'ANN_{opt.train_dataset}_{opt.arch}'

        # Collect checkpoints
        ckpts = []
        for f in os.listdir(opt.checkpoint_path):
            if filter_setting in f:
                epoch = f.split('epoch')[-1].split('.')[0]
                ckpts.append((epoch, f))
        ckpts.sort(key=lambda x: int(x[0]))

        # Evaluate each checkpoint
        for ckpt_info in ckpts:
            epoch, ckpt_file = ckpt_info
            ckpt_path = os.path.join(opt.checkpoint_path, ckpt_file)
            
            checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
            model.load_state_dict(checkpoint)
            print(f'Loaded checkpoint: {ckpt_file}')

            if opt.spiking:
                benchmark = StaticBenchmark(timestep=opt.timesteps, mean=True)
            else:
                benchmark = StaticBenchmark()
            
            cache_dir = benchmark(model, img_size, opt.batch_size)

            bio_score_dir = os.path.join('./save/Analysis', ckpt_file.split('_epoch')[0], 'BioScore')
            if not os.path.exists(bio_score_dir):
                os.makedirs(bio_score_dir)

            with open(os.path.join(bio_score_dir, f'{epoch}.sim'), 'ab') as f:
                results = benchmark.evaluate(save_f=f)
            
            print(f"Epoch {epoch}: {len(results)} results computed")

            shutil.rmtree(cache_dir)


if __name__ == "__main__":
    main()
