"""Determine the effective parameters of models in a given number of folders.
"""

import sys
import os
import argparse

import torch
import numpy as np
import pandas as pd
from tqdm import tqdm

sys.path.append('')
import functions.sgc_resnets as sgc
import functions.datasets as datasets
import functions.sir_metrics as sir

def main(args):
    _, _, val_loader = datasets.get_data(args.data, trainset=args.train_size)
    with torch.no_grad():
        for folder in tqdm(args.folders):
            path = os.path.join(args.path, folder, 'default')
            df = pd.DataFrame({
                'model_folder': [],
                'metric': [],
                'layer': [],
                'criterion': [],
                'value': []
            })
            for model_folder in os.listdir(path):
                model_path = os.path.join(path, model_folder, 'checkpoints')
                epoch = {
                    int(ckpt.split('.')[0].split('=')[1]): ckpt for ckpt in os.listdir(model_path)
                }
                print(np.array(list(epoch.keys())))
                max_epoch = np.array(list(epoch.keys())).max()
                print(max_epoch)
                ckpt = epoch[max_epoch]
                print(ckpt)
                full_path = os.path.join(model_path, ckpt)
                model = sgc.SGCResNetModule.load_from_checkpoint(full_path, reload=True)\
                                           .eval()\
                                           .to(args.device)
                values = {}
                for metric in ['early_late_readout', 'layer_dropout']:
                    values[metric] = {
                        'acc': 0, 'crossentropy': 0, 'blockdev': 0, 'size': 0
                    }
                for data in val_loader:
                    for metric, metric_name in zip(
                        [sir.early_late_readout, sir.layer_dropout],
                        ['early_late_readout', 'layer_dropout']
                    ):
                        new_values = sir.apply(
                            model.model, data, metric, device=args.device
                        )
                        for key, value in new_values.items():
                            values[metric_name][key] = values[metric_name][key] + new_values[key]
                for metric in ['early_late_readout', 'layer_dropout']:
                    for type in ['acc', 'crossentropy', 'blockdev']:
                        length = len(values[metric][type])
                        new_df = pd.DataFrame({
                            'folder': np.array([folder]*length),
                            'model_folder': np.array([model_folder]*length),
                            'metric': np.array([metric]*length),
                            'layer': np.linspace(0, length-1, length),
                            'criterion': np.array([type]*length),
                            'value': values[metric][type].numpy()/values[metric]['size']
                        })
                        df = df.append(new_df, ignore_index=True)
            df.to_csv(os.path.join(args.savepath, folder+'.csv'))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('path')
    parser.add_argument('savepath')
    parser.add_argument('folders', type=str, nargs='+')
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--data', type=str, default='cifar10')
    parser.add_argument('--train_size', type=int, default=None)
    args = parser.parse_args()
    main(args)
