import argparse
import torch
import os

from init import Initial
from predictor import PredictorModel
from sampler import TaskSampler

ZERO_COST_PROXY_LIST = ['grad_norm', 'snip', 'grasp', 'fisher', 'jacob_cov', 'plain', 'synflow']

def get_e2e_acc_zerocost(args, topk=1):
    net_index = args.topk_idx[str(args.proxy)][str(args.ds_name)][topk]
    log_path = f'/d1/code/exp/resnet/resnet42/nas'
    log_path += f'/{args.ds_name}/{args.proxy}/top-{topk}/net-{net_index}'
    log_path += f'/finetuning/logs.pt'
    logs = torch.load(log_path)
    final_acc = logs['valid/top1'][-1].item()
    best_acc = logs['valid/best_acc'][-1].item()
    return final_acc, best_acc

def search_zero_cost(args, net_info_list):
    score_list = []
    for net_idx, net in enumerate(net_info_list):
        log_path = '/d1/code/zero-cost-nas/exp/nas'
        log_path += f'/resnet'
        log_path += f'/prtype-{args.zero_cost_pr_type}/num_batches-{args.num_batches}/net-{net_idx}/logs.pt'
        log = torch.load(log_path)
        score_list.append(log[args.proxy])
    scores, net_indices = torch.topk(torch.tensor(score_list), args.topk, largest=True)
    return net_indices

def search_flops_params(args, net_info_list):
    score_list = []
    for net_idx, net in enumerate(net_info_list):
        if args.proxy == 'flops':
            score_list.append(net[0])
        elif args.proxy == 'params':
            score_list.append(net[1])
    scores, net_indices = torch.topk(torch.tensor(score_list), args.topk, largest=True)
    return net_indices

def search_kd_predictor(args):
    load_path = args.ours1_ckpt
    log_path = args.ours1_log
    logger = torch.load(log_path)
    predictor_args = logger['args']

    predictor = PredictorModel(predictor_args)
    predictor.load_model(load_path)
    y_pred_all_dict, elapsed_time_dict = predictor.nas()
    
    net_indices_dict = {}
    for ds_name, y_pred_all in y_pred_all_dict.items():
        net_indices_dict[f'{ds_name}'] = torch.topk(torch.tensor(y_pred_all), args.topk, largest=True)
    
    return net_indices_dict


def get_net_acc(args, topk=1):
    load_path = f'/d1/code/exp/nas'
    load_path += f'/{args.proxy}'
    load_path += f'/{args.ds_name}'
    load_path += f'prtype-copy_paste_first/lit/net-{args.topk_idx[args.proxy][topk]}/logs.pt'
    log = torch.load(load_path)
    final_acc = log['valid/top1'][-1].item()
    best_acc = log['valid/best_acc'][-1].item()
    final_loss = log['valid/loss'][-1]
    return final_acc, best_acc, final_loss


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--sampled_net_info_path', type=str, default=None)
    parser.add_argument('--task', type=str, default='search', help='search | end2end')
    parser.add_argument('--proxy', type=str, default='grad_norm',
                        help='grad_norm, ... | ours | flops | params')
    parser.add_argument('--topk', type=int, default=3)
    parser.add_argument('--pr_type', type=str, default='copy_paste_first')
    parser.add_argument('--zero_cost_pr_type', type=str, default='random_init')
    parser.add_argument('--num_batches', type=int, default=0)
    parser.add_argument('--nas_datasets', nargs='+', type=str, 
                        default=['cub', 'cifar100', 'dtd', 'stl10'])
    parser.add_argument('--ds_name', type=str, default='cub')
    args = parser.parse_args()

    base_configs = ['ckpt.yaml', 'nas.yaml']
    initial = Initial(args, base_configs=base_configs)
    args = initial.args

    # flops, params, depth_config, channel_widths, image_size
    net_info_path = args.sampled_net_info_path 
    net_info_list = torch.load(net_info_path)
    
    if args.task == 'search':
        if args.proxy in ZERO_COST_PROXY_LIST:
            for args.ds_name in args.nas_datasets:
                net_indices = search_zero_cost(args, net_info_list)
                print(f'{args.ds_name} | {args.proxy} | {net_indices}')
        
        elif args.proxy in ['flops', 'params']:
            net_indices = search_flops_params(args, net_info_list)
            print(args.proxy, net_indices)

    elif args.task == 'end2end':
        if args.proxy in ZERO_COST_PROXY_LIST:
            for args.ds_name in args.nas_datasets:
                final_acc, best_acc = get_e2e_acc_zerocost(args, topk=1)
                print(f'{args.proxy} | {args.ds_name} | Acc: {final_acc:.2f}')
    