# Original Code here:
# https://github.com/pytorch/examples/blob/master/mnist/main.py
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch.utils.data as data
from torchvision.models.resnet import resnet18, ResNet18_Weights, resnet50, ResNet50_Weights

import ray
from ray import train, tune
from ray.tune.schedulers import AsyncHyperBandScheduler

from dataset import PickleDataset
import wandb
import pickle as pkl

# Change these values if you want the training to run quicker or slower.

DATA_PATH = "/Data/ict04/dev/medical/classification/surgical/cached_data/BUSI.pkl"

def train_func(model, optimizer, train_loader, device=None):
    device = device or torch.device("cpu")
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()


def test_func(model, data_loader, device=None):
    device = device or torch.device("cpu")
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    return correct / total

def get_data_loaders(dataset, batch_size=64):
    
    # We add FileLock here because multiple workers will want to
    # download data, and this may cause overwrites since
    # DataLoader is not threadsafe.
    n_classes = len(dataset.classes)
    train_len = int(len(dataset) * 0.6)
    val_len = int(len(dataset) * 0.2)
    test_len = len(dataset) - train_len - val_len
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_len, val_len, test_len], generator=torch.Generator().manual_seed(42))
    
    train_loader = data.DataLoader(dataset=train_dataset,
                                batch_size=batch_size,
                                shuffle=True,
                                num_workers=0,
                                pin_memory=True)
    val_loader = data.DataLoader(dataset=val_dataset,
                                batch_size=512,
                                shuffle=False, 
                                num_workers=0,
                                pin_memory=True)
    test_loader = data.DataLoader(dataset=test_dataset,
                                batch_size=512,
                                shuffle=False,
                                num_workers=0,
                                pin_memory=True)
    return train_loader, val_loader, test_loader, n_classes

def train_mnist(config):
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    dataset = ray.get(config["data_ref"])
    train_loader, test_loader, _, n_classes = get_data_loaders(dataset=dataset)

    model = get_model(device, n_classes)

    optimizer = optim.Adam(
        model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"]
    )

    while True:
        train_func(model, optimizer, train_loader, device)
        acc = test_func(model, test_loader, device)
        metrics = {"mean_accuracy": acc}

        # Report metrics (and possibly a checkpoint)
        train.report(metrics)

def get_model(device, n_classes):
    model =  resnet50(weights=ResNet50_Weights.DEFAULT)
    model.fc = nn.Linear(model.fc.in_features, n_classes)
    model = model.to(device)
    return model

if __name__ == "__main__":

    ray.init()

    # for early stopping
    sched = AsyncHyperBandScheduler()

    resources_per_trial = {"cpu": 8, "gpu": 0.5}  # set this for GPUs
    dataset = PickleDataset(DATA_PATH)
    data_ref = ray.put(dataset)

    tuner = tune.Tuner(
        tune.with_resources(train_mnist, resources=resources_per_trial),
        tune_config=tune.TuneConfig(
            metric="mean_accuracy",
            mode="max",
            scheduler=sched,
            num_samples=50,
        ),
        run_config=train.RunConfig(
            name="exp",
            stop={
                "mean_accuracy": 0.98,
                "training_iteration": 50,
            },
        ),
        param_space={
            "lr": tune.loguniform(1e-6, 1e-1),
            "weight_decay": tune.uniform(1e-3, 1e-6),
            "data_ref": data_ref
        },
    )
    results = tuner.fit()

    print("Best config is:", results.get_best_result().config)
    # lr, weight_decay = results.get_best_result().config["lr"], results.get_best_result().config["weight_decay"]
    # print("Best mean_accuracy is:", results.get_best_result().metrics["mean_accuracy"]["max"])

    # _, _, test_loader, n_classes = get_data_loaders(512)
    # model = get_model("cuda", n_classes)

    # wandb.init(project="moo_class", tags=["SRSMAS"])
    # wandb.log({"mean_acc": results.get_best_result().metrics["mean_accuracy"]["max"]})


    assert not results.errors