import argparse
import os
import pickle
import sys

import numpy as np
import torch

sys.path.append(os.getcwd())

from utils.plotting_utils import *

parser = argparse.ArgumentParser(description="Whether to plot zero variance or low variance.")
parser.add_argument("--low-var", dest="low_var", action="store_true")
parser.set_defaults(low_var=False)
args = parser.parse_args()

# Plots path and suffix for low_var.
plots = "plots"
low_var = "_low_var"
suffix = ""
if args.low_var:
    suffix = low_var

# Produce synthetic plots first.
dataset = "Synthetic"
y_label = "High Variance Dimension Norm"
if args.low_var:
    synth_paths = [f.path for f in os.scandir("runs/") if f.is_dir() and dataset in f.name and low_var in f.name]
else:
    synth_paths = [f.path for f in os.scandir("runs/") if f.is_dir() and dataset in f.name and low_var not in f.name]
for path in synth_paths:
    ls_params = pickle.load(open(f"{path}/ls_sweep.p", "rb"))
    mix_params = pickle.load(open(f"{path}/mix_sweep.p", "rb"))
    wd_params = pickle.load(open(f"{path}/wd_sweep.p", "rb"))

    spur_norm_means = np.array([pickle.load(open(f"{path}/spur_norm_means.p", "rb"))])
    spur_norm_stds = np.array([pickle.load(open(f"{path}/spur_norm_stds.p", "rb"))])

    if ls_params[-1] > 0:
        plot_multi_dataset_metrics(
            x_label=r"Label Smoothing $\alpha$",
            y_label=y_label,
            fname=f"{plots}/{dataset}_LS_sweep{suffix}.png",
            xs=ls_params,
            metric_means=spur_norm_means,
            metric_stds=spur_norm_stds,
            datasets=[dataset],
            custom_colors=["C0"],
        )
    elif mix_params[-1] > 0:
        plot_multi_dataset_metrics(
            x_label=r"Mixup $\alpha$ in Beta($\alpha$, $\alpha$)",
            y_label=y_label,
            fname=f"{plots}/{dataset}_mixup_sweep{suffix}.png",
            xs=mix_params,
            metric_means=spur_norm_means,
            metric_stds=spur_norm_stds,
            datasets=[dataset],
            custom_colors=["C2"],
        )
    else:
        plot_multi_dataset_metrics(
            x_label=r"Weight Decay $\lambda$",
            y_label=y_label,
            fname=f"{plots}/{dataset}_WD_sweep{suffix}.png",
            xs=wd_params,
            metric_means=spur_norm_means,
            metric_stds=spur_norm_stds,
            datasets=[dataset],
            custom_colors=["C4"],
        )


# Binary classification plots.
datasets = ["CIFAR10", "CIFAR100"]
y_label = "Test Error"
spur_label = "First/Rest Dimension Norm Ratio"
cifar10_norm_means, cifar100_norm_means = [], []
cifar10_norm_stds, cifar100_norm_stds = [], []

sweep_vals = [0.75, 8.0, 0.1]  # LS, Mixup, and WD have unique sweep values.

for val in sweep_vals:
    if args.low_var:
        cifar10_path = [f.path for f in os.scandir("runs/") if f.is_dir() and "CIFAR10_" in f.name and "logreg_sweep" in f.name and str(val) in f.name and low_var in f.name][0]
        cifar100_path = [f.path for f in os.scandir("runs/") if f.is_dir() and "CIFAR100_" in f.name and "logreg_sweep" in f.name and str(val) in f.name and low_var in f.name][0]
    else:
        cifar10_path = [f.path for f in os.scandir("runs/") if f.is_dir() and "CIFAR10_" in f.name and "logreg_sweep" in f.name and str(val) in f.name and low_var not in f.name][0]
        cifar100_path = [f.path for f in os.scandir("runs/") if f.is_dir() and "CIFAR100_" in f.name and "logreg_sweep" in f.name and str(val) in f.name and low_var not in f.name][0]
    
    ls_params = pickle.load(open(f"{cifar10_path}/ls_sweep.p", "rb"))
    mix_params = pickle.load(open(f"{cifar10_path}/mix_sweep.p", "rb"))
    wd_params = pickle.load(open(f"{cifar10_path}/wd_sweep.p", "rb"))

    test_means, test_stds = [], []
    spur_means, spur_stds = [], []
    for path in [cifar10_path, cifar100_path]:
        test_means.append(pickle.load(open(f"{path}/test_means.p", "rb")))
        test_stds.append(pickle.load(open(f"{path}/test_stds.p", "rb")))

    test_means = np.array(test_means)
    test_stds = np.array(test_stds)

    cifar10_norm_means.append(pickle.load(open(f"{cifar10_path}/spur_means.p", "rb")))
    cifar10_norm_stds.append(pickle.load(open(f"{cifar10_path}/spur_stds.p", "rb")))
    cifar100_norm_means.append(pickle.load(open(f"{cifar100_path}/spur_means.p", "rb")))
    cifar100_norm_stds.append(pickle.load(open(f"{cifar100_path}/spur_stds.p", "rb")))
    
    if ls_params[-1] > 0:
        plot_multi_dataset_metrics(
            x_label=r"Label Smoothing $\alpha$",
            y_label=y_label,
            fname=f"{plots}/CIFAR_LS_sweep{suffix}.png",
            xs=ls_params,
            metric_means=test_means,
            metric_stds=test_stds,
            datasets=datasets,
            custom_colors=["C0", "C1"],
        )
    elif mix_params[-1] > 0:
        plot_multi_dataset_metrics(
            x_label=r"Mixup $\alpha$ in Beta($\alpha$, $\alpha$)",
            y_label=y_label,
            fname=f"{plots}/CIFAR_mixup_sweep{suffix}.png",
            xs=mix_params,
            metric_means=test_means,
            metric_stds=test_stds,
            datasets=datasets,
            custom_colors=["C2", "C6"],
        )
    else:
        plot_multi_dataset_metrics(
            x_label=r"Weight Decay $\lambda$",
            y_label=y_label,
            fname=f"{plots}/CIFAR_WD_sweep{suffix}.png",
            xs=wd_params,
            metric_means=test_means,
            metric_stds=test_stds,
            datasets=datasets,
            custom_colors=["C4", "C5"],
        )

cifar10_norm_means = np.array(cifar10_norm_means)
cifar10_norm_stds = np.array(cifar10_norm_stds)
cifar100_norm_means = np.array(cifar100_norm_means)
cifar100_norm_stds = np.array(cifar100_norm_stds)

plot_multi_dataset_metrics(
    x_label=r"Hyperparameter Choice",
    y_label=spur_label,
    fname=f"{plots}/CIFAR10_weight_ratios{suffix}.png",
    xs=np.array(range(1, 21), dtype=np.int32),
    metric_means=cifar10_norm_means,
    metric_stds=cifar10_norm_stds,
    datasets=["Label Smoothing", "Mixup", "Weight Decay"],
    custom_colors=["C0", "C2", "C4"],
)
plot_multi_dataset_metrics(
    x_label=r"Hyperparameter Choice",
    y_label=spur_label,
    fname=f"{plots}/CIFAR100_weight_ratios{suffix}.png",
    xs=np.array(range(1, 21), dtype=np.int32),
    metric_means=cifar100_norm_means,
    metric_stds=cifar100_norm_stds,
    datasets=["Label Smoothing", "Mixup", "Weight Decay"],
    custom_colors=["C1", "C6", "C5"],
)


# Make plots for resnet results.
epochs = range(1, 101)
hparams = [0.1, 1.0, 0.01]

for i, val in enumerate(hparams):
    if args.low_var:
        cifar10_path = [f.path for f in os.scandir("runs/") if f.is_dir() and "CIFAR10_" in f.name and "ResNet18" in f.name and str(val) in f.name and low_var in f.name][0]
        cifar100_path = [f.path for f in os.scandir("runs/") if f.is_dir() and "CIFAR100_" in f.name and "ResNet18" in f.name and str(val) in f.name and low_var in f.name][0]
    else:
        cifar10_path = [f.path for f in os.scandir("runs/") if f.is_dir() and "CIFAR10_" in f.name and "ResNet18" in f.name and str(val) in f.name and low_var not in f.name][0]
        cifar100_path = [f.path for f in os.scandir("runs/") if f.is_dir() and "CIFAR100_" in f.name and "ResNet18" in f.name and str(val) in f.name and low_var not in f.name][0]
    
    test_means, test_stds = [], []
    for path in [cifar10_path, cifar100_path]:
        all_test_errors = pickle.load(open(f"{path}/all_run_test_errors.p", "rb"))
        test_means.append(all_test_errors.mean(axis=0))
        test_stds.append(all_test_errors.std(axis=0))
    test_means = np.array(test_means)
    test_stds = np.array(test_stds)
    
    if i == 0:
        plot_multi_dataset_metrics(
            x_label=r"Training Epoch",
            y_label=y_label,
            fname=f"{plots}/CIFAR_LS_resnet{suffix}.png",
            xs=epochs,
            metric_means=test_means,
            metric_stds=test_stds,
            datasets=datasets,
            custom_colors=["C0", "C1"],
        )
    elif i == 1:
        plot_multi_dataset_metrics(
            x_label=r"Training Epoch",
            y_label=y_label,
            fname=f"{plots}/CIFAR_mixup_resnet{suffix}.png",
            xs=epochs,
            metric_means=test_means,
            metric_stds=test_stds,
            datasets=datasets,
            custom_colors=["C2", "C6"],
        )
    else:
        plot_multi_dataset_metrics(
            x_label=r"Training Epoch",
            y_label=y_label,
            fname=f"{plots}/CIFAR_WD_resnet{suffix}.png",
            xs=epochs,
            metric_means=test_means,
            metric_stds=test_stds,
            datasets=datasets,
            custom_colors=["C4", "C5"],
        )


# Create Colored MNIST plots.
n_sweep = 20

cmnist_mix_means = pickle.load(open("runs/Colored_MNIST_0.0_LS_0.0_WD_8.0_Mix/test_means.p", "rb"))
cmnist_mix_stds = pickle.load(open("runs/Colored_MNIST_0.0_LS_0.0_WD_8.0_Mix/test_stds.p", "rb"))
# plot_multi_dataset_metrics(
#     x_label=r"Mixup $\alpha$ in Beta($\alpha$, $\alpha$)",
#     y_label="Test Error",
#     fname=f"{plots}/cmnist_mixup_sweep.png",
#     xs=np.linspace(0, 8, n_sweep),
#     metric_means=[np.array(cmnist_mix_means)],
#     metric_stds=[np.array(cmnist_mix_stds)],
#     datasets=["Colored_MNIST"],
#     custom_colors=["C2"],
# )

cmnist_wd_means = pickle.load(open("runs/Colored_MNIST_0.0_LS_0.1_WD_0.0_Mix/test_means.p", "rb"))
cmnist_wd_stds = pickle.load(open("runs/Colored_MNIST_0.0_LS_0.1_WD_0.0_Mix/test_stds.p", "rb"))
# plot_multi_dataset_metrics(
#     x_label=r"Weight Decay $\lambda$",
#     y_label="Test Error",
#     fname=f"{plots}/cmnist_wd_sweep.png",
#     xs=np.linspace(0, 0.1, n_sweep),
#     metric_means=[np.array(cmnist_wd_means)],
#     metric_stds=[np.array(cmnist_wd_stds)],
#     datasets=["Colored_MNIST"],
#     custom_colors=["C4"],
# )

cmnist_ls_means = pickle.load(open("runs/Colored_MNIST_0.75_LS_0.0_WD_0.0_Mix/test_means.p", "rb"))
cmnist_ls_stds = pickle.load(open("runs/Colored_MNIST_0.75_LS_0.0_WD_0.0_Mix/test_stds.p", "rb"))
# plot_multi_dataset_metrics(
#     x_label=r"Label Smoothing $\alpha$",
#     y_label="Test Error",
#     fname=f"{plots}/cmnist_ls_sweep.png",
#     xs=np.linspace(0, 0.75, n_sweep),
#     metric_means=[np.array(cmnist_ls_means)],
#     metric_stds=[np.array(cmnist_ls_stds)],
#     datasets=["Colored_MNIST"],
#     custom_colors=["C0"],
# )

plot_multi_dataset_metrics(
    x_label=r"Hyperparameter Choice",
    y_label="Test Error",
    fname=f"{plots}/cmnist_combined_sweep.png",
    xs=np.array(range(1, n_sweep + 1), dtype=np.int32),
    metric_means=[np.array(cmnist_ls_means), np.array(cmnist_mix_means), np.array(cmnist_wd_means)],
    metric_stds=[np.array(cmnist_ls_stds), np.array(cmnist_mix_stds), np.array(cmnist_wd_stds)],
    datasets=["Label Smoothing", "Mixup", "Weight Decay"],
    custom_colors=["C0", "C2", "C4"],
)
