import torch
from torch.utils.data import Dataset
import wandb
import os
from tenacity import retry, wait_exponential, stop_after_attempt
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from torch.utils.data import DataLoader
from timm.models.vision_transformer import vit_small_patch16_224 
from torchvision import models
from meta.meta_dataset import CIFAR10_oneclass, CIFAR10_nineclass 
from central_data.dataset import CIFAR10
from meta.args_ens import args_parser 
import torch.optim.lr_scheduler as lr_scheduler
import pickle
from central_data.wandb_ens_pickle import get_wandb_runs, load_runs_from_file, process_wandb_runs, load_state_dicts, CustomDataset, custom_loss, vanilla_loss #* # is the args_parser being imported here? Yes. It is an override. 
from evaluations.train_test import model_evaluate_imagedata, model_train_imagedata, initialize_optimizer

########## Step 1: Set up hyperparameter configuration sweep
args = args_parser()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
project_name = args.project_name 
wandb_entity = 'pyt-geo'

# Calculating the total number of hyperparameter configurations
total_configs = (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list) * len(args.beta_2_list) *
                 len(args.warmup_epochs_list) * len(args.optimzer_choice_list) * len(args.weight_decay_list) *
                 len(args.p_list) * len(args.num_ingredients_list) * len(args.epochs_list) * len(args.z_list) *
                 len(args.shuffletrue_list) * len(args.batch_size_list))

# Function to compute the index of the configuration
def compute_config_index(slurm_id, max_array):
    return [(slurm_id + i) % total_configs for i in range(0, total_configs, max_array + 1)]

# Retrieve the configuration indices to process
config_indices = compute_config_index(args.slurm_id, args.max_array)

########## Step 2: Initialize ensemble scheduler
def ensemble_scheduler(optimizer, p = 1):
    def near_harmonic_schedule(epoch):
        print((epoch + 1) ** p)
        return 1/(epoch + 1) ** p
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=near_harmonic_schedule)
    return scheduler

########## Step 3: Retrieve and Filter Runs from Weights & Biases, load datasets
all_runs, len_eval_steps, eval_every, dataset_name, eval_dataset_name = process_wandb_runs(project_name, sort_metric=args.sort_metric, wandb_entity=wandb_entity, reload=args.reload)
args.eval_every = eval_every

# Evaulation dataset (e.g. CIFAR10_Lp)
eval_dataset_class = globals()[eval_dataset_name] 
eval_dataset = eval_dataset_class(batch_size=64, random_seed=1, data_randseed=1)
eval_train_loader, eval_val_loader, eval_test_loader = eval_dataset.load_datasets()

########## Step 4: Loop over each configuration index and set corresponding args, start sweep
for config_index in config_indices:
    # Calculate indices for each hyperparameter array
    i = config_index % len(args.lr_list)
    j = (config_index // len(args.lr_list)) % len(args.eps_list)
    k = (config_index // (len(args.lr_list) * len(args.eps_list))) % len(args.beta_1_list)
    l = (config_index // (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list))) % len(args.beta_2_list)
    m = (config_index // (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list) * len(args.beta_2_list))) % len(args.warmup_epochs_list)
    n = (config_index // (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list) * len(args.beta_2_list) * len(args.warmup_epochs_list))) % len(args.optimzer_choice_list)
    o = (config_index // (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list) * len(args.beta_2_list) * len(args.warmup_epochs_list) * len(args.optimzer_choice_list))) % len(args.weight_decay_list)
    q = (config_index // (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list) * len(args.beta_2_list) * len(args.warmup_epochs_list) * len(args.optimzer_choice_list) * len(args.weight_decay_list))) % len(args.p_list)
    r = (config_index // (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list) * len(args.beta_2_list) * len(args.warmup_epochs_list) * len(args.optimzer_choice_list) * len(args.weight_decay_list) * len(args.p_list))) % len(args.num_ingredients_list)
    s = (config_index // (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list) * len(args.beta_2_list) * len(args.warmup_epochs_list) * len(args.optimzer_choice_list) * len(args.weight_decay_list) * len(args.p_list) * len(args.num_ingredients_list))) % len(args.epochs_list)
    t = (config_index // (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list) * len(args.beta_2_list) * len(args.warmup_epochs_list) * len(args.optimzer_choice_list) * len(args.weight_decay_list) * len(args.p_list) * len(args.num_ingredients_list) * len(args.epochs_list))) % len(args.z_list)
    u = (config_index // (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list) * len(args.beta_2_list) * len(args.warmup_epochs_list) * len(args.optimzer_choice_list) * len(args.weight_decay_list) * len(args.p_list) * len(args.num_ingredients_list) * len(args.epochs_list) * len(args.z_list))) % len(args.shuffletrue_list)
    v = (config_index // (len(args.lr_list) * len(args.eps_list) * len(args.beta_1_list) * len(args.beta_2_list) * len(args.warmup_epochs_list) * len(args.optimzer_choice_list) * len(args.weight_decay_list) * len(args.p_list) * len(args.num_ingredients_list) * len(args.epochs_list) * len(args.z_list) * len(args.shuffletrue_list))) % len(args.batch_size_list)

    # Override args with the current hyperparameter configuration
    args.lr = args.lr_list[i]
    args.eps = args.eps_list[j]
    args.beta1 = args.beta_1_list[k]
    args.beta2 = args.beta_2_list[l]
    args.warmup_epochs = args.warmup_epochs_list[m]
    args.opt = args.optimzer_choice_list[n]
    args.weight_decay = args.weight_decay_list[o]
    args.p = args.p_list[q]
    args.num_ingredients = args.num_ingredients_list[r]
    args.epochs = args.epochs_list[s]
    args.z = args.z_list[t]
    args.shuffletrue = args.shuffletrue_list[u]
    args.batch_size = args.batch_size_list[v]

    # Step 4: Initialize wandb
    project_name1 = args.base_project_name
    if args.greedy:
        project_name1 += 'Grdy'
    project_name1 += 'Ens_'
    project_name1 += args.project_name
    run_name1 = f"{args.opt} {args.lr} {args.eps} {args.beta1} {args.beta2} {args.epochs} {args.p} {args.batch_size} {args.weight_decay} {args.num_ingredients} {args.z} {args.shuffletrue}" # {args.warmup_epochs} {args.p} {args.z} {args.shuffletrue} #{args.PySeed} {args.DataSeed} {args.fine_tune_ver}

    # Set up Weights & Biases logging
    os.makedirs('/directory/wandblog/', exist_ok=True)
    os.environ['WANDB_DIR'] = '/directory/wandblog/'
    os.environ["WANDB__SERVICE_WAIT"] = "300"
    wandb.init(
        project=project_name1,
        name=run_name1,
        config=args.__dict__
    )

    # Step 5: Ensemble
    loss_list = []
    accuracy_list = []

    epoch_idx = args.epoch_idx 
    models_collection, latest_model_path = load_state_dicts(all_runs, project_name, args.num_ingredients, device, epoch_idx, meta = False)
    dataset = CustomDataset(models_collection, (args.num_ingredients - 1) - np.array(range(len(models_collection)))) 
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=bool(args.shuffletrue)) 
    
    # Make souped initialization (standard soup)
    dataloader_soup = DataLoader(dataset, batch_size=1, shuffle=False)
    model_soup = torch.load(latest_model_path, map_location=device)
    model_soup.to(device)
    optimizer_SGD = torch.optim.SGD(model_soup.parameters(), lr=1) 
    scheduler_soup = ensemble_scheduler(optimizer_SGD, p = 1)

    for batch in dataloader_soup:
        # Print the learning rate
        for param_group in optimizer_SGD.param_groups:
            print(f"Learning Rate for souping SGD: {param_group['lr']}")
        # Ensemble
        optimizer_SGD.zero_grad()
        for x in batch:
            loss = vanilla_loss(model_soup, x)
            loss.backward() 
        optimizer_SGD.step()
        scheduler_soup.step()

    # Initialize model to souped version
    model = model_soup
    del dataloader_soup
    torch.cuda.empty_cache() if device.type == 'cuda' else None
    
    # Evaluate the soup
    # soup_train_loss, soup_train_accuracy = model_evaluate_imagedata(model, train_loader, device, loss_function = torch.nn.functional.cross_entropy)
    soup_eval_train_loss, soup_eval_train_accuracy, soup_eval_train_datapoints = model_evaluate_imagedata(model, eval_train_loader, device, loss_function = torch.nn.functional.cross_entropy)

    # Val evaluation loop
    # soup_val_loss, soup_val_accuracy = model_evaluate_imagedata(model, val_loader, device, loss_function = torch.nn.functional.cross_entropy)
    soup_eval_val_loss, soup_eval_val_accuracy, soup_eval_val_datapoints = model_evaluate_imagedata(model, eval_val_loader, device, loss_function = torch.nn.functional.cross_entropy)

    # Test loop
    # soup_test_loss, soup_test_accuracy = model_evaluate_imagedata(model, test_loader, device, loss_function = torch.nn.functional.cross_entropy)
    soup_eval_test_loss, soup_eval_test_accuracy, soup_eval_test_datapoints = model_evaluate_imagedata(model, eval_test_loader, device, loss_function = torch.nn.functional.cross_entropy)

    # Make greedy soup
    dataloader_soup = DataLoader(dataset, batch_size=1, shuffle=False)
    model_greedy_soup = torch.load(latest_model_path, map_location=device)
    model_greedy_soup.to(device)
    optimizer_SGD = torch.optim.SGD(model_greedy_soup.parameters(), lr=1)
    scheduler_soup = ensemble_scheduler(optimizer_SGD, p=1)

    # Evaluate the initial model
    old_GS_val_loss, old_GS_val_accuracy, old_GS_val_datapoints = model_evaluate_imagedata(model_greedy_soup, eval_val_loader, device, loss_function=torch.nn.functional.cross_entropy)
    best_val_accuracy = old_GS_val_accuracy
    best_model_greedy_soup_state = model_greedy_soup.state_dict()  # Save the initial state dict

    for epoch in range(1):
        for batch in dataloader_soup:
            # Print the learning rate
            for param_group in optimizer_SGD.param_groups:
                print(f"Learning Rate for souping SGD: {param_group['lr']}")
            
            # Train with the current batch
            optimizer_SGD.zero_grad()
            for x in batch:
                loss = vanilla_loss(model_greedy_soup, x)
                loss.backward()
            optimizer_SGD.step()
            
            # Evaluate the updated model
            new_GS_val_loss, new_GS_val_accuracy, new_GS_val_datapoints = model_evaluate_imagedata(model_greedy_soup, eval_val_loader, device, loss_function=torch.nn.functional.cross_entropy)
            
            # Compare the validation accuracy
            if new_GS_val_accuracy >= best_val_accuracy:
                # Keep the new model
                best_val_accuracy = new_GS_val_accuracy
                best_model_greedy_soup_state = model_greedy_soup.state_dict()  # Save the best state
                scheduler_soup.step()  # Take a step of the scheduler
                print("New model retained with accuracy: ", new_GS_val_accuracy)
            else:
                # Revert to the previous best model
                model_greedy_soup.load_state_dict(best_model_greedy_soup_state)  # Restore the best model's weights
                print("Model reverted to previous best with accuracy: ", best_val_accuracy)

    # Final model is the best model from the greedy soup process
    # model_greedy = model_greedy_soup
    del dataloader_soup
    torch.cuda.empty_cache() if device.type == 'cuda' else None

    # Evaluate greedy model soup
    # greedy_train_loss, greedy_train_accuracy, greedy_train_datapoints = model_evaluate_imagedata(model_greedy_soup, train_loader, device, loss_function = torch.nn.functional.cross_entropy)
    greedy_soup_eval_train_loss, greedy_soup_eval_train_accuracy, greedy_soup_eval_train_datapoints = model_evaluate_imagedata(model_greedy_soup, eval_train_loader, device, loss_function = torch.nn.functional.cross_entropy)
    # greedy_val_loss, greedy_val_accuracy, greedy_val_datapoints = model_evaluate_imagedata(model_greedy_soup, val_loader, device, loss_function = torch.nn.functional.cross_entropy)
    greedy_soup_eval_val_loss, greedy_soup_eval_val_accuracy, greedy_soup_eval_val_datapoints = model_evaluate_imagedata(model_greedy_soup, eval_val_loader, device, loss_function = torch.nn.functional.cross_entropy)
    # greedy_test_loss, greedy_test_accuracy, greedy_test_datapoints = model_evaluate_imagedata(model_greedy_soup, test_loader, device, loss_function = torch.nn.functional.cross_entropy)
    greedy_soup_eval_test_loss, greedy_soup_eval_test_accuracy, greedy_soup_eval_test_datapoints = model_evaluate_imagedata(model_greedy_soup, eval_test_loader, device, loss_function = torch.nn.functional.cross_entropy)


    # Initialize optimizer
    optimizer = initialize_optimizer(args, model)
    scheduler = ensemble_scheduler(optimizer, p=args.p)
    num_epochs = args.epochs

    for epoch in range(num_epochs):
        print(f"Epoch is {epoch_idx}")

        # Do greedy versus non-greedy using an if-else statement
        if args.greedy == 1:
            # Greedy mode
            model.train()
            best_val_accuracy = float('-inf')
            best_model_state = model.state_dict()  # Save initial model state
            best_optimizer_state = optimizer.state_dict()  # Save initial optimizer state
            
            # for epoch in range(num_epochs):
            for batch in dataloader:
                for param_group in optimizer.param_groups:
                    print(f"Epoch {epoch} Learning Rate: {param_group['lr']}")

                # Train/ensemble
                optimizer.zero_grad()
                for x in batch:
                    loss = custom_loss(model, x)
                    loss.backward()
                optimizer.step()

                # Evaluate model on validation data after each batch
                new_val_loss, new_val_accuracy, new_val_datapoints = model_evaluate_imagedata(model, eval_val_loader, device, loss_function=torch.nn.functional.cross_entropy)

                # Compare the validation accuracy
                if new_val_accuracy >= best_val_accuracy:
                    # Keep the new model and optimizer states
                    best_val_accuracy = new_val_accuracy
                    best_model_state = model.state_dict()
                    best_optimizer_state = optimizer.state_dict()
                    scheduler.step()  # Step the scheduler only if the model is retained
                    print("New model retained with accuracy: ", new_val_accuracy)
                else:
                    # Revert to the previous best model and optimizer states
                    model.load_state_dict(best_model_state)
                    optimizer.load_state_dict(best_optimizer_state)
                    print("Model reverted to previous best with accuracy: ", best_val_accuracy)

            print(f"Epoch {epoch+1}/{num_epochs}, Custom Loss: {loss.item()}")

        else:
            # Non-greedy mode
            model.train()
            # for epoch in range(num_epochs):
            for batch in dataloader:
                for param_group in optimizer.param_groups:
                    print(f"Epoch {epoch} Learning Rate: {param_group['lr']}")

                # Train/ensemble
                optimizer.zero_grad()
                for x in batch:
                    loss = custom_loss(model, x)
                    loss.backward()
                optimizer.step()
                scheduler.step()

            print(f"Epoch {epoch+1}/{num_epochs}, Custom Loss: {loss.item()}")
        
        # Evaluate ensemble on training data
        # train_loss, train_accuracy, train_datapoints = model_evaluate_imagedata(model, train_loader, device, loss_function = torch.nn.functional.cross_entropy)
        eval_train_loss, eval_train_accuracy, eval_train_datapoints = model_evaluate_imagedata(model, eval_train_loader, device, loss_function = torch.nn.functional.cross_entropy)

        # Val evaluation loop
        # val_loss, val_accuracy, val_datapoints = model_evaluate_imagedata(model, val_loader, device, loss_function = torch.nn.functional.cross_entropy)
        eval_val_loss, eval_val_accuracy, eval_val_datapoints = model_evaluate_imagedata(model, eval_val_loader, device, loss_function = torch.nn.functional.cross_entropy)

        # Test loop
        # test_loss, test_accuracy, test_datapoints = model_evaluate_imagedata(model, test_loader, device, loss_function = torch.nn.functional.cross_entropy)
        eval_test_loss, eval_test_accuracy, eval_test_datapoints = model_evaluate_imagedata(model, eval_test_loader, device, loss_function = torch.nn.functional.cross_entropy)

        # Calculate eval total traindata metrics
        eval_total_traindata_accuracy = (eval_train_accuracy * eval_train_datapoints + eval_val_accuracy * eval_val_datapoints + eval_test_accuracy * eval_test_datapoints) / (eval_train_datapoints + eval_val_datapoints + eval_test_datapoints)
        eval_total_traindata_loss = (eval_train_loss * eval_train_datapoints + eval_val_loss * eval_val_datapoints + eval_test_loss * eval_test_datapoints) / (eval_train_datapoints + eval_val_datapoints + eval_test_datapoints)
        
        # if epoch_idx == 0: 
        # Calculate eval total traindata metrics
        soup_eval_total_traindata_accuracy = (soup_eval_train_accuracy * soup_eval_train_datapoints + soup_eval_val_accuracy * soup_eval_val_datapoints + soup_eval_test_accuracy * soup_eval_test_datapoints) / (soup_eval_train_datapoints + soup_eval_val_datapoints + soup_eval_test_datapoints)
        soup_eval_total_traindata_loss = (soup_eval_train_loss * soup_eval_train_datapoints + soup_eval_val_loss * soup_eval_val_datapoints + soup_eval_test_loss * soup_eval_test_datapoints) / (soup_eval_train_datapoints + soup_eval_val_datapoints + soup_eval_test_datapoints)
        
        # Repeat for greedy
        greedy_soup_eval_total_traindata_accuracy = (greedy_soup_eval_train_accuracy * greedy_soup_eval_train_datapoints + greedy_soup_eval_val_accuracy * greedy_soup_eval_val_datapoints + greedy_soup_eval_test_accuracy * greedy_soup_eval_test_datapoints) / (greedy_soup_eval_train_datapoints + greedy_soup_eval_val_datapoints + greedy_soup_eval_test_datapoints)
        greedy_soup_eval_total_traindata_loss = (greedy_soup_eval_train_loss * greedy_soup_eval_train_datapoints + greedy_soup_eval_val_loss * greedy_soup_eval_val_datapoints + greedy_soup_eval_test_loss * greedy_soup_eval_test_datapoints) / (greedy_soup_eval_train_datapoints + greedy_soup_eval_val_datapoints + greedy_soup_eval_test_datapoints)

        wandb.log({
            'eval_train_loss': eval_train_loss,
            'eval_train_accuracy': eval_train_accuracy,
            'eval_val_loss': eval_val_loss,
            'eval_val_accuracy': eval_val_accuracy,
            'eval_test_loss': eval_test_loss,
            'eval_test_accuracy': eval_test_accuracy,
            'eval_total_traindata_accuracy': eval_total_traindata_accuracy,
            'eval_total_traindata_loss': eval_total_traindata_loss, 
            'soup_eval_train_loss': soup_eval_train_loss,
            'soup_eval_train_accuracy': soup_eval_train_accuracy,
            'soup_eval_val_loss': soup_eval_val_loss,
            'soup_eval_val_accuracy': soup_eval_val_accuracy,
            'soup_eval_test_loss': soup_eval_test_loss,
            'soup_eval_test_accuracy': soup_eval_test_accuracy,
            'soup_eval_total_traindata_accuracy': soup_eval_total_traindata_accuracy,
            'soup_eval_total_traindata_loss': soup_eval_total_traindata_loss,
            'greedy_soup_eval_train_loss': greedy_soup_eval_train_loss,
            'greedy_soup_eval_train_accuracy': greedy_soup_eval_train_accuracy,
            'greedy_soup_eval_val_loss': greedy_soup_eval_val_loss,
            'greedy_soup_eval_val_accuracy': greedy_soup_eval_val_accuracy,
            'greedy_soup_eval_test_loss': greedy_soup_eval_test_loss,
            'greedy_soup_eval_test_accuracy': greedy_soup_eval_test_accuracy,
            'greedy_soup_eval_total_traindata_accuracy': greedy_soup_eval_total_traindata_accuracy,
            'greedy_soup_eval_total_traindata_loss': greedy_soup_eval_total_traindata_loss,
        })
    wandb.finish()