import torch
from torch.utils.data import Dataset
import wandb
import os
from tenacity import retry, wait_exponential, stop_after_attempt
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from torch.utils.data import DataLoader
from timm.models.vision_transformer import vit_small_patch16_224 
from torchvision import models
# from torchvision.models import ResNet18_Weights, ResNet50_Weights
from central_data.dataset import CIFAR10, CIFAR100, MNIST, FMNIST, CIFAR10p1, GLD23K 
from central_data.dataset import CIFAR10_Lp, CIFAR100_Lp, MNIST_Lp, FMNIST_Lp, CIFAR10p1_Lp, GLD23K_Lp
from central_data.ensemble_options import eff_args_parser 
import torch.optim.lr_scheduler as lr_scheduler
import pickle
from central_data.wandb_ens_pickle import get_wandb_runs, load_runs_from_file, process_wandb_runs, load_state_dicts, CustomDataset, custom_loss, vanilla_loss #* # is the args_parser being imported here? Yes. It is an override. 
from evaluations.train_test import model_evaluate_imagedata, model_train_imagedata, initialize_optimizer


########## Step 1: Set up hyperparameter configuration sweep
args = eff_args_parser()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
project_name = args.project_name 
wandb_entity = 'pyt-geo'

# Calculating the total number of hyperparameter configurations
total_configs = (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list) * len(args.beta_2_list) *
                 len(args.warmup_epochs_list) * len(args.optimzer_choice_list) * len(args.weight_decay_list) *
                 len(args.p_list) * len(args.num_ingredients_list) * len(args.epochs_list) * len(args.z_list) *
                 len(args.shuffletrue_list) * len(args.batch_size_list))

# Function to compute the index of the configuration
def compute_config_index(slurm_id, max_array):
    return [(slurm_id + i) % total_configs for i in range(0, total_configs, max_array + 1)]

# Retrieve the configuration indices to process
config_indices = compute_config_index(args.slurm_id, args.max_array)

########## Step 2: Initialize ensemble scheduler
def ensemble_scheduler(optimizer, p = 1):
    def near_harmonic_schedule(epoch):
        # print((epoch + 1) ** p)
        return 1/(epoch + 1) ** p
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=near_harmonic_schedule)
    return scheduler

########## Step 3: Retrieve and Filter Runs from Weights & Biases, load datasets
all_runs, len_eval_steps, eval_every, dataset_name, eval_dataset_name = process_wandb_runs(project_name, sort_metric=args.sort_metric, wandb_entity=wandb_entity, reload=args.reload)
args.eval_every = eval_every

# Original dataset (e.g. CIFAR10)
dataset_class = globals()[dataset_name] 
dataset = dataset_class(batch_size=64, random_seed=1, data_randseed=1)
train_loader, val_loader, test_loader = dataset.load_datasets()

# Evaulation dataset (e.g. CIFAR10_Lp)
eval_dataset_class = globals()[eval_dataset_name] 
eval_dataset = eval_dataset_class(batch_size=64, random_seed=1, data_randseed=1)
eval_train_loader, eval_val_loader, eval_test_loader = eval_dataset.load_datasets()

########## Step 4: Loop over each configuration index and set corresponding args, start sweep
for config_index in config_indices:
    # Calculate indices for each hyperparameter array
    i = config_index % len(args.lr_list)
    j = (config_index // len(args.lr_list)) % len(args.eps_list)
    k = (config_index // (len(args.lr_list) * len(args.eps_list))) % len(args.beta_1_list)
    l = (config_index // (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list))) % len(args.beta_2_list)
    m = (config_index // (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list) * len(args.beta_2_list))) % len(args.warmup_epochs_list)
    n = (config_index // (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list) * len(args.beta_2_list) * len(args.warmup_epochs_list))) % len(args.optimzer_choice_list)
    o = (config_index // (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list) * len(args.beta_2_list) * len(args.warmup_epochs_list) * len(args.optimzer_choice_list))) % len(args.weight_decay_list)
    q = (config_index // (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list) * len(args.beta_2_list) * len(args.warmup_epochs_list) * len(args.optimzer_choice_list) * len(args.weight_decay_list))) % len(args.p_list)
    r = (config_index // (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list) * len(args.beta_2_list) * len(args.warmup_epochs_list) * len(args.optimzer_choice_list) * len(args.weight_decay_list) * len(args.p_list))) % len(args.num_ingredients_list)
    s = (config_index // (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list) * len(args.beta_2_list) * len(args.warmup_epochs_list) * len(args.optimzer_choice_list) * len(args.weight_decay_list) * len(args.p_list) * len(args.num_ingredients_list))) % len(args.epochs_list)
    t = (config_index // (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list) * len(args.beta_2_list) * len(args.warmup_epochs_list) * len(args.optimzer_choice_list) * len(args.weight_decay_list) * len(args.p_list) * len(args.num_ingredients_list) * len(args.epochs_list))) % len(args.z_list)
    u = (config_index // (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list) * len(args.beta_2_list) * len(args.warmup_epochs_list) * len(args.optimzer_choice_list) * len(args.weight_decay_list) * len(args.p_list) * len(args.num_ingredients_list) * len(args.epochs_list) * len(args.z_list))) % len(args.shuffletrue_list)
    v = (config_index // (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list) * len(args.beta_2_list) * len(args.warmup_epochs_list) * len(args.optimzer_choice_list) * len(args.weight_decay_list) * len(args.p_list) * len(args.num_ingredients_list) * len(args.epochs_list) * len(args.z_list) * len(args.shuffletrue_list))) % len(args.batch_size_list)

    # Override args with the current hyperparameter configuration
    args.lr = args.lr_list[i]
    args.eps = args.eps_list[j]
    args.beta1 = args.beta_1_list[k]
    args.beta2 = args.beta_2_list[l]
    args.warmup_epochs = args.warmup_epochs_list[m]
    args.opt = args.optimzer_choice_list[n]
    args.weight_decay = args.weight_decay_list[o]
    args.p = args.p_list[q]
    args.num_ingredients = args.num_ingredients_list[r]
    args.epochs = args.epochs_list[s]
    args.z = args.z_list[t]
    args.shuffletrue = args.shuffletrue_list[u]
    args.batch_size = args.batch_size_list[v]
    # {args.warmup_epochs} {args.p} {args.z} {args.shuffletrue}
    # Step 4: Initialize wandb
    project_name1 = args.base_project_name
    if args.greedy:
        project_name1 += 'Grdy'
    project_name1 += 'Ens_'
    project_name1 += args.project_name
    run_name1 = f"{args.opt} {args.lr} {args.eps} {args.beta1} {args.beta2} {args.epochs} {args.p} {args.batch_size} {args.weight_decay} {args.num_ingredients} {args.z} {args.shuffletrue}" # {args.warmup_epochs} {args.p} {args.z} {args.shuffletrue} #{args.PySeed} {args.DataSeed} {args.fine_tune_ver}

    # Set up Weights & Biases logging
    os.makedirs('/directory/wandblog/', exist_ok=True)
    os.environ['WANDB_DIR'] = '/directory/wandblog/'
    # Set the environment variable for service wait time
    os.environ["WANDB__SERVICE_WAIT"] = "300"
    wandb.init(
        project=project_name1,
        name=run_name1,
        config=args.__dict__
    )

    # Step 5: Ensemble
    loss_list = []
    accuracy_list = []
    #################
    for epoch_idx in range(len_eval_steps):
        models_collection, latest_model_path = load_state_dicts(all_runs, project_name, args.num_ingredients, device, epoch_idx)

        # Initialize EnsembleNet
        dataset = CustomDataset(models_collection, (args.num_ingredients - 1) - np.array(range(len(models_collection)))) # Order needs to be from 0
        dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=bool(args.shuffletrue)) #False)  # #

        # Make souped initialization (standard soup)
        dataloader_soup = DataLoader(dataset, batch_size=1, shuffle=False)
        model_soup = torch.load(latest_model_path, map_location=device)
        model_soup.to(device)
        optimizer_SGD = torch.optim.SGD(model_soup.parameters(), lr=1) 
        scheduler_soup = ensemble_scheduler(optimizer_SGD, p = 1)
        for epoch in range(1):
            for batch in dataloader_soup:
                # Ensemble
                optimizer_SGD.zero_grad()
                for x in batch:
                    loss = vanilla_loss(model_soup, x)
                    loss.backward() 
                optimizer_SGD.step()
                scheduler_soup.step()

        # Initialize model to souped version
        model = model_soup
        del dataloader_soup
        torch.cuda.empty_cache() if device.type == 'cuda' else None

        # Initialize optimizer
        optimizer = initialize_optimizer(args, model)
        scheduler = ensemble_scheduler(optimizer, p=args.p)
        num_epochs = args.epochs

        # Do greedy versus non-greedy using an if-else statement
        if args.greedy == 1:
            # Greedy mode
            model.train()
            best_val_accuracy = float('-inf')
            best_model_state = model.state_dict()  # Save initial model state
            best_optimizer_state = optimizer.state_dict()  # Save initial optimizer state
            
            for epoch in range(num_epochs):
                for batch in dataloader:
                    # Train/ensemble
                    optimizer.zero_grad()
                    for x in batch:
                        loss = custom_loss(model, x, z = args.z)
                        loss.backward()
                    optimizer.step()

                    # Evaluate model on validation data after each batch
                    new_val_loss, new_val_accuracy, new_val_datapoints = model_evaluate_imagedata(model, val_loader, device, loss_function=torch.nn.functional.cross_entropy)

                    # Compare the validation accuracy
                    if new_val_accuracy >= best_val_accuracy:
                        # Keep the new model and optimizer states
                        best_val_accuracy = new_val_accuracy
                        best_model_state = model.state_dict()
                        best_optimizer_state = optimizer.state_dict()
                        scheduler.step()  # Step the scheduler only if the model is retained
                        print("New model retained with accuracy: ", new_val_accuracy)
                    else:
                        # Revert to the previous best model and optimizer states
                        model.load_state_dict(best_model_state)
                        optimizer.load_state_dict(best_optimizer_state)
                        print("Model reverted to previous best with accuracy: ", best_val_accuracy)

                print(f"Epoch {epoch+1}/{num_epochs}, Custom Loss: {loss.item()}")

        else:
            # Non-greedy mode
            model.train()
            for epoch in range(num_epochs):
                for batch in dataloader:
                    # Train/ensemble
                    optimizer.zero_grad()
                    for x in batch:
                        loss = custom_loss(model, x, z = args.z)
                        loss.backward()
                    optimizer.step()
                    scheduler.step()
        
        # Evaluate ensemble on training data
        train_loss, train_accuracy, train_datapoints = model_evaluate_imagedata(model, train_loader, device, loss_function = torch.nn.functional.cross_entropy)
        eval_train_loss, eval_train_accuracy, eval_train_datapoints = model_evaluate_imagedata(model, eval_train_loader, device, loss_function = torch.nn.functional.cross_entropy)

        # Val evaluation loop
        val_loss, val_accuracy, val_datapoints = model_evaluate_imagedata(model, val_loader, device, loss_function = torch.nn.functional.cross_entropy)
        eval_val_loss, eval_val_accuracy, eval_val_datapoints = model_evaluate_imagedata(model, eval_val_loader, device, loss_function = torch.nn.functional.cross_entropy)

        # Test loop
        test_loss, test_accuracy, test_datapoints = model_evaluate_imagedata(model, test_loader, device, loss_function = torch.nn.functional.cross_entropy)
        eval_test_loss, eval_test_accuracy, eval_test_datapoints = model_evaluate_imagedata(model, eval_test_loader, device, loss_function = torch.nn.functional.cross_entropy)

        # Calculate total traindata metrics
        total_traindata_accuracy = (train_accuracy * train_datapoints + val_accuracy * val_datapoints + test_accuracy * test_datapoints) / (train_datapoints + val_datapoints + test_datapoints)
        total_traindata_loss = (train_loss * train_datapoints + val_loss * val_datapoints + test_loss * test_datapoints) / (train_datapoints + val_datapoints + test_datapoints)

        # Calculate eval total traindata metrics
        eval_total_traindata_accuracy = (eval_train_accuracy * eval_train_datapoints + eval_val_accuracy * eval_val_datapoints + eval_test_accuracy * eval_test_datapoints) / (eval_train_datapoints + eval_val_datapoints + eval_test_datapoints)
        eval_total_traindata_loss = (eval_train_loss * eval_train_datapoints + eval_val_loss * eval_val_datapoints + eval_test_loss * eval_test_datapoints) / (eval_train_datapoints + eval_val_datapoints + eval_test_datapoints)
        
        wandb.log({
            # 'epoch': epoch,
            'train_loss': train_loss,
            'train_accuracy': train_accuracy,
            'val_loss': val_loss,
            'val_accuracy': val_accuracy,
            'test_loss': test_loss,
            'test_accuracy': test_accuracy,
            'eval_train_loss': eval_train_loss,
            'eval_train_accuracy': eval_train_accuracy,
            'eval_val_loss': eval_val_loss,
            'eval_val_accuracy': eval_val_accuracy,
            'eval_test_loss': eval_test_loss,
            'eval_test_accuracy': eval_test_accuracy,
            'total_traindata_accuracy': total_traindata_accuracy,
            'total_traindata_loss': total_traindata_loss,
            'eval_total_traindata_accuracy': eval_total_traindata_accuracy,
            'eval_total_traindata_loss': eval_total_traindata_loss
        })

        # Clear memory
        del models_collection, dataloader, dataset, model, model_soup
        torch.cuda.empty_cache() if device.type == 'cuda' else None
    wandb.finish()