from functools import partial
import argparse
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
from torch.utils.data import random_split

import torchvision
from torchvision import datasets, transforms, models
import ray
from ray import tune
import ray.train as train
from ray.tune import CLIReporter

from ray.tune.schedulers import PopulationBasedTraining
from ray.tune.schedulers.pbt import PopulationBasedTrainingReplay

from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
from ray.tune.suggest.bohb import TuneBOHB
from ray.tune.schedulers import ASHAScheduler
from ray.tune.schedulers import HyperBandScheduler

import pandas as pd

from ray.tune.integration.torch import (DistributedTrainableCreator, distributed_checkpoint_dir)

# nats bench imports
import nats_bench.cells
import nats_bench.cell_operations
from nats_bench.InferTinyCellNet import DynamicShapeTinyNet
import nats_bench.genotypes as genotypes
from nats_bench.DownsampledImageNet import ImageNet16



def parsIni():
    parser = argparse.ArgumentParser(description='Ray Tune ImageNet Example')
    parser.add_argument('--num-samples', type=int, default=24, metavar='N',
                    help='number of samples to train (default: 32)')
    parser.add_argument('--par-workers', type=int, default=1, metavar='N',
                    help='parallel workers to train on a single trial (default: 1)')
    parser.add_argument('--scheduler', type=str, default='RAND',
                    help='scheduler for tuning (default: RandomSearch)')
    parser.add_argument('--dataset', type=str, default='cifar-10',
                    help='dataset to evaluate (default: cifar-10)')
    parser.add_argument('--seed', type=int, default='111',
                    help='seed to use (default: 111)')

    return parser

# mean of field over GPGPUs
def par_mean(field):
    res = torch.tensor(field).float()
    #res = res.cuda()
    dist.all_reduce(res,op=dist.ReduceOp.SUM,group=None,async_op=True).wait()
    res/=dist.get_world_size()
    return res

def par_sum(field):
    res = torch.tensor(field).float()
    #res = res.cuda()
    dist.all_reduce(res,op=dist.ReduceOp.SUM,group=None,async_op=True).wait()
    return res
        
def accuracy(output, target):
    # get the index of the max log-probability
    pred = output.max(1, keepdim=True)[1]
    correct = pred.eq(target.view_as(pred)).cpu().float().sum()
    total = target.size(0)
    return correct, total


def load_data(dataset="cifar-10", seed=111, data_dir=None):
    
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    # get dataset
    if dataset == "cifar-10":
                
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
        
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    
 
        trainset = torchvision.datasets.CIFAR10(
            root=data_dir, train=True, download=True, transform=transform_train)

        valset = torchvision.datasets.CIFAR10(
            root=data_dir, train=True, download=True, transform=transform_test)

        testset = torchvision.datasets.CIFAR10(
            root=data_dir, train=False, download=True, transform=transform_test)
        
        
    elif dataset == "cifar-100":
        
        
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
        
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    
        trainset = torchvision.datasets.CIFAR100(
            root=data_dir, train=True, download=True, transform=transform_train)

        valset = torchvision.datasets.CIFAR100(
            root=data_dir, train=True, download=True, transform=transform_test)

        testset = torchvision.datasets.CIFAR100(
            root=data_dir, train=False, download=True, transform=transform_test)
    
    elif dataset == "imagenet-16":
        
        
        mean = [x / 255 for x in [122.68, 116.66, 104.01]]
        std = [x / 255 for x in [63.22, 61.26, 65.09]]
    
        transform_train = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(16, padding=2),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        
        trainset = ImageNet16(data_dir, True, transform_train, 120)
        valset = ImageNet16(data_dir, True, transform_test, 120)
        testset = ImageNet16(data_dir, False, transform_test, 120)
        
        
    elif dataset == "imagenet-1k":
                
        mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
    
        transform_train = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomResizedCrop(size=224, scale=(0.1, 1.0), ratio=(0.8, 1.25)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        
        trainset = torchvision.datasets.ImageFolder(data_dir+'train', transform_train)
        valset = torchvision.datasets.ImageFolder(data_dir+'train', transform_test)
        testset = torchvision.datasets.ImageFolder(data_dir+'val', transform_test)
        
    
    # compute train/val split
    
    val_size = 0.2
    indices = torch.randperm(len(trainset))
    val_size = int(len(trainset)*val_size)
    trainset = torch.utils.data.Subset(trainset, indices[:-val_size])
    valset = torch.utils.data.Subset(valset, indices[-val_size:])

    return trainset, valset, testset
    
    

def train_cifar(config, checkpoint_dir=None):
    
    np.random.seed(config["seed"])
    torch.manual_seed(config["seed"])
        
    gwsize = dist.get_world_size()     # global world size - per run
    lwsize = torch.cuda.device_count() # local world size - per node
    grank = dist.get_rank()            # global rank - assign per run
    lrank = dist.get_rank()%lwsize     # local rank - assign per node
    
    genotype_structure = "|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|"
    genotype_structure = genotypes.Structure.str2fullstructure(genotype_structure)
    
    if config["dataset"] == "cifar-10":
        num_classes = 10
    elif config["dataset"] == "cifar-100":
        num_classes = 100
    elif config["dataset"] == "imagenet-16":
        num_classes = 120
    elif config["dataset"] == "imagenet-1k":
        num_classes = 1000
    
    

    model = DynamicShapeTinyNet(channels=[config["channels_1"], config["channels_2"], config["channels_3"], config["channels_4"], config["channels_5"]], genotype = genotype_structure, num_classes = num_classes)
    
        
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    model.to(device)

    model = torch.nn.parallel.DistributedDataParallel(model)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1*gwsize, momentum=0.9, weight_decay=0.0005, nesterov=True)
    

    trainset, valset, testset = load_data(dataset=config["dataset"] , seed=config["seed"] )
    
    
    # # checkpointing
    # if checkpoint_dir:
    #     model_state, optimizer_state = torch.load(
    #         os.path.join(checkpoint_dir, "checkpoint"))
    #     model.load_state_dict(model_state)
    #     optimizer.load_state_dict(optimizer_state)

        
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        trainset, num_replicas=dist.get_world_size(), rank=dist.get_rank())
    
    val_sampler = torch.utils.data.distributed.DistributedSampler(
        valset, num_replicas=dist.get_world_size(), rank=dist.get_rank())
    
    test_sampler = torch.utils.data.distributed.DistributedSampler(
        testset, num_replicas=dist.get_world_size(), rank=dist.get_rank())

    trainloader = torch.utils.data.DataLoader(
        trainset,
        #batch_size=int(256 / gwsize),
        batch_size=256,
        sampler=train_sampler,
        num_workers=16)

    valloader = torch.utils.data.DataLoader(
        valset,
        batch_size=256,
        sampler=val_sampler,
        #shuffle=False,
        num_workers=16)
    
    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=256,
        sampler=test_sampler,
        #shuffle=False,
        num_workers=16)
    
    train_loader_len = len(trainloader)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=90*train_loader_len)
        
    for epoch in range(90):  # loop over the dataset multiple times
        model.train()
        train_acc = 0
        train_loss = 0
        train_correct = 0
        train_total = 0
        for i, (images, target) in enumerate(trainloader):

            # adjust LR
            images = images.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            
            
            # compute output
            optimizer.zero_grad()
            output = model(images)

            # compute loss and accurcay
            loss = criterion(output, target)
            train_loss += loss

            tmp_correct, tmp_total = accuracy(output, target)    
            train_correct +=tmp_correct
            train_total +=tmp_total

            # optimization step
            loss.backward() 
            optimizer.step()
        
            lr_scheduler.step()
        
        if (epoch > 0 and (epoch+1) % 10 == 0):
            # average train metrics
            train_correct = par_sum(train_correct)
            train_total = par_sum(train_total)

            train_loss = par_mean(train_loss)

            train_acc = train_correct/train_total

        
            # perform validation step
            val_correct = 0
            val_total = 0
            val_loss = 0
            model.eval()
            with torch.no_grad():
                for i, (images, target) in enumerate(valloader): 
                    images = images.to(device, non_blocking=True)
                    target = target.to(device, non_blocking=True)

                    output = model(images)
                    val_loss += criterion(output, target)

                    tmp_correct, tmp_total = accuracy(output, target)
                    val_correct +=tmp_correct
                    val_total +=tmp_total    
                    
                # average validation metrics
                val_correct = par_sum(val_correct)
                val_total = par_sum(val_total)

                val_loss = par_mean(val_loss)

                val_acc = val_correct/val_total
                
                
            test_correct = 0
            test_total = 0
            test_loss = 0
            model.eval()
            with torch.no_grad():
                for i, (images, target) in enumerate(testloader): 
                    images = images.to(device, non_blocking=True)
                    target = target.to(device, non_blocking=True)

                    output = model(images)
                    test_loss += criterion(output, target)

                    tmp_correct, tmp_total = accuracy(output, target)
                    test_correct +=tmp_correct
                    test_total +=tmp_total    
                    
                # average validation metrics
                test_correct = par_sum(test_correct)
                test_total = par_sum(test_total)

                test_loss = par_mean(test_loss)

                test_acc = test_correct/test_total
            
            # save checkpoint
            # with distributed_checkpoint_dir(step=epoch) as checkpoint_dir:
            #     path = os.path.join(checkpoint_dir, "checkpoint")
            #     #print("Saving: ", path)
            #     torch.save((model.state_dict(), optimizer.state_dict()), path)


            # report results to tune
            tune.report(train_acc=train_acc, val_acc=val_acc, test_acc=test_acc)
        
    print("Finished Training")
    

def main(args, num_samples=32, max_num_epochs=9):
    ray.init(address='auto')
    
    
    np.random.seed(args.seed)
    
    if args.dataset == "imagenet-1k":
        config = {
            "channels_1": tune.choice([64, 128, 256, 512]),
            "channels_2": tune.choice([64, 128, 256, 512]),
            "channels_3": tune.choice([64, 128, 256, 512]),
            "channels_4": tune.choice([64, 128, 256, 512]),
            "dataset": tune.choice([str(args.dataset)]),
            "seed": tune.choice([args.seed])
        }
    else:
        config = {
            "channels_1": tune.choice([8, 16, 24, 32, 40, 48, 56, 64]),
            "channels_2": tune.choice([8, 16, 24, 32, 40, 48, 56, 64]),
            "channels_3": tune.choice([8, 16, 24, 32, 40, 48, 56, 64]),
            "channels_4": tune.choice([8, 16, 24, 32, 40, 48, 56, 64]),
            "channels_5": tune.choice([8, 16, 24, 32, 40, 48, 56, 64]),
            "dataset": tune.choice([str(args.dataset)]),
            "seed": tune.choice([args.seed])
        }
    
    ddp_trainable = DistributedTrainableCreator(
        train_cifar,
        num_workers=args.par_workers,
        num_gpus_per_worker=1,
        num_cpus_per_worker=32,
        #backend="nccl",
        #num_workers_per_host=4
    )
    
    
    if (args.scheduler == "ASHA"):
        scheduler = ASHAScheduler(
               max_t=max_num_epochs,
               grace_period=2,
               reduction_factor=3)

        search_alg = None
    
    if (args.scheduler == "BOHB"):
        scheduler = HyperBandForBOHB(
               time_attr="training_iteration",
               max_t=max_num_epochs,
               reduction_factor=3,
               stop_last_trials=True)

        search_alg = TuneBOHB(seed=args.seed)
        
    if (args.scheduler == "HB"):
        scheduler = HyperBandScheduler(
               time_attr="training_iteration",
               max_t=max_num_epochs,
               reduction_factor=3,
               stop_last_trials=True)

        search_alg = None

    if (args.scheduler == "RAND"):
        scheduler = None
        search_alg = None
    
    reporter = CLIReporter(
        metric_columns=["train_acc", "val_acc", "test_acc", "training_iteration", "time_this_iter_s", "time_total_s"],
        max_report_frequency=60)
    
    result = tune.run(
        ddp_trainable,
        local_dir=os.path.join(os.path.abspath(os.getcwd()), "ray_results"),
        config=config,
        num_samples=num_samples,
        scheduler=scheduler,
        search_alg=search_alg,
        name=str("eval" + (args.scheduler)),
        metric="val_acc",
        mode="max",
        stop={"training_iteration": max_num_epochs},
        progress_reporter=reporter,
        verbose=2
        )

        
    best_trial = result.get_best_trial("val_acc", "max", "last")
    print("Best trial config: {}".format(best_trial.config))
    print("Best trial final validation accuracy: {}".format(
        best_trial.last_result["val_acc"]))
    print("Best trial final test accuracy: {}".format(
        best_trial.last_result["test_acc"]))    


if __name__ == "__main__":
    
    parser = parsIni()
    args = parser.parse_args()
    # You can change the number of GPUs per trial here:
    main(args, num_samples=args.num_samples, max_num_epochs=9)