import argparse
import os
import pickle
import random
import sys

import numpy as np
import torch
import torchvision
import wandb

sys.path.append(os.getcwd())

from pathlib import Path
from torch.utils.data import DataLoader
from utils.data_utils import load_dataset
from utils.training_utils import MixupLoss, reset_weights, train, test

# Set up commandline arguments.
parser = argparse.ArgumentParser(description="Hyperparameters for model training.")
parser.add_argument("--n-sweep", dest="n_sweep", default=20, type=int)
parser.add_argument("--label-smoothing", dest="label_smoothing", default=0, type=float)
parser.add_argument("--weight-decay", dest="weight_decay", default=0, type=float)
parser.add_argument("--mixup-alpha", dest="mixup_alpha", default=0, type=float)
args = parser.parse_args()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if device != "cpu":
    print("Device count: ", torch.cuda.device_count())
    print("GPU being used: {}".format(torch.cuda.get_device_name(0)))

# Fix seeds for reproducibility.
seed = 42
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

# Wandb setup. Can set to None if not using wandb.
task_name = "Colored_MNIST"
wandb_run = wandb.init(
    project="min-variance-label-aug",
    config={
        "label_smoothing": args.label_smoothing,
        "weight_decay": args.weight_decay,
        "mixup_alpha": args.mixup_alpha,
    },
    name=f"{task_name}_{args.label_smoothing}_LS_{args.weight_decay}_WD_{args.mixup_alpha}_Mix",
)

# Set up runs path.
runs_path = runs_path = f"runs/{task_name}_{args.label_smoothing}_LS_{args.weight_decay}_WD_{args.mixup_alpha}_Mix"
Path(runs_path).mkdir(parents=True, exist_ok=True)
perf_file = open(f"{runs_path}/training.out", "w")

# Load data.
train_data, test_data, out_dim = load_dataset(dataset="MNIST")

# Fix hyperparams. 
epochs = 20
batch_size = 500
n_runs = 5

# Prepare data.
train_dl = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# Set up model.
in_dim = train_data[0][0].shape[0] * train_data[0][0].shape[1] * train_data[0][0].shape[2]
model = torch.nn.Sequential(
    torch.nn.Flatten(), 
    torch.nn.Linear(in_dim, 2048),
    torch.nn.ReLU(),
    torch.nn.Linear(2048, out_dim)).to(device)

# Train over sweep of parameters.
wd_sweep = np.linspace(0, args.weight_decay, args.n_sweep)
ls_sweep = np.linspace(0, args.label_smoothing, args.n_sweep)
mix_sweep = np.linspace(0, args.mixup_alpha, args.n_sweep)

if args.mixup_alpha > 0:
    criterion = MixupLoss(
        mixup_alpha=args.mixup_alpha, criterion=torch.nn.CrossEntropyLoss()
    )
else:
    criterion = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)

test_means, test_stds = [], []
for i in range(args.n_sweep):
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=wd_sweep[i])
    if mix_sweep[i] > 0:
        criterion = MixupLoss(mixup_alpha=mix_sweep[i], criterion=torch.nn.CrossEntropyLoss())
    else:
        criterion = torch.nn.CrossEntropyLoss(label_smoothing=ls_sweep[i])
    print(f"\nResults for Mixup = {mix_sweep[i]}, LS = {ls_sweep[i]}, WD = {wd_sweep[i]}.")
    test_errors = []
    for j in range(n_runs):
        model.apply(reset_weights)
        print(f"Run {j + 1}")
        for epoch in range(1, epochs + 1):
            avg_batch_loss = train(
                model,
                train_dl,
                criterion,
                optimizer,
                device,
            )
            test_error = test(model, test_dl, device)
            print(f"[Epoch {epoch}] Batch Loss: {avg_batch_loss:.3f} \t Test Error: {test_error:.2f}")
        test_errors.append(test(model, test_dl, device))
    test_errors = np.array(test_errors)
    test_means.append(test_errors.mean())
    test_stds.append(test_errors.std())
    if wandb_run is not None:
        wandb_run.log(
            {
                "Mean Test Error": test_means[-1],
                "Test Error Std": test_stds[-1],
            }
        )

if wandb_run is not None:
    wandb_run.finish()

pickle.dump(test_means, open(f"{runs_path}/test_means.p", "wb"))
pickle.dump(test_stds, open(f"{runs_path}/test_stds.p", "wb"))
