import torch
import wandb
import os
from tenacity import retry, wait_exponential, stop_after_attempt
import numpy as np
from torch.utils.data import DataLoader
from central_data.dataset import CIFAR10, CIFAR100, MNIST, FMNIST, CIFAR10p1, GLD23K
from central_data.dataset import CIFAR10_Lp, CIFAR100_Lp, MNIST_Lp, FMNIST_Lp, CIFAR10p1_Lp, GLD23K_Lp
from central_data.ensemble_options import eff_args_parser 
import torch.optim.lr_scheduler as lr_scheduler
import pickle
from central_data.wandb_ens_pickle import get_wandb_runs, load_runs_from_file, process_wandb_runs, load_state_dicts, CustomDataset, custom_loss, vanilla_loss
from evaluations.train_test import model_evaluate_imagedata, model_train_imagedata, initialize_optimizer

args = eff_args_parser()

########## Step 1: Set up hyperparameter configuration sweep
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
project_name = args.project_name 
wandb_entity ='pyt-geo'
sort_metric= args.sort_metric #'val_accuracy'
reload_bool = 0
base_project_name = 'SP_BaseLine_'
shuffletrue=0

num_ingredients_list = [args.num_ingredients] 

########## Step 2: Initialize ensemble scheduler
def ensemble_scheduler(optimizer, p = 1):
    def near_harmonic_schedule(epoch):
        print((epoch + 1) ** p)
        return 1/(epoch + 1) ** p
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=near_harmonic_schedule)
    return scheduler

########## Step 3: Retrieve and Filter Runs from Weights & Biases, load datasets
all_runs, len_eval_steps, eval_every, dataset_name, eval_dataset_name = process_wandb_runs(project_name, sort_metric=sort_metric, wandb_entity=wandb_entity, reload=reload_bool)

# Original dataset (e.g. CIFAR10)
dataset_class = globals()[dataset_name] 
dataset = dataset_class(batch_size=64, random_seed=1, data_randseed=1)
train_loader, val_loader, test_loader = dataset.load_datasets()

# Evaulation dataset (e.g. CIFAR10_Lp)
eval_dataset_class = globals()[eval_dataset_name] 
eval_dataset = eval_dataset_class(batch_size=64, random_seed=1, data_randseed=1)
eval_train_loader, eval_val_loader, eval_test_loader = eval_dataset.load_datasets()

########## Step 4: Loop over each configuration index and set corresponding args, start sweep
for num_ingredients in num_ingredients_list:
    #  Initialize wandb
    project_name1 = base_project_name
    project_name1 += 'soups_'
    project_name1 += project_name
    run_name1 = f"Soup {num_ingredients}" 

    # Set up Weights & Biases logging
    os.makedirs('/directory/wandblog/', exist_ok=True)
    os.environ['WANDB_DIR'] = '/directory/wandblog/'
    os.environ["WANDB__SERVICE_WAIT"] = "300"
    wandb.init(
        project=project_name1,
        name=run_name1,
        #config=args.__dict__
    )

    # Step 5: Ensemble
    for epoch_idx in range(len_eval_steps):
        print(f"Epoch index is {epoch_idx}")
        models_collection, latest_model_path = load_state_dicts(all_runs, project_name, num_ingredients, device, epoch_idx)

        # Initialize EnsembleNet
        dataset = CustomDataset(models_collection, (num_ingredients - 1) - np.array(range(len(models_collection)))) # Order needs to be from 0

        # Make standard soup
        dataloader_soup = DataLoader(dataset, batch_size=1, shuffle=False)
        model_soup = torch.load(latest_model_path, map_location=device)
        model_soup.to(device)
        optimizer_SGD = torch.optim.SGD(model_soup.parameters(), lr=1) 
        scheduler_soup = ensemble_scheduler(optimizer_SGD, p = 1)
        for epoch in range(1):
            for batch in dataloader_soup:
                # Print the learning rate
                for param_group in optimizer_SGD.param_groups:
                    print(f"Learning Rate for souping SGD: {param_group['lr']}")
                # Ensemble
                optimizer_SGD.zero_grad()
                for x in batch:
                    loss = vanilla_loss(model_soup, x)
                    loss.backward() 
                optimizer_SGD.step()
                scheduler_soup.step()
        del dataloader_soup
        torch.cuda.empty_cache() if device.type == 'cuda' else None
        
        # Make greedy soup
        dataloader_soup = DataLoader(dataset, batch_size=1, shuffle=False)
        model_greedy_soup = torch.load(latest_model_path, map_location=device)
        model_greedy_soup.to(device)
        optimizer_SGD = torch.optim.SGD(model_greedy_soup.parameters(), lr=1)
        scheduler_soup = ensemble_scheduler(optimizer_SGD, p=1)

        # Evaluate the initial model
        old_GS_val_loss, old_GS_val_accuracy, old_GS_val_datapoints = model_evaluate_imagedata(model_greedy_soup, val_loader, device, loss_function=torch.nn.functional.cross_entropy)
        best_val_accuracy = old_GS_val_accuracy
        best_model_greedy_soup_state = model_greedy_soup.state_dict()  # Save the initial state dict

        for epoch in range(1):
            for batch in dataloader_soup:
                # Print the learning rate
                for param_group in optimizer_SGD.param_groups:
                    print(f"Learning Rate for souping SGD: {param_group['lr']}")
                
                # Train with the current batch
                optimizer_SGD.zero_grad()
                for x in batch:
                    loss = vanilla_loss(model_greedy_soup, x)
                    loss.backward()
                optimizer_SGD.step()
                
                # Evaluate the updated model
                new_GS_val_loss, new_GS_val_accuracy, new_GS_val_datapoints = model_evaluate_imagedata(model_greedy_soup, val_loader, device, loss_function=torch.nn.functional.cross_entropy)
                
                # Compare the validation accuracy
                if new_GS_val_accuracy >= best_val_accuracy:
                    # Keep the new model
                    best_val_accuracy = new_GS_val_accuracy
                    best_model_greedy_soup_state = model_greedy_soup.state_dict()  # Save the best state
                    scheduler_soup.step()  # Take a step of the scheduler
                    print("New model retained with accuracy: ", new_GS_val_accuracy)
                else:
                    # Revert to the previous best model
                    model_greedy_soup.load_state_dict(best_model_greedy_soup_state)  # Restore the best model's weights
                    print("Model reverted to previous best with accuracy: ", best_val_accuracy)

        # Final model is the best model from the greedy soup process
        model = model_greedy_soup
        del dataloader_soup
        torch.cuda.empty_cache() if device.type == 'cuda' else None

        # Evaluate model soup on training data
        train_loss, train_accuracy, train_datapoints = model_evaluate_imagedata(model_soup, train_loader, device, loss_function = torch.nn.functional.cross_entropy)
        eval_train_loss, eval_train_accuracy, eval_train_datapoints = model_evaluate_imagedata(model_soup, eval_train_loader, device, loss_function = torch.nn.functional.cross_entropy)

        # Val evaluation loop
        val_loss, val_accuracy, val_datapoints = model_evaluate_imagedata(model_soup, val_loader, device, loss_function = torch.nn.functional.cross_entropy)
        eval_val_loss, eval_val_accuracy, eval_val_datapoints = model_evaluate_imagedata(model_soup, eval_val_loader, device, loss_function = torch.nn.functional.cross_entropy)

        # Test loop
        test_loss, test_accuracy, test_datapoints = model_evaluate_imagedata(model_soup, test_loader, device, loss_function = torch.nn.functional.cross_entropy)
        eval_test_loss, eval_test_accuracy, eval_test_datapoints = model_evaluate_imagedata(model_soup, eval_test_loader, device, loss_function = torch.nn.functional.cross_entropy)
        
        # Evaluate greedy model soup
        greedy_train_loss, greedy_train_accuracy, greedy_train_datapoints = model_evaluate_imagedata(model_greedy_soup, train_loader, device, loss_function = torch.nn.functional.cross_entropy)
        greedy_eval_train_loss, greedy_eval_train_accuracy, greedy_eval_train_datapoints = model_evaluate_imagedata(model_greedy_soup, eval_train_loader, device, loss_function = torch.nn.functional.cross_entropy)
        greedy_val_loss, greedy_val_accuracy, greedy_val_datapoints = model_evaluate_imagedata(model_greedy_soup, val_loader, device, loss_function = torch.nn.functional.cross_entropy)
        greedy_eval_val_loss, greedy_eval_val_accuracy, greedy_eval_val_datapoints = model_evaluate_imagedata(model_greedy_soup, eval_val_loader, device, loss_function = torch.nn.functional.cross_entropy)
        greedy_test_loss, greedy_test_accuracy, greedy_test_datapoints = model_evaluate_imagedata(model_greedy_soup, test_loader, device, loss_function = torch.nn.functional.cross_entropy)
        greedy_eval_test_loss, greedy_eval_test_accuracy, greedy_eval_test_datapoints = model_evaluate_imagedata(model_greedy_soup, eval_test_loader, device, loss_function = torch.nn.functional.cross_entropy)

        # Calculate total traindata metrics
        total_traindata_accuracy = (train_accuracy * train_datapoints + val_accuracy * val_datapoints + test_accuracy * test_datapoints) / (train_datapoints + val_datapoints + test_datapoints)
        total_traindata_loss = (train_loss * train_datapoints + val_loss * val_datapoints + test_loss * test_datapoints) / (train_datapoints + val_datapoints + test_datapoints)

        # Calculate eval total traindata metrics
        eval_total_traindata_accuracy = (eval_train_accuracy * eval_train_datapoints + eval_val_accuracy * eval_val_datapoints + eval_test_accuracy * eval_test_datapoints) / (eval_train_datapoints + eval_val_datapoints + eval_test_datapoints)
        eval_total_traindata_loss = (eval_train_loss * eval_train_datapoints + eval_val_loss * eval_val_datapoints + eval_test_loss * eval_test_datapoints) / (eval_train_datapoints + eval_val_datapoints + eval_test_datapoints)
        
        # Repeat for greedy
        greedy_total_traindata_accuracy = (greedy_train_accuracy * greedy_train_datapoints + greedy_val_accuracy * greedy_val_datapoints + greedy_test_accuracy * greedy_test_datapoints) / (greedy_train_datapoints + greedy_val_datapoints + greedy_test_datapoints)
        greedy_total_traindata_loss = (greedy_train_loss * greedy_train_datapoints + greedy_val_loss * greedy_val_datapoints + greedy_test_loss * greedy_test_datapoints) / (greedy_train_datapoints + greedy_val_datapoints + greedy_test_datapoints)
        greedy_eval_total_traindata_accuracy = (greedy_eval_train_accuracy * greedy_eval_train_datapoints + greedy_eval_val_accuracy * greedy_eval_val_datapoints + greedy_eval_test_accuracy * greedy_eval_test_datapoints) / (greedy_eval_train_datapoints + greedy_eval_val_datapoints + greedy_eval_test_datapoints)
        greedy_eval_total_traindata_loss = (greedy_eval_train_loss * greedy_eval_train_datapoints + greedy_eval_val_loss * greedy_eval_val_datapoints + greedy_eval_test_loss * greedy_eval_test_datapoints) / (greedy_eval_train_datapoints + greedy_eval_val_datapoints + greedy_eval_test_datapoints)
        # model_greedy_soup
        wandb.log({
            'train_loss': train_loss,
            'train_accuracy': train_accuracy,
            'val_loss': val_loss,
            'val_accuracy': val_accuracy,
            'test_loss': test_loss,
            'test_accuracy': test_accuracy,
            'eval_train_loss': eval_train_loss,
            'eval_train_accuracy': eval_train_accuracy,
            'eval_val_loss': eval_val_loss,
            'eval_val_accuracy': eval_val_accuracy,
            'eval_test_loss': eval_test_loss,
            'eval_test_accuracy': eval_test_accuracy,
            'total_traindata_accuracy': total_traindata_accuracy,
            'total_traindata_loss': total_traindata_loss,
            'eval_total_traindata_accuracy': eval_total_traindata_accuracy,
            'eval_total_traindata_loss': eval_total_traindata_loss,
            'greedy_train_loss': greedy_train_loss,
            'greedy_train_accuracy': greedy_train_accuracy,
            'greedy_val_loss': greedy_val_loss,
            'greedy_val_accuracy': greedy_val_accuracy,
            'greedy_test_loss': greedy_test_loss,
            'greedy_test_accuracy': greedy_test_accuracy,
            'greedy_eval_train_loss': greedy_eval_train_loss,
            'greedy_eval_train_accuracy': greedy_eval_train_accuracy,
            'greedy_eval_val_loss': greedy_eval_val_loss,
            'greedy_eval_val_accuracy': greedy_eval_val_accuracy,
            'greedy_eval_test_loss': greedy_eval_test_loss,
            'greedy_eval_test_accuracy': greedy_eval_test_accuracy,
            'greedy_total_traindata_accuracy': greedy_total_traindata_accuracy,
            'greedy_total_traindata_loss': greedy_total_traindata_loss,
            'greedy_eval_total_traindata_accuracy': greedy_eval_total_traindata_accuracy,
            'greedy_eval_total_traindata_loss': greedy_eval_total_traindata_loss,
        })

        # Clear memory
        del models_collection, dataset, model, model_soup, model_greedy_soup
        torch.cuda.empty_cache() if device.type == 'cuda' else None
    wandb.finish()