import os
import argparse

import sys
sys.path.insert(0, '../../')
sys.path.insert(0, '../../../')
import torch
from network_designer.trainner.nasbench101.datasets.cifar10 import prepare_dataset
from network_designer.design_space.nasbench101.network import Network
from network_designer.design_space.nasbench101.nasbench1_spec  import ModelSpec
import logging
import pickle
import pandas as pd
import numpy as np
import random
import torch.backends.cudnn as cudnn

def pick_gpu_lowest_memory():
    import gpustat
    stats = gpustat.GPUStatCollection.new_query()
    ids = map(lambda gpu: int(gpu.entry['index']), stats)
    ratios = map(lambda gpu: float(gpu.memory_used)/float(gpu.memory_total), stats)
    bestGPU = min(zip(ids, ratios), key=lambda x: x[1])[0]
    return bestGPU

def remove_zero_rows_cols(array):
    n = array.shape[0]  # Get the size of the square array
    indices = np.arange(n)
    zero_indices = indices[np.all(array == 0, axis=0) & np.all(array == 0, axis=1)]

    # Remove the zero rows and columns
    trimmed_array = np.delete(array, zero_indices, axis=0)
    trimmed_array = np.delete(trimmed_array, zero_indices, axis=1)

    return trimmed_array

def load_graph_from_pickle(space_root, space, id):
    f = open('{}/{}.pkl'.format(space_root, space), 'rb')
    data = pickle.load(f)
    dataset = pd.DataFrame(data)
    
    print(dataset.iloc[id])
    adj = dataset.iloc[id]['adj_matrix']
    ops = dataset.iloc[id]['ops_features']
    
    o = ops.astype(int)
    a = adj.astype(int)
    a = np.triu(a, 1)

    labels = []
    for idx, ops_f in enumerate(list(o)):
        if sum(ops_f) > 0:
            if ((sum(a[:, idx]) >0) or (idx in [0, 6])):

                labels.append(OPS[np.argmax(ops_f)])
    
    a = remove_zero_rows_cols(a).tolist()
    print(a)
    print(labels)
    return a, labels

def seed_torch(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    cudnn.deterministic = True
    cudnn.benchmark = True

def calc_stats(values):
    averages = []
    for subvalues in values:
        q25 = np.percentile(subvalues, 25)
        q75 = np.percentile(subvalues, 75)
        subvalues_filtered = list(filter(lambda x : (x >= q25) and (x <= q75), subvalues))
        averages.append(np.mean(subvalues_filtered))
    q25 = np.percentile(averages, 25)
    q75 = np.percentile(averages, 75)
    averages_filtered = list(filter(lambda x : (x >= q25) and (x <= q75), averages))
    return np.mean(averages_filtered)

def measure_latency(model, input_data, num_runs=1000):
    model = model.cuda()
    model.eval()

    # Warm-up run
    with torch.no_grad():
        model(input_data)

    run_times = []

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    # Timing loop
    for _ in range(num_runs):
        start_event.record()

        with torch.no_grad():
            model(input_data)
        
        end_event.record()

        # Waits for everything to finish running
        torch.cuda.synchronize()

        run_times.append(start_event.elapsed_time(end_event))

    run_times = np.array(run_times)

    print(f'Average inference time over {num_runs} runs: {np.mean(run_times):.3f} ms')
    print(f'Minimum inference time over {num_runs} runs: {np.min(run_times):.3f} ms')
    print(f'Maximum inference time over {num_runs} runs: {np.max(run_times):.3f} ms')


OPS=['conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3','input', 'output']

parser = argparse.ArgumentParser(description='NASBench')
parser.add_argument('--random_state', default=1, type=int, help='Random seed.')
parser.add_argument('--data_root', default='./data/', type=str, help='Path where cifar will be downloaded.')
parser.add_argument('--in_channels', default=3, type=int, help='Number of input channels.')
parser.add_argument('--stem_out_channels', default=128, type=int, help='output channels of stem convolution')
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules')
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack')
parser.add_argument('--batch_norm_momentum', default=0.997, type=float, help='Batch normalization momentum')
parser.add_argument('--batch_norm_eps', default=1e-5, type=float, help='Batch normalization epsilon')
parser.add_argument('--num_labels', default=10, type=int, help='#classes')
parser.add_argument('--device', default='cuda', type=str, help='Device for network training.')
parser.add_argument('--tf_like', default=False, type=bool,
                    help='If true, use same weight initialization as in the tensorflow version.')
parser.add_argument('--space_root', type=str, default='../experiments/DEMO/step_2/')
parser.add_argument('--space', type=str, default='', help="pickle file name for sampled space")
parser.add_argument('--id', type=int, default=0, help="id for architecutres in space")

args = parser.parse_args()

args.save = '{}/{}/{}/{}-{}'.format(args.exp_path, args.space, args.random_state, args.id)

os.makedirs(args.save, exist_ok=True)

log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                    format=log_format, datefmt='%m/%d %I:%M:%S %p')
log_file = 'log.txt'
fh = logging.FileHandler(os.path.join(args.save, log_file), mode='w')
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

if __name__ == '__main__':
    args.device = pick_gpu_lowest_memory()

    matrix, operations = load_graph_from_pickle(args.space_root, args.space, args.id)
    # model
    spec = ModelSpec(matrix, operations)
    net = Network(spec, num_labels=args.num_labels, in_channels=args.in_channels,
                  stem_out_channels=args.stem_out_channels, num_stacks=args.num_stacks,
                  num_modules_per_stack=args.num_modules_per_stack,
                  momentum=args.batch_norm_momentum, eps=args.batch_norm_eps, tf_like=args.tf_like)

        
    input_data = torch.rand(1, 3, 32, 32).cuda()
    
    run_stats = measure_latency(net, input_data)
    
    final_stats = calc_stats(run_stats)
    logging.info('Final_Latency:{}'.format(final_stats))