import argparse
import os
import pickle
import random
import sys

import numpy as np
import torch
import torchvision
import wandb

sys.path.append(os.getcwd())

from pathlib import Path
from utils.data_utils import load_dataset
from utils.training_utils import MixupLoss, BinaryMixupLoss, LSLogLoss, single_train_test
from torch.utils.data import DataLoader

# Set up commandline arguments.
parser = argparse.ArgumentParser(
    description="Logistic regression training parameter upper bounds."
)
parser.add_argument("--task-name", dest="task_name", default="CIFAR10", type=str)
parser.add_argument("--n-sweep", dest="n_sweep", default=20, type=int)
parser.add_argument("--label-smoothing", dest="label_smoothing", default=0, type=float)
parser.add_argument("--weight-decay", dest="weight_decay", default=0, type=float)
parser.add_argument("--mixup-alpha", dest="mixup_alpha", default=0, type=float)
parser.add_argument("--low-var", dest="low_var", default=0, type=float)
args = parser.parse_args()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if device != "cpu":
    print("Device count: ", torch.cuda.device_count())
    print("GPU being used: {}".format(torch.cuda.get_device_name(0)))

# Fix seeds for reproducibility.
seed = 42
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

# Wandb setup. Can set to None if not using wandb.
task_name = args.task_name + "Binary"
wandb_run = wandb.init(
    project="min-variance-label-aug",
    config={
        "label_smoothing": args.label_smoothing,
        "weight_decay": args.weight_decay,
        "mixup_alpha": args.mixup_alpha,
    },
    name=f"{task_name}_{args.label_smoothing}_LS_{args.weight_decay}_WD_{args.mixup_alpha}_Mix",
)

# Set up runs path.
runs_path = f"runs/logreg_sweep_{args.task_name}_{args.label_smoothing}_LS_{args.weight_decay}_WD_{args.mixup_alpha}_Mix"
if args.low_var > 0:
    runs_path = runs_path + f"_low_var_{args.low_var}"
Path(runs_path).mkdir(parents=True, exist_ok=True)

# Load data.
train_data, test_data, out_dim = load_dataset(
    dataset=args.task_name,
    subsample=0,
    binary=True,
    add_patches=True,
    low_var=args.low_var,
)

# Fix hyperparams. 
epochs = 200
batch_size = 500
n_runs = 5

# Prepare data.
train_dl = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# Set up logistic regression.
in_dim = np.prod(train_data.data[0].shape)
# model = torch.nn.Sequential(torch.nn.Flatten(), torch.nn.Linear(in_dim, 2)).to(device)
model = torch.nn.Sequential(torch.nn.Flatten(), torch.nn.Linear(in_dim, 1, bias=False), torch.nn.Sigmoid()).to(device)

# Train over sweep of parameters.
wd_sweep = np.linspace(0, args.weight_decay, args.n_sweep)
ls_sweep = np.linspace(0, args.label_smoothing, args.n_sweep)
mix_sweep = np.linspace(0, args.mixup_alpha, args.n_sweep)

train_means, train_stds, test_means, test_stds = [], [], [], []
spur_means, spur_stds = [], []
for i in range(args.n_sweep):
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=wd_sweep[i])
    if mix_sweep[i] > 0:
        criterion = BinaryMixupLoss(alpha=mix_sweep[i], fixed=False)
    elif ls_sweep[i] > 0:
        criterion = LSLogLoss(alpha=ls_sweep[i])
    else:
        criterion = torch.nn.BCELoss()
    
    train_errors, test_errors, spur_norms = single_train_test(
        model=model,
        train_loader=train_dl,
        train_loss_fn=criterion,
        test_loader=test_dl,
        optimizer=optimizer,
        num_epochs=epochs,
        num_runs=n_runs,
        device=device,
    )
    train_means.append(train_errors.mean())
    train_stds.append(train_errors.std())
    test_means.append(test_errors.mean())
    test_stds.append(test_errors.std())
    spur_means.append(spur_norms.mean())
    spur_stds.append(spur_norms.std())

    if wandb_run is not None:
        wandb_run.log(
            {
                "Average Train Error": train_means[-1],
                "Average Test Error": test_means[-1],
                "Average Zero Variance Feature Norm": spur_means[-1],
            }
        )

if wandb_run is not None:
    wandb_run.finish()

# X-axis values for plotting.
pickle.dump(wd_sweep, open(f"{runs_path}/wd_sweep.p", "wb"))
pickle.dump(ls_sweep, open(f"{runs_path}/ls_sweep.p", "wb"))
pickle.dump(mix_sweep, open(f"{runs_path}/mix_sweep.p", "wb"))

# Y-axis values for plotting.
pickle.dump(train_means, open(f"{runs_path}/train_means.p", "wb"))
pickle.dump(train_stds, open(f"{runs_path}/train_stds.p", "wb"))
pickle.dump(test_means, open(f"{runs_path}/test_means.p", "wb"))
pickle.dump(test_stds, open(f"{runs_path}/test_stds.p", "wb"))
pickle.dump(spur_means, open(f"{runs_path}/spur_means.p", "wb"))
pickle.dump(spur_stds, open(f"{runs_path}/spur_stds.p", "wb"))
