from functools import partial
import argparse
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
from torch.utils.data import random_split

import torchvision
from torchvision import datasets, transforms, models
import ray
from ray import tune
import ray.train as train
from ray.tune import CLIReporter

from ray.tune.schedulers import PopulationBasedTraining
from ray.tune.schedulers.pbt import PopulationBasedTrainingReplay

from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
from ray.tune.suggest.bohb import TuneBOHB
from ray.tune.schedulers import ASHAScheduler
from ray.tune.schedulers import HyperBandScheduler

import pandas as pd

from ray.tune.integration.torch import (DistributedTrainableCreator, distributed_checkpoint_dir)


# nats bench imports
import nats_bench.cells
import nats_bench.cell_operations
from nats_bench.InferTinyCellNet import DynamicShapeTinyNet
import nats_bench.genotypes as genotypes
from nats_bench.DownsampledImageNet import ImageNet16


def parsIni():
    parser = argparse.ArgumentParser(description='Ray Tune ImageNet Example')
    parser.add_argument('--num-samples', type=int, default=24, metavar='N',
                    help='number of samples to train (default: 32)')
    parser.add_argument('--par-workers', type=int, default=1, metavar='N',
                    help='parallel workers to train on a single trial (default: 1)')
    parser.add_argument('--scheduler', type=str, default='RAND',
                    help='scheduler for tuning (default: RandomSearch)')
    parser.add_argument('--dataset', type=str, default='cifar-10',
                    help='dataset to evaluate (default: cifar-10)')
    parser.add_argument('--seed', type=int, default='111',
                    help='seed to use (default: 111)')
    parser.add_argument('--scale-bs', type=int, default='0',
                    help='scale batch size with num workers (default: 0=False)')
    return parser

def load_data(dataset="cifar-10", seed=111, data_dir=None):
    
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    # get dataset
    if dataset == "cifar-10":
                
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
        
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    
 
        trainset = torchvision.datasets.CIFAR10(
            root=data_dir, train=True, download=True, transform=transform_train)

        valset = torchvision.datasets.CIFAR10(
            root=data_dir, train=True, download=True, transform=transform_test)

        testset = torchvision.datasets.CIFAR10(
            root=data_dir, train=False, download=True, transform=transform_test)
        
        
    elif dataset == "cifar-100":
                
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
        
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    
        trainset = torchvision.datasets.CIFAR100(
            root=data_dir, train=True, download=True, transform=transform_train)

        valset = torchvision.datasets.CIFAR100(
            root=data_dir, train=True, download=True, transform=transform_test)

        testset = torchvision.datasets.CIFAR100(
            root=data_dir, train=False, download=True, transform=transform_test)
    
    elif dataset == "imagenet-16":
                
        mean = [x / 255 for x in [122.68, 116.66, 104.01]]
        std = [x / 255 for x in [63.22, 61.26, 65.09]]
    
        transform_train = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(16, padding=2),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        
        trainset = ImageNet16(data_dir, True, transform_train, 120)
        valset = ImageNet16(data_dir, True, transform_test, 120)
        testset = ImageNet16(data_dir, False, transform_test, 120)
        
    
    # compute train/val split
    
    val_size = 0.2
    indices = torch.randperm(len(trainset))
    val_size = int(len(trainset)*val_size)
    trainset = torch.utils.data.Subset(trainset, indices[:-val_size])
    valset = torch.utils.data.Subset(valset, indices[-val_size:])

    return trainset, valset, testset
    

def train_cifar(config, checkpoint_dir=None):
    
    np.random.seed(config["seed"])
    torch.manual_seed(config["seed"])
    
    
    ##model definition ##
    
    genotype_structure = "|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|"
    genotype_structure = genotypes.Structure.str2fullstructure(genotype_structure)
    
    if config["dataset"] == "cifar-10":
        num_classes = 10
    elif config["dataset"] == "cifar-100":
        num_classes = 100
    elif config["dataset"] == "imagenet-16":
        num_classes = 120
        
    model = DynamicShapeTinyNet(channels=[config["channels_1"], config["channels_2"], config["channels_3"], config["channels_4"], config["channels_5"]], genotype = genotype_structure, num_classes = num_classes)
        
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    
    if config["scale_bs"] == 0:
        batch_size = 256
        learning_rate = 0.1
    elif config["scale_bs"] == 1:
        batch_size = 256 * (8 / config["workers"])
        learning_rate = 0.1 * (8 / config["workers"])
    
    
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0005, nesterov=True)

    trainset, valset, testset = load_data(dataset=config["dataset"] , seed=config["seed"] )
    
    
    # checkpointing
    if checkpoint_dir:
        model_state, optimizer_state = torch.load(
            os.path.join(checkpoint_dir, "checkpoint"))
        model.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)
        

    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=int(batch_size),
        shuffle=True,
        num_workers=16)

    valloader = torch.utils.data.DataLoader(
        valset,
        batch_size=int(batch_size),
        shuffle=False,
        num_workers=16)
    
    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=int(batch_size),
        shuffle=False,
        num_workers=16)
    
    train_loader_len = len(trainloader)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=90*train_loader_len)
        
    for epoch in range(90):  # loop over the dataset multiple times
        running_loss = 0.0
        epoch_steps = 0
        running_correct_train = 0
        running_correct_val = 0
        running_correct_test = 0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            pred = outputs.argmax(dim=1, keepdim=True)
            loss.backward()
            optimizer.step()
            
            running_correct_train += pred.eq(labels.view_as(pred)).sum().item()
            
            lr_scheduler.step()
            
        # perform tune reporting and validation only every 10 epochs
        if (epoch > 0 and (epoch+1) % 10 == 0):
            
            # val cycle
            for i, data in enumerate(valloader, 0):
                # get the inputs; data is a list of [inputs, labels]
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                pred = outputs.argmax(dim=1, keepdim=True)

                running_correct_val += pred.eq(labels.view_as(pred)).sum().item()
                
                
            for i, data in enumerate(testloader, 0):
                # get the inputs; data is a list of [inputs, labels]
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                pred = outputs.argmax(dim=1, keepdim=True)

                running_correct_test += pred.eq(labels.view_as(pred)).sum().item()
            
            # save checkpoint
            with tune.checkpoint_dir(epoch) as checkpoint_dir:
                path = os.path.join(checkpoint_dir, "checkpoint")
                torch.save((model.state_dict(), optimizer.state_dict()), path)


            # report results to tune
            tune.report(train_acc=running_correct_train / len(trainset), val_acc=running_correct_val / len(valset), test_acc=running_correct_test / len(testset))
        
    print("Finished Training")
    

def main(args, num_samples=32, max_num_epochs=9):
    ray.init(address='auto')
    
    
    np.random.seed(args.seed)
    
    config = {
        "channels_1": tune.choice([8, 16, 24, 32, 40, 48, 56, 64]),
        "channels_2": tune.choice([8, 16, 24, 32, 40, 48, 56, 64]),
        "channels_3": tune.choice([8, 16, 24, 32, 40, 48, 56, 64]),
        "channels_4": tune.choice([8, 16, 24, 32, 40, 48, 56, 64]),
        "channels_5": tune.choice([8, 16, 24, 32, 40, 48, 56, 64]),
        "dataset": tune.choice([str(args.dataset)]),
        "workers": tune.choice([args.par_workers]),
        "scale_bs": tune.choice([args.scale_bs]),
        "seed": tune.choice([args.seed])
    }

    
    if (args.scheduler == "ASHA"):
        scheduler = ASHAScheduler(
               max_t=max_num_epochs,
               grace_period=2,
               reduction_factor=3)

        search_alg = None
    
    if (args.scheduler == "BOHB"):
        scheduler = HyperBandForBOHB(
               time_attr="training_iteration",
               max_t=max_num_epochs,
               reduction_factor=3,
               stop_last_trials=True)

        search_alg = TuneBOHB(seed=args.seed)
        
    if (args.scheduler == "HB"):
        scheduler = HyperBandScheduler(
               time_attr="training_iteration",
               max_t=max_num_epochs,
               reduction_factor=3,
               stop_last_trials=True)

        search_alg = None

    if (args.scheduler == "RAND"):
        scheduler = None
        search_alg = None
    
    reporter = CLIReporter(
        metric_columns=["train_acc", "val_acc", "test_acc", "training_iteration", "time_this_iter_s", "time_total_s"],
        max_report_frequency=60)
    
    result = tune.run(
        train_cifar,
        resources_per_trial={"cpu": int(128/args.par_workers), "gpu": 1/args.par_workers},
        local_dir=os.path.join(os.path.abspath(os.getcwd()), "ray_results"),
        config=config,
        num_samples=num_samples,
        scheduler=scheduler,
        search_alg=search_alg,
        name=str("eval" + (args.scheduler)),
        metric="val_acc",
        mode="max",
        stop={"training_iteration": max_num_epochs},
        progress_reporter=reporter,
        verbose=2
        )

        
    best_trial = result.get_best_trial("val_acc", "max", "last")
    print("Best trial config: {}".format(best_trial.config))
    print("Best trial final validation accuracy: {}".format(
        best_trial.last_result["val_acc"]))
    print("Best trial final test accuracy: {}".format(
        best_trial.last_result["test_acc"]))
    


if __name__ == "__main__":
    
    parser = parsIni()
    args = parser.parse_args()
    # You can change the number of GPUs per trial here:
    main(args, num_samples=args.num_samples, max_num_epochs=9)