import torch
import torch.nn as nn
import os
import warnings
from eos_line_search.experiment import *
from eos_line_search.data import *
from eos_line_search.run import *
from eos_line_search.plot import *
from eos_line_search.model import Model, CNNModel, MLPModel
from python_scripts import (
    optimizers,
    new_optimizers,
    assmpt_optimizers,
    approx_optimizers,
    debug_optimizers,
    sam_optimizers,
    best_c,
    GD_steps,
    delta_ablation_optimizers,
    vit_optimizers,
    warmup_optimizers,
)
import sys
import copy
import argparse

# warnings.filterwarnings("ignore")

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="CIFAR10", help="Dataset name")
parser.add_argument("--model", type=str, default="CNN", help="Model name")
parser.add_argument(
    "--batch_size", type=str, default="full", help="Batch size or full batch"
)
parser.add_argument("--epochs", type=int, default=5000, help="Number of epochs")
parser.add_argument(
    "--mode",
    type=int,
    default=0,
    help="0: Default, 2: New optimizers, 3: Debugging, 4: Assmpt optimizers, 10: Constant GD",
)
args = parser.parse_args()


# Read arguments
dataset = args.dataset
model = args.model
batch_size = args.batch_size
epochs = args.epochs
mode = args.mode

loss = "mse"
if loss == "mse":
    one_hot_encode = True
    loss_fn = nn.MSELoss(reduction="mean")
elif loss == "ce":
    one_hot_encode = False
    loss_fn = nn.CrossEntropyLoss(reduction="mean")

# Get Path
path = os.getcwd()

# Create folders for experiments and plots
if not os.path.exists(os.path.join(path, "experiments")):
    os.makedirs(os.path.join(path, "experiments"))
### Set device
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

### Setup plots and data subsetting
if batch_size == "full":
    if (
        dataset == "CIFAR10"
        or dataset == "SVHN"
        or dataset == "CIFAR100"
        or dataset == "EMNIST"
    ):
        train_subset = 5000
    elif dataset == "Imagenette" or dataset == "Imagenet":
        train_subset = 1000
    else:
        raise ValueError("{} is not a valid dataset".format(dataset))

    plot_metrics = Plot(
        metrics=[
            "Training Loss",
            "Training Accuracy",
            "Eigenvalues",
            "Gradient Norm",
            "Final Step Size",
            "Initial Step Size",
            "Backtracks",
            "Function Evaluations",
            "Test Loss",
            "Test Accuracy",  # FIXME: this only works if also "Test Loss" is active
            # "Approx 7",
            # "Approx 8",
            # "Approx 9",
            "a",
            "time",
        ],
        num_eigs=1,
        label="Optimizer",
        sharpness_every=100,
    )
else:
    train_subset = "full"
    plot_metrics = Plot(
        metrics=[
            "Training Loss",
            "Training Accuracy",
            # "Eigenvalues",
            "Avg Batch Eigenvalues",
            "Max Batch Eigenvalues",
            "Min Batch Eigenvalues",
            "Gradient Norm",
            "Final Step Size",
            "Initial Step Size",
            "Backtracks",
            "Function Evaluations",
            "Test Loss",
            "Test Accuracy",
            # "Approx 7",
            # "Approx 8",
            # "Approx 9",
            "a",
            "time",
        ],
        num_eigs=1,
        label="Optimizer",
        sharpness_every=2,
    )

### Setup optimizers
if mode == 1:
    optimizers = debug_optimizers
    num_eigs = 20
    sharpness_every = 2
    if dataset == "EMNIST":
        num_eigs = 30
    after_it = 0
    if model == "vgg11" and dataset == "CIFAR10":
        after_it = 3200
    elif model == "vgg11" and dataset == "CIFAR100":
        after_it = 1200
    elif model == "vgg11" and dataset == "SVHN":
        after_it = 800
    elif model == "vgg11" and dataset == "EMNIST":
        after_it = 1100
    plot_metrics = Plot(
        metrics=[
            "Training Loss",
            "Training Accuracy",
            "Eigenvalues",
            "Gradient Norm",
            "Final Step Size",
            "Initial Step Size",
            "Backtracks",
            "Function Evaluations",
            "a",
            "time",
            "debugging",
            "Trace",
            "Bias Grad Norm",
            "Avg Hidden Grad Norm",
            "Std Hidden Grad Norm",
            "Min Hidden Grad Norm",
            "Max Hidden Grad Norm",
            "Zero Grad Entries",
            "Zero Activations",
            "NC1 (CDNV)",
            "NC1 (Pseudo-inv)",
            "NC1 (SVD)",
            "NC1 (Quotient)",
            "NC2 (ETF Error)",
            "NC2 (Global Dist)",
            "NC3 (Dual Error)",
            "NC3 (Uniform Dual)",
            "Test Loss",
            "Test Accuracy",
        ],
        num_eigs=num_eigs,
        label="Optimizer",
        sharpness_every=sharpness_every,
        after_it=after_it,
    )

elif mode == 2:
    optimizers = assmpt_optimizers
    plot_metrics = Plot(
        metrics=[
            "Training Loss",
            "Training Accuracy",
            "Eigenvalues",
            # "Average Batch Eigenvalues",
            "Gradient Norm",
            "Final Step Size",
            "Initial Step Size",
            "Backtracks",
            "Function Evaluations",
            "Test Loss",
            "Test Accuracy",
            #            "Approx 7",
            #            "Approx 8",
            #            "Approx 9",
            "a",
            "time",
            "Lw_asmpt",
        ],
        num_eigs=1,
        label="Optimizer",
        sharpness_every=100,
    )
elif mode == 3:
    optimizers = approx_optimizers
    plot_metrics = Plot(
        metrics=[
            "Training Loss",
            "Training Accuracy",
            "Eigenvalues",
            # "Average Batch Eigenvalues",
            "Gradient Norm",
            "Final Step Size",
            "Initial Step Size",
            "Backtracks",
            "Function Evaluations",
            "Test Loss",
            "Test Accuracy",
            "Approx 7",
            "Approx 8",
            "Approx 9",
            "a",
            "time",
        ],
        num_eigs=1,
        label="Optimizer",
        sharpness_every=100,
    )

elif mode == 4:
    optimizers = sam_optimizers

elif mode == 5:
    optimizers = delta_ablation_optimizers

elif mode == 6:
    optimizers = vit_optimizers

elif mode == 7:
    plot_metrics.sharpness_every = 1
    optimizers = warmup_optimizers

elif mode == 10:
    optimizers = []
    constant_GDs = GD_steps.get((dataset, model))
    if constant_GDs:
        for step_size in constant_GDs:
            optimizers.append(
                Optim(opt_name="constant_stepsize_GD", step_size=step_size)
            )
    else:
        print("Missing GD constants")
elif mode == 11:
    optimizers = new_optimizers


### Setup runs
runs = []
epochs = epochs
if batch_size != "full":
    batch_size = int(batch_size)
reg_param = 0
for optimizer in optimizers:
    if optimizer.step_size == -1:
        maybe_step_size = best_c.get(
            (optimizer.opt_name, optimizer.forward_option, dataset, model)
        )
        if maybe_step_size:
            optimizer.step_size = maybe_step_size
            print("Best known step size: ", optimizer.step_size)
        else:
            optimizer.step_size = 0.1
    if model == "CNN":
        the_model = CNNModel(
            model_type="CNN", activation_fn=nn.ReLU, pooling=nn.MaxPool2d, window_size=2
        )
    elif model == "MLP":
        the_model = MLPModel(
            model_type="MLP", activation_fn=nn.ReLU, num_layers=3, width=100
        )
    else:
        the_model = Model(model_type=model)

    runs.append(
        Run(
            dataset=Data(
                name=dataset,
                train_subset=train_subset,
                stratified=True,
                one_hot_encode=one_hot_encode,
                centered=True,
            ),
            loss_fn=loss_fn,
            optimizer=copy.deepcopy(optimizer),
            batch_size=batch_size,
            epochs=epochs,
            reg_param=reg_param,
            model=the_model,
            plot_metrics=plot_metrics,
        )
    )

### Setup experiment
experiment = Experiment(runs=runs, device=device, path=path)

### Run experiment
use_wb = False  # set to True to use Weights and Biases logging
entity = None  # set Weights and Biases entity name, otherwise set to None
project_name = None  # set Weights and Biases project name, otherwise set to None
group = None  # set Weights and Biases group name, otherwise set to None
experiment.run_experiment(use_wb, entity, project_name, group)
