""" Analysis Script

Copyright (c) 2025 Anonymous Authors
"""
import os
import shutil
import argparse
import pandas as pd
from datetime import datetime

import yaml

from timm import utils
from timm.analysis import get_layer_difference_logging_list, layer_difference_comparison
from timm.analysis import get_cluster_logging_list, cluster_comparison
from timm.analysis import get_accuracy_logging_list, accuracy_comparison


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
                    help='YAML config file specifying default arguments')


parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

# Device & distributed
group = parser.add_argument_group('Device parameters')
group.add_argument('--device', default='cuda', type=str,
                    help="Device (accelerator) to use.")
group.add_argument('--amp', type=str2bool, default=True,
                   help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
group.add_argument('--amp-dtype', default='float16', type=str,
                   help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str,
                   help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
                   help='Force broadcast buffers for native DDP to off.')
group.add_argument('--synchronize-step', action='store_true', default=False,
                   help='torch.cuda.synchronize() end of each step')
group.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--device-modules', default=None, type=str, nargs='+',
                    help="Python imports for device backend modules.")

# Analysis
group = parser.add_argument_group('Analysis parameters')
group.add_argument('--exp-label-list', type=str, nargs='+', default=[],
                   help='label for train list and results list')
group.add_argument('--train-list', type=str, nargs='+', default=[],
                   help='train folder dir list')
group.add_argument('--results-list', type=str, nargs='+', default=[],
                   help='results folder dir list')
# layer difference
group.add_argument('--layer-difference-comparison', action='store_true', default=False,
                   help='')
# cluster
group.add_argument('--cluster-size-frequency-comparison', action='store_true', default=False,
                   help='')
group.add_argument('--n-skipped-layer-cluster-comparison', action='store_true', default=False,
                   help='')
group.add_argument('--n-cluster-per-layer-comparison', action='store_true', default=False,
                   help='')
group.add_argument('--cluster-size-frequency-histogram-comparison', action='store_true', default=False,
                   help='')
# accuracy
group.add_argument('--accuracy-comparison', action='store_true', default=False,
                   help='')

group = parser.add_argument_group('Miscellaneous parameters')
group.add_argument('--output', default='', type=str, metavar='PATH',
                   help='path to output folder (default: none, current dir)')
group.add_argument('--experiment', default='', type=str, metavar='NAME',
                   help='name of train experiment, name of sub-folder for output')


def _parse_args():
    # Do we have a config file to parse?
    args_config, remaining = config_parser.parse_known_args()
    if args_config.config:
        with open(args_config.config, 'r') as f:
            cfg = yaml.safe_load(f)
            parser.set_defaults(**cfg)
    
    # The main arg parser parses the rest of the args, the usual
    # defaults will have been overridden if config file specified.
    args = parser.parse_args(remaining)

    # Cache the args as a text string to save them in the output dir later
    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
    return args, args_text


def copy_files(output_directory, directory, logging_list):
    sub_path = os.path.join(*directory.strip(os.sep).split(os.sep)[-2:])
    target_directory = os.path.join(output_directory, sub_path)
    os.makedirs(target_directory, exist_ok=True)
    for file_name in logging_list:
        from_file = os.path.join(directory, file_name)
        to_file = os.path.join(target_directory, file_name)
        dir_name = os.path.dirname(to_file)
        if dir_name:
            os.makedirs(dir_name, exist_ok=True)
        if os.path.basename(from_file) == 'hardware.json':
            if os.path.exists(from_file):
                shutil.copy2(from_file, to_file)
        else:
            shutil.copy2(from_file, to_file)
    return


def reproduction_copy(train_list, results_list, train_logging_list, results_logging_list, output_directory):
    for directory in train_list:
        copy_files(output_directory, directory, train_logging_list)
    for directory in results_list:
        copy_files(output_directory, directory, results_logging_list)            
    return


def analysis(args, output_directory=''):
    train_logging_list = ['args.yaml', 'summary.csv', 'hardware.json']
    results_logging_list = ['args.yaml', 'summary.csv', 'hardware.json']
    if args.layer_difference_comparison:
        assert len(args.results_list) == 2
        cur_train_logging_list, cur_results_logging_list = get_layer_difference_logging_list()
        train_logging_list.extend(cur_train_logging_list)
        results_logging_list.extend(cur_results_logging_list)
    if args.cluster_size_frequency_comparison or args.n_skipped_layer_cluster_comparison or args.n_cluster_per_layer_comparison or args.cluster_size_frequency_histogram_comparison:
        assert len(args.results_list) == 2
        cur_train_logging_list, cur_results_logging_list = get_cluster_logging_list(args.cluster_size_frequency_comparison, args.n_skipped_layer_cluster_comparison, args.n_cluster_per_layer_comparison, args.cluster_size_frequency_histogram_comparison)
        train_logging_list.extend(cur_train_logging_list)
        results_logging_list.extend(cur_results_logging_list)
        output_dir = f"{output_directory}/cluster"
        os.makedirs(output_dir)
    if args.accuracy_comparison:
        cur_train_logging_list, cur_results_logging_list = get_accuracy_logging_list()
        train_logging_list.extend(cur_train_logging_list)
        results_logging_list.extend(cur_results_logging_list)
    reproduction_copy(args.train_list, args.results_list, train_logging_list, results_logging_list, output_directory)
    if args.layer_difference_comparison:
        layer_difference_comparison(args.exp_label_list, args.results_list, output_directory)
    cluster_comparison(args.exp_label_list, args.results_list, output_directory, args.cluster_size_frequency_comparison, args.n_skipped_layer_cluster_comparison, args.n_cluster_per_layer_comparison, args.cluster_size_frequency_histogram_comparison)
    accuracy_comparison(args.results_list, output_directory, args.accuracy_comparison)
    return


def main():
    utils.setup_default_logging()
    args, args_text = _parse_args()

    device = utils.init_distributed_device(args)

    output_dir = None
    if utils.is_primary(args):
        if args.experiment:
            exp_name = args.experiment
        else:
            cur_now = datetime.now().strftime("%Y%m%d-%H%M%S")
            '''
            if args.time_stamp:
                datetime.strptime(args.time_stamp, "%Y%m%d-%H%M%S")
                cur_now = args.time_stamp
            '''
            exp_name = '-'.join([
                cur_now,
                'analysis'
                # safe_model_name(args.model),
                # str(data_config['input_size'][-1])
            ])
        output_dir = utils.get_outdir(args.output if args.output else './output/analysis', exp_name)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    analysis(args, output_directory=output_dir)
    
    return


if __name__ == '__main__':
    main()
