import os
import argparse

import sys
sys.path.insert(0, '../../')
sys.path.insert(0, '../../../')
import torch
import torch.nn as nn
import torch.optim as optim

from network_designer.trainner.nasbench101.datasets.cifar10 import prepare_dataset
from network_designer.design_space.nasbench101.network import Network
from network_designer.design_space.nasbench101.nasbench1_spec  import ModelSpec
from network_designer.trainner.nasbench101.trainner import train, test

import pickle
import pandas as pd
import numpy as np

def save_checkpoint(net, seed, id, acc, postfix='cifar10'):
    print('--- Saving Checkpoint ---')
    save_path = '../../../experiments/nb101-like-train'
    os.makedirs(save_path, exist_ok=True)
    
    ckpt_path = '../../../experiments/nb101-like-train/{}_{}_{:.2f}.pt'.format(seed, id, acc)

    torch.save(net.state_dict(), ckpt_path)

def reload_checkpoint(path, device=None):
    print('--- Reloading Checkpoint ---')

    assert os.path.isdir('checkpoint'), '[Error] No checkpoint directory found!'
    return torch.load(path, map_location=device)

def pick_gpu_lowest_memory():
    import gpustat
    stats = gpustat.GPUStatCollection.new_query()
    ids = map(lambda gpu: int(gpu.entry['index']), stats)
    ratios = map(lambda gpu: float(gpu.memory_used)/float(gpu.memory_total), stats)
    bestGPU = min(zip(ids, ratios), key=lambda x: x[1])[0]
    return bestGPU

OPS=['conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3','input', 'output']

parser = argparse.ArgumentParser(description='NASBench')
parser.add_argument('--random_state', default=1, type=int, help='Random seed.')
parser.add_argument('--data_root', default='./data/', type=str, help='Path where cifar will be downloaded.')
parser.add_argument('--in_channels', default=3, type=int, help='Number of input channels.')
parser.add_argument('--stem_out_channels', default=128, type=int, help='output channels of stem convolution')
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules')
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack')
parser.add_argument('--batch_size', default=256, type=int, help='batch size')
parser.add_argument('--test_batch_size', default=256, type=int, help='test set batch size')
parser.add_argument('--epochs', default=108, type=int, help='#epochs of training')
parser.add_argument('--validation_size', default=10000, type=int, help="Size of the validation set to split off.")
parser.add_argument('--num_workers', default=0, type=int, help="Number of parallel workers for the train dataset.")
parser.add_argument('--learning_rate', default=0.2, type=float, help='base learning rate')
parser.add_argument('--lr_decay_method', default='COSINE_BY_STEP', type=str, help='learning decay method')
parser.add_argument('--optimizer', default='rmsprop', type=str, help='Optimizer (sgd, rmsprop or rmsprop_tf)')
parser.add_argument('--rmsprop_eps', default=1.0, type=float, help='RMSProp eps parameter.')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', default=1e-4, type=float, help='L2 regularization weight')   
parser.add_argument('--grad_clip', default=5, type=float, help='gradient clipping')
parser.add_argument('--grad_clip_off', default=False, type=bool, help='If True, turn off gradient clipping.')
parser.add_argument('--batch_norm_momentum', default=0.997, type=float, help='Batch normalization momentum')
parser.add_argument('--batch_norm_eps', default=1e-5, type=float, help='Batch normalization epsilon')
parser.add_argument('--load_checkpoint', default='', type=str, help='Reload model from checkpoint')
parser.add_argument('--num_labels', default=10, type=int, help='#classes')
parser.add_argument('--device', default='cuda', type=str, help='Device for network training.')
parser.add_argument('--print_freq', default=100, type=int, help='Batch print frequency.')
parser.add_argument('--tf_like', default=False, type=bool,
                    help='If true, use same weight initialization as in the tensorflow version.')
parser.add_argument('--space_root', type=str, default='../experiments/DEMO/step_2/')
parser.add_argument('--space', type=str, default='', help="pickle file name for sampled space")
parser.add_argument('--id', type=int, default=0, help="id for architecutres in space")

args = parser.parse_args()

def remove_zero_rows_cols(array):
    n = array.shape[0]  # Get the size of the square array
    indices = np.arange(n)
    zero_indices = indices[np.all(array == 0, axis=0) & np.all(array == 0, axis=1)]

    # Remove the zero rows and columns
    trimmed_array = np.delete(array, zero_indices, axis=0)
    trimmed_array = np.delete(trimmed_array, zero_indices, axis=1)

    return trimmed_array

def load_graph_from_pickle(space_root, space, id):
    f = open('{}/{}.pkl'.format(space_root, space), 'rb')
    data = pickle.load(f)
    dataset = pd.DataFrame(data)
    
    print(dataset.iloc[id])
    adj = dataset.iloc[id]['adj_matrix']
    ops = dataset.iloc[id]['ops_features']
    
    o = ops.astype(int)
    a = adj.astype(int)
    a = np.triu(a, 1)

    labels = []
    for idx, ops_f in enumerate(list(o)):
        if sum(ops_f) > 0:
            if ((sum(a[:, idx]) >0) or (idx in [0, 6])):

                labels.append(OPS[np.argmax(ops_f)])
    
    a = remove_zero_rows_cols(a).tolist()
    print(a)
    print(labels)
    return a, labels

if __name__ == '__main__':
    args.device = pick_gpu_lowest_memory()
    print(args.device)
    # cifar10 dataset
    dataset = prepare_dataset(args.batch_size, test_batch_size=args.test_batch_size, root=args.data_root,
                              validation_size=args.validation_size, random_state=args.random_state,
                              set_global_seed=True, num_workers=args.num_workers)

    train_loader, test_loader, test_size = dataset['train'], dataset['test'], dataset['test_size']
    valid_loader = dataset['validation'] if args.validation_size > 0 else None

    matrix, operations = load_graph_from_pickle(args.space_root, args.space, args.id)
    # model
    spec = ModelSpec(matrix, operations)
    net = Network(spec, num_labels=args.num_labels, in_channels=args.in_channels,
                  stem_out_channels=args.stem_out_channels, num_stacks=args.num_stacks,
                  num_modules_per_stack=args.num_modules_per_stack,
                  momentum=args.batch_norm_momentum, eps=args.batch_norm_eps, tf_like=args.tf_like)

    if args.load_checkpoint != '':
        net.load_state_dict(reload_checkpoint(args.load_checkpoint))
    net.to(args.device)

    criterion = nn.CrossEntropyLoss()

    if args.optimizer.lower() == 'sgd':
        optimizer = optim.SGD
        optimizer_kwargs = {}
    elif args.optimizer.lower() == 'rmsprop':
        optimizer = optim.RMSprop
        optimizer_kwargs = {'eps': args.rmsprop_eps}
    elif args.optimizer.lower() == 'rmsprop_tf':
        from timm.optim import RMSpropTF
        optimizer = RMSpropTF
        optimizer_kwargs = {'eps': args.rmsprop_eps}
    else:
        raise ValueError(f"Invalid optimizer {args.optimizer}, possible: SGD, RMSProp")

    optimizer = optimizer(net.parameters(), lr=args.learning_rate, momentum=args.momentum,
                          weight_decay=args.weight_decay, **optimizer_kwargs)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)

    result = train(net, train_loader, loss=criterion, optimizer=optimizer, scheduler=scheduler,
                   grad_clip=args.grad_clip if not args.grad_clip_off else None,
                   num_epochs=args.epochs, num_validation=args.validation_size, validation_loader=valid_loader,
                   device=args.device, print_frequency=args.print_freq)

    last_epoch = {k: v[-1] for k, v in result.items() if len(v) > 0}
    print(f"Final train metrics: {last_epoch}")

    result = test(net, test_loader, loss=criterion, num_tests=test_size, device=args.device)
    print(f"\nFinal test metrics: {result}")

    save_checkpoint(net, seed=args.random_state, id=args.id, acc=result['test_accuracy'])