import random

import numpy as np
from torch import optim

import wandb
from data_loading import *
from models import *
from train import *

'''
To run experiments, configure this script as desired and run. Note that it assumes you have wandb installed and your API
 key is set.

'''

if __name__ == "__main__":
    torch.manual_seed(0)
    random.seed(0)
    np.random.seed(0)

    device = "cuda"
    '''
    Model type can be one of "FC", "Residual", "conv", "conv_residual". If using the division mod p dataset,
     the model type is fixed.
    '''

    model_type = "FC"

    dset_name = "MNIST"

    '''
    dset can be one of "synthetic_classification", "div_mod_p", "MNIST", "CIFAR", "synth_images". The easiest way 
    to adjust model parameters is by adjusting the model_params dictionary in the below code for the dataset you are using.
    '''

    if dset_name == "synthetic_classification":
        # Not used extensively, included for posterity.
        dset_params = dict(n_samples=10000, n_features=32, n_informative=8, n_redundant=2,
                           n_repeated=0, n_classes=2, n_clusters_per_class=3,
                           weights=None, flip_y=0, class_sep=1, hypercube=True,
                           shift=0.0, scale=1.0, shuffle=True, random_state=None)
        num_classes = dset_params["n_classes"]
        num_features = dset_params["n_features"]
        model_params = dict(num_classes=num_classes, layers=[num_features, 128, 64],
                            use_batch_norm=True, bias=False, weight_init="xavier")
        if model_type == "Residual":
            mod_class = FullyConnectedResidualNN
        else:
            mod_class = FullyConnectedNN
        float_input = True
        special_params = {}
    elif dset_name == "div_mod_p":
        # Not used extensively, included for posterity.
        dset_params = dict(p=53, eq_token=53, op_token=54)
        num_classes = dset_params["p"]
        mod_class = EmbeddingConcatFFModel
        special_params = {"p": dset_params["p"], "embed_dim": 128, "hidden": 256}
        model_params = special_params
        num_features = dset_params["p"]
        float_input = False

    elif dset_name == "MNIST":
        # For convolutional networks, do not flatten the images.
        # (Make sure your data-loading function honors the "flatten" flag.)
        dset_params = dict(num_samples=None, flatten=False)
        num_classes = 10

        # If the model_type indicates a conv model, we set up its parameters.
        if model_type in ["conv", "conv_residual"]:
            if model_type == "conv":
                mod_class = SimpleConvNet
            else:
                #When using conv res net, the first channel parameter is the depth between blocks.
                mod_class = ConvResNet
            # For conv models, use input_channels=1 (for grayscale images)
            # and specify the channels for your convolutional layers.
            model_params = dict(num_classes=num_classes,
                                input_channels=1,
                                conv_channels=[16, 32],
                                weight_init="xavier",
                                use_batch_norm=True,
                                bias=False)
        else:
            # Fallback to fully connected models when not using conv types.
            mod_class = FullyConnectedNN
            num_features = 28 * 28
            model_params = dict(num_classes=num_classes,
                                layers=[num_features, 128, 64],
                                use_batch_norm=True,
                                bias=False,
                                weight_init="xavier")
        float_input = True
        special_params = {}

    elif dset_name == "CIFAR":
        # For convolutional networks, do not flatten the images.
        # (Make sure your data-loading function honors the "flatten" flag.)
        #num_samples = 10000,
        dset_params = dict(flatten=False, num_samples=10000)
        num_classes = 10

        # If the model_type indicates a conv model, we set up its parameters.
        if model_type in ["conv", "conv_residual"]:
            if model_type == "conv":
                mod_class = SimpleConvNet
            else:
                mod_class = ConvResNet
            # For conv models, use input_channels=1 (for grayscale images)
            # and specify the channels for your convolutional layers.
            #[16, 32, 64]
            model_params = dict(num_classes=num_classes,
                                input_channels=3,
                                conv_channels=[16, 32, 64, 128],
                                weight_init="xavier",
                                use_batch_norm=True,
                                bias=False)
        else:
            # Fallback to fully connected models when not using conv types.
            mod_class = FullyConnectedNN
            num_features = 32 * 32
            model_params = dict(num_classes=num_classes,
                                layers=[num_features, 128, 64],
                                use_batch_norm=True,
                                bias=False,
                                weight_init="xavier")
        float_input = True
        special_params = {}

    elif dset_name == "synth_images":
        # Not used extensively, included for posterity.
        dset_params = dict(num_images=10000, width=32, height=32, flatten=True, class_weights=None)
        num_classes = 3
        num_features = dset_params["width"] * dset_params["height"] * 3
        if model_type == "Residual":
            mod_class = FullyConnectedResidualNN
        else:
            mod_class = FullyConnectedNN
        float_input = True
        special_params = {}


    optimizer_params = dict(lr=0.001, weight_decay=0)

    run_params = {
        "dset": dset_name,
        "dset_params": dset_params,
        "batch_size": 64,
        "epochs": 50,
        "test_size": 0.5,
        "model_type": model_type,
        "model_params": model_params,
        "optimizer": "sgd",
        "optimizer_params": optimizer_params
    }

    # Initialize the model using the chosen parameters and move to device.
    model = mod_class(**run_params["model_params"]).to(device)

    possible_optimizers = {"adam": optim.Adam, "sgd": optim.SGD, "adamw": optim.AdamW}
    opt = possible_optimizers[run_params["optimizer"]]
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = opt(model.parameters(), **optimizer_params)
    #logarithmic_conv_llc_CIFAR
    project_name = f"synthetic_llc_{dset_name}"
    wandb.init(project=project_name)
    wandb.log(run_params)
    run = wandb.run

    # Load the dataset using your provided data-loading functions.
    dset_func = DATASET_DICT[run_params["dset"]]
    X_train, X_test, train_labs, test_labs = dset_func(**run_params["dset_params"], test_size=run_params["test_size"])
    train_dset, test_dset = make_tensor_datasets(X_train, X_test, train_labs, test_labs)
    train_loader, test_loader = make_dataloaders(train_dset, test_dset, batch_size=run_params["batch_size"])

    train_and_analyze(
        model, train_loader, test_loader, criterion, optimizer, device,
        epochs=run_params["epochs"], eval_step=100, wandb_run=run,
    )
