import torch
import os
import os.path as osp
import argparse
import time
import json

import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'DejaVu Sans'

from botorch.acquisition.monte_carlo import qExpectedImprovement
from botorch.acquisition import UpperConfidenceBound

from bo import get_fitted_model, optimize_acqf_and_get_observation
from utils import device, results_path, get_logger

import warnings

# Ignore all warnings
warnings.filterwarnings("ignore")
# torch.autograd.set_detect_anomaly(True)

def training(args):

    torch.manual_seed(args.train_seed)

    logfolder = os.path.join(results_path, args.task, args.model, str(args.train_seed))

    # Check if the path exists, and create it if it doesn't
    if not os.path.exists(logfolder):
        os.makedirs(logfolder)  # Create the directory and any necessary parent directories
        print(f"Directory '{logfolder}' created.")
    else:
        print(f"Directory '{logfolder}' already exists.")

    best_observed_path = os.path.join(results_path, args.task, args.model, str(args.train_seed), f'best_observed_list.txt')

    logfilename = os.path.join(results_path, args.task, args.model, str(args.train_seed), f'train_{time.strftime("%Y%m%d-%H%M")}.log')
    logger = get_logger(logfilename)

    logger.info(f"Experiment: {args.model}-{args.task}")

    resultfilename = os.path.join(results_path, args.task, args.model, str(args.train_seed), f'results_{time.strftime("%Y%m%d")}.txt')
    # results_record = get_logger(resultfilename)
    results_file = open(resultfilename, 'w')

    args.results_file = resultfilename

    best_file = open(best_observed_path, 'w')

    #
    if args.model == 'vae_no':
        from models.vae_no import _vae, vae_retrain
    elif args.model == 'vae_base':
        from models.vae_base import _vae, vae_retrain
    elif args.model == 'vae_hippo':
        from models.vae_hippo import _vae, vae_retrain
    elif args.model == 'vae_gp':
        from models.vae_gp import _vae, vae_retrain
    elif args.model == 'vae_metric':
        from models.vae_metric import _vae, vae_retrain
    elif args.model == 'vae_reweigh':
        from models.vae_reweigh import _vae, vae_retrain
    else:
        print("UNKONWN Model Error!!!")

    # define vae model
    vae_model = _vae(args.input_dim, args.hidden_dim, args.latent_dim)
    vae_model.to(device)
    # vae_weights_path = "pretrained_models/mnist_vae.pt"
    # vae_state_dict = torch.load(vae_weights_path, map_location=device, weights_only=True)
    # vae_model.load_state_dict(vae_state_dict)
    vae_model.eval()

    # logger.info(f'Total number of parameters: {sum(p.numel() for p in vae_model.parameters())}\n')

    # load initl data
    best_observed = []

    with torch.no_grad():
        if args.task == 'mnist':
            from data.mnist import generate_initial_data, score_func, init_model

            score_model = init_model()
            train_x, train_obj = generate_initial_data(vae_model, score_model, args.num_initial_data, args.latent_dim)

            train_x = train_x.to(device)
            train_obj = train_obj.to(device)
        elif args.task == 'ackley':
            from botorch.test_functions import Ackley
            from data.ackley import generate_initial_data, score_func

            ackley = Ackley(dim=args.input_dim, noise_std=0.0)

            train_x, train_obj = generate_initial_data(ackley, args.num_initial_data, args.input_dim)

            train_x = train_x.to(device)
            train_obj = train_obj.to(device)
        else:
            print("UNKONWN Task Error!!!")

        best_value = train_obj.max().item()
        best_observed.append(best_value)
        results_file.write(str(best_value) + "\n")
        print("The best observed value so far: ", best_value)

    # main loop
    global_iteration = 0
    state_dict = None
    for i in range(args.num_epoch):

        # train VAE
        logger.info("VAE retrain " + str(i) + ".......")

        vae_model = vae_retrain(vae_model, train_x.detach(), train_obj.detach(), logger, args)
        with torch.no_grad():
            train_z = vae_model.encoding(train_x)

        # print("Device is : " + str(torch.cuda.get_device_name(device)))
        train_z = train_z.to(device)
        # print(train_z.device)

        # run BO
        # print(f"\nRunning BO ", end="")
        logger.info("\nRunning BO ....")

        # run N_BATCH rounds of BayesOpt after the initial random batch
        for iteration in range(args.n_batch):

            # fit the GP model
            gp_model = get_fitted_model(
                train_x=train_z,
                train_obj=train_obj,
                d=args.latent_dim,
                state_dict=state_dict,
            )

            # define the qEI acquisition function
            if args.acq_func == "ei":
                acq_func = qExpectedImprovement(model=gp_model, best_f=train_obj.max())
            elif args.acq_func == "ucb":
                acq_func = UpperConfidenceBound(gp_model, beta=0.1)
            else:
                print("unknown acquisition function error!!!")

            # optimize and get new observation
            new_z = optimize_acqf_and_get_observation(acq_func, args.latent_dim, args.batch_size, args.num_restarts, args.raw_samples)
            # new_z.to(device)

            # update training points
            new_x = vae_model.decode(new_z)
            # new_x.to(device)

            if args.task == 'mnist':
                new_obj = score_func(score_model, new_x).unsqueeze(-1)
                # new_obj.to(device)
            elif args.task == 'ackley':
                new_obj = score_func(ackley, new_x)
                # new_obj.to(device)
            else:
                print("UNKONWN Task Error!!!")

            # print(train_z.device)
            # print(new_z.device)
            # print(train_x.device)
            # print(new_x.device)
            # print(train_obj.device)
            # print(new_obj.device)

            train_z = torch.cat((train_z, new_z)).detach()
            train_x = torch.cat((train_x, new_x)).detach()
            train_obj = torch.cat((train_obj, new_obj)).detach()

            # update progress
            best_value = max(best_value, new_obj.item())
            best_observed.append(best_value)

            best_file.write(f"{best_value}\n") 

            state_dict = gp_model.state_dict()

            # print("global_iteration: ", global_iteration, "BO iteration: ", iteration, " new_obj = ", new_obj, " best_value = ", best_value)
            global_iteration = global_iteration + 1
            logger.info("global_iteration: "+ str(global_iteration) +
                        ", BO iteration: "+ str(iteration)+
                        # ", new_z= " + str(new_z.item()) +
                        # ", new_x = " + str(new_x.item()) +
                        ", new_obj = " + str(new_obj.item()) +
                        ", best_value = " + str(best_value))

            results_file.write(str(best_value)+"\n") 

    #
    # Open the file in write mode and save the numbers
    #with open(best_observed_path, "w") as file:
     #   for number in best_observed:
      #      file.write(f"{number}\n")  # Write each number on a new line

    results_file.close()

    return None

def plotting(args):

    # Initialize lists to store x and y coordinates
    x = []
    y = []

    # Open the file and read the data
    # results_doc = f'results_{time.strftime("%Y%m%d-%H%M")}'
    with open(args.results_file, 'r') as file:
        line_num = 1
        for line in file:
            # Split each line into x and y values
            # values = line.split()
            x.append(line_num)
            y.append(float(line))
            # y.append(float(line))
            line_num = line_num + 1

    # Plot the data
    plt.plot(x, y)

    # Add labels and title
    plt.xlabel('BO step')
    plt.ylabel('best observed value')
    plt.title(args.model)

    # Show the plot
    # plt.show()
    plt.savefig(os.path.join(results_path, args.task, args.model, str(args.train_seed), f'results_{time.strftime("%Y%m%d-%H%M")}.png'))

    return None

# Press the green button in the gutter to run the script.
if __name__ == '__main__':

    # read args
    parser = argparse.ArgumentParser()

    # Experiment
    parser.add_argument('--task', type=str, default='mnist', choices=['mnist', 'ackley'])
    parser.add_argument('--resume', type=str, default=None)

    # Data
    # parser.add_argument('--max_num_points', type=int, default=50)
    # parser.add_argument('--input_dim', type=int, default=100)

    # Model
    parser.add_argument('--model', type=str, default='vae_gp', choices=['vae_no', 'vae_base', 'vae_metric', 'vae_gp', 'vae_reweigh', 'vae_hippo'])

    # Train
    parser.add_argument('--train_seed', type=int, default=126)
    parser.add_argument('--num_epoch', type=int, default=100) #100
    parser.add_argument('--n_batch', type=int, default=25)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--num_restarts', type=int, default=5)
    parser.add_argument('--raw_samples', type=int, default=20)
    parser.add_argument('--num_vae_epoch', type=int, default=10)
    parser.add_argument('--vae_batch_size', type=int, default=10)
    parser.add_argument('--acq_func', type=str, default='ucb', choices=['ei', 'ucb'])
    # parser.add_argument('--latent_dim', type=int, default=10)
    parser.add_argument('--num_initial_data', type=int, default=10)


    # parser.add_argument('--lr', type=float, default=5e-4)
    # parser.add_argument('--num_steps', type=int, default=100000)
    # parser.add_argument('--print_freq', type=int, default=200)
    # parser.add_argument('--eval_freq', type=int, default=5000)
    # parser.add_argument('--save_freq', type=int, default=1000)

    # Plot
    # parser.add_argument('--plot_seed', type=int, default=123)
    # parser.add_argument('--plot_batch_size', type=int, default=16)
    # parser.add_argument('--plot_num_samples', type=int, default=30)
    # parser.add_argument('--plot_num_ctx', type=int, default=30)
    # parser.add_argument('--plot_num_tar', type=int, default=10)
    # parser.add_argument('--start_time', type=str, default=None)

    args = parser.parse_args()

    # read task-model specific config
    with open(f'configs/{args.task}/{args.model}.json', 'r') as f:
        config = json.load(f)

    # Convert config to command-line arguments
    for key, value in config.items():
        setattr(args, key, value)

    if args.task is not None:
        args.root = osp.join(results_path, args.task, args.model)
    else:
        args.root = osp.join(results_path, 'default', args.model)

    # Save to JSON file
    args_dict = vars(args)

    config_file = osp.join(args.root, "config.json")
    with open(config_file, 'w') as json_file:
        json.dump(args_dict, json_file, indent=4)

    # training
    training(args)

    # plot
    plotting(args)

