import os
import os.path as osp
import sys
import time
import argparse
import importlib

import tqdm
import yaml

import ipdb
import numpy as np

import torch

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from attrdict import AttrDict

from npf.utils.launch import launch
from npf.utils.paths import results_path

from nxcl.config import save_config
from bayeso import acquisition
import logging

logger = logging.getLogger()

def range_normalize(values, min_values, max_values, log=False):
    if log:
        values = np.log(values)
        min_values = np.log(min_values)
        max_values = np.log(max_values)
    return -2 + (values - min_values) * 4 / (max_values - min_values)

def range_denormalize(values, min_values, max_values, log=False):
    if log:
        values = np.exp(values)
    deno =  min_values + (values + 2) * (max_values - min_values) / 4
    return deno

def load_checkpoint(model, checkpoint, logger, args):
    model_ckpt = torch.load(checkpoint, map_location="cuda")["model"]
    state_dict = model.state_dict()

    for k in list(model_ckpt.keys()):
        if state_dict[k].shape != model_ckpt[k].shape:
            logger.info(f"- \"{k}\" not loaded (shape not compatible)")
            model_ckpt.pop(k)
    
    model.load_state_dict(model_ckpt, strict=False)
    return model

def setup_argparse(parser: argparse.ArgumentParser):
    parser.add_argument('--max_num_points', type=int, default=50)
    parser.add_argument('--min_num_points', type=int, default=5)
    parser.add_argument('-c', '--checkpoint', type=str, required=True)
    parser.add_argument('-pc', '--plot_checkpoint', type=str)
    parser.add_argument('--plot_dir', type=str, default='./plots')

    parser.add_argument('--train_seed', type=int, default=0)
    parser.add_argument('--train_batch_size', type=int, default=16)
    parser.add_argument('--train_num_samples', type=int, default=4)
    parser.add_argument('--train_num_bs', type=int, default=-1)
    parser.add_argument('--lr', type=float, default=5e-4)
    parser.add_argument('--num_steps', type=int, default=100000)
    parser.add_argument('--print_freq', type=int, default=200)
    parser.add_argument('--eval_freq', type=int, default=5000)
    parser.add_argument('--save_freq', type=int, default=1000)
    parser.add_argument('--bo_num_samples', type=int, default=10)

    parser.add_argument('--mode', choices=['bo', 'plot'], default='bo')
    parser.add_argument('--time_comparison', action='store_true', default=False)

    parser.add_argument('--acquisition', choices=['ucb', 'ei'], default='ucb')
    parser.add_argument('--num_task', type=int, default=10)
    parser.add_argument('--num_iter', type=int, default=200)
    parser.add_argument('--num_initial_design', type=int, default=10)
    parser.add_argument('--seed', type=int, default=1)

    parser.add_argument('--eval_num_batches', type=int, default=3000)
    parser.add_argument('--eval_batch_size', type=int, default=16)
    parser.add_argument('--eval_num_samples', type=int, default=50)

    parser.add_argument('--train_kernel', type=str, default='rbf')
    parser.add_argument('--eval_kernel', type=str, default='rbf')
    parser.add_argument('--t_noise', type=float, default=None)
    parser.add_argument('--bo_seed', type=int, default=0)
    parser.add_argument('--num_init', type=int, default=20)

def build_model(cfg):
    model_name = cfg.model.name
    try:
        module = importlib.import_module(f"npf.models.{model_name}")
        model_cls = getattr(module, model_name.upper())
    except Exception as e:
        raise ValueError(f'Invalid model {model_name}')
    
    print(cfg)
    model = model_cls(**{k: v for k, v in cfg.model.items() if k != 'name'})
    model.cuda()
    return model

def load_data():
    x, y = [], []
    for i in range(4):
        data = np.load(f'results_cnn_{i}.npy', allow_pickle=True)
        string_data = f"{data}"
        data = eval(string_data)
        x += data['x']
        y += [-i for i in data['y']]
    
    x = np.array(x)
    y = np.array(y)
    
    x = range_normalize(x, np.min(x, axis=0),  np.max(x, axis=0), log=True)
    y = range_normalize(y, np.min(y), np.max(y))
    
    return x, y, np.min(x, axis=0), np.min(y), np.max(x, axis=0), np.max(y)

def bayesian_optimization(args, cfg, x, y, X_MIN, Y_MIN, X_MAX, Y_MAX):
    
    torch.manual_seed(args.train_seed)
    torch.cuda.manual_seed(args.train_seed)
    np.random.seed(args.bo_seed)

    model = build_model(cfg)
    model_name = model.__class__.__name__.lower()
    args.model = model_name

    logger.info(f"Load checkpoint from {args.checkpoint}")
    model = load_checkpoint(model, args.checkpoint, logger, args)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    root = osp.join(results_path, 'cnn_bo')
    if not osp.isdir(root):
        os.makedirs(root)

    regrets = np.zeros((args.num_task, args.num_iter + 1))
    times = np.zeros((args.num_task, args.num_iter + 1))
        
    for i_seed in tqdm.tqdm(range(1, args.num_task + 1), unit='task', ascii=True):
        seed_ = args.bo_seed * i_seed 
        
        torch.manual_seed(seed_)
        torch.cuda.manual_seed(seed_)
        
        indices = np.random.choice(len(x), args.num_init, replace=False)
        X_train = torch.tensor(x[indices], dtype=torch.float32, device=device)
        Y_train = torch.tensor(y[indices], dtype=torch.float32, device=device).unsqueeze(-1)

        X_test = torch.tensor(x, dtype=torch.float32, device=device)

        max_values = [range_denormalize(Y_train.min().item(), Y_MIN, Y_MAX, log=False)]
        time_list = [0]
        start_time = time.time()

        batch = AttrDict()
        batch.xc = X_train.unsqueeze(0).to(device)
        batch.yc = Y_train.unsqueeze(0).to(device)
        
        true_best_value = range_denormalize(np.min(y), Y_MIN, Y_MAX, log=False)

        for i in tqdm.tqdm(range(args.num_iter)):
            model.eval()
            with torch.no_grad():
                if model_name in ['tnpd', 'canp']:
                    py = model.predict(xc=batch.xc, yc=batch.yc, xt=X_test.unsqueeze(0))            
                else:
                    py = model.predict(xc=batch.xc, yc=batch.yc, xt=X_test.unsqueeze(0), num_samples=args.bo_num_samples)
                mu_, sigma_ = py.mean.squeeze(0), py.scale.squeeze(0)

            if mu_.dim() == 4:
                var = sigma_.pow(2).mean(dim=0) + mu_.pow(2).mean(dim=0) - mu_.mean(dim=0).pow(2)
                sigma_ = var.sqrt().squeeze(0)
                mu_ = mu_.mean(dim=0).squeeze(0)
        
            mu_ = mu_.cpu().numpy()  # (10*num_points, 1)
            sigma_ = sigma_.cpu().numpy()  # (10*num_points, 1)
            
            acq_vals = -1.0 * acquisition.ei(np.ravel(mu_), np.ravel(sigma_), Y_train.cpu().numpy())
            ind_ = np.argmin(acq_vals)
            
            x_new = X_test[ind_].unsqueeze(0)  # (1, x_dim)
            y_new = torch.tensor([y[ind_]], dtype=torch.float32, device=device)

            d_x_new = range_denormalize(X_train[np.argmin(Y_train.cpu().numpy())].cpu().numpy(), X_MIN, X_MAX, log=True)
            d_y_new = range_denormalize(Y_train.min().cpu().numpy(), Y_MIN, Y_MAX, log=False)

            if y_new == Y_train.min():
                print(d_x_new)

            print(f"Best value: {d_x_new} at {d_y_new}")
            
            X_train = torch.cat((X_train, x_new))
            Y_train = torch.cat((Y_train, y_new.unsqueeze(0)))

            batch.xc = X_train.unsqueeze(0).to(device)
            batch.yc = Y_train.unsqueeze(0).to(device)
            
            current_min = batch.yc.min()
            max_values.append(-range_denormalize(current_min.cpu().numpy(), Y_MIN, Y_MAX, log=False))
            time_list.append(time.time() - start_time)
        
        regrets[i_seed-1,:] = -true_best_value - np.array(max_values)  
        times[i_seed-1,:] = np.array(time_list)

    exp_results = np.array(regrets)
    np.save(osp.join(f'cnn_bo_results_{model_name}_{args.bo_seed}.npy'), exp_results)
    
    print(f"Best value: {Y_train.min().item()} at {X_train[Y_train.argmin()].cpu().numpy()}")

def normalize(values):
    min_val = np.min(values)
    max_val = np.max(values)
    return (values - min_val) / (max_val - min_val)

def plot():
    all_kernels = ['rbf']
    kernel_names = ['RBF']
    all_models = ["anp", "banp", "canp", "mpanp", "tnpd", "danp"]
    model_names = ["ANP", "BANP", "CANP", "MPANP", "TNP", "DANP"]
    colors = ['#962FbF', '#00B159', '#FFCC5C', '#3366FF', '#FF4D00', '#D62976']

    fig, axes = plt.subplots(1, len(all_kernels) * 2, figsize=(12, 6))

    if len(all_kernels) == 1:
        axes = np.expand_dims(axes, axis=0)

    for k_id, (kernel, kernel_name) in enumerate(zip(all_kernels, kernel_names)):
        ax_mean = axes[0, k_id * 2]
        ax_cumulative = axes[0, k_id * 2 + 1]
        for i, (model, model_name) in enumerate(zip(all_models, model_names)):
            seed = 0 if model == 'danp' else 3
            regrets = np.load(logfile, allow_pickle=True)
            regrets = np.stack([regrets[j] for j in range(len(regrets))], axis=0)

            mean_regret = np.mean(regrets, axis=0)
            std_regret = np.std(regrets, axis=0)
            cumulative_regret = np.cumsum(mean_regret)
            cumulative_std_regret = np.sqrt(np.cumsum(std_regret**2))
            steps = np.arange(regrets.shape[1])

            normalized_mean_regret = normalize(mean_regret)
            normalized_cumulative_regret = normalize(cumulative_regret)
            normalized_std_regret = normalize(std_regret)
            normalized_cumulative_std_regret = normalize(cumulative_std_regret)

            ax_mean.plot(steps, normalized_mean_regret, label=model_name, color=colors[i], lw=2.0)
            ax_mean.fill_between(
                steps,
                normalized_mean_regret - 0.1 * normalized_std_regret,
                normalized_mean_regret + 0.1 * normalized_std_regret,
                alpha=0.1,
                color=colors[i])
            ax_mean.set_facecolor('white')
            ax_mean.grid(ls=':', color='gray', linewidth=0.5)
            ax_mean.spines['bottom'].set_color('black')
            ax_mean.spines['bottom'].set_linewidth(0.8)
            ax_mean.spines['top'].set_color('black')
            ax_mean.spines['top'].set_linewidth(0.8)
            ax_mean.spines['right'].set_color('black')
            ax_mean.spines['right'].set_linewidth(0.8)
            ax_mean.spines['left'].set_color('black')
            ax_mean.spines['left'].set_linewidth(0.8)

            ax_cumulative.plot(steps, normalized_cumulative_regret, label=model_name, color=colors[i], lw=2.0)
            ax_cumulative.fill_between(
                steps,
                normalized_cumulative_regret - 0.1 * normalized_cumulative_std_regret,
                normalized_cumulative_regret + 0.1 * normalized_cumulative_std_regret,
                alpha=0.1,
                color=colors[i])
            ax_cumulative.set_facecolor('white')
            ax_cumulative.grid(ls=':', color='gray', linewidth=0.5)
            ax_cumulative.spines['bottom'].set_color('black')
            ax_cumulative.spines['bottom'].set_linewidth(0.8)
            ax_cumulative.spines['top'].set_color('black')
            ax_cumulative.spines['top'].set_linewidth(0.8)
            ax_cumulative.spines['right'].set_color('black')
            ax_cumulative.spines['right'].set_linewidth(0.8)
            ax_cumulative.spines['left'].set_color('black')
            ax_cumulative.spines['left'].set_linewidth(0.8)

        ax_mean.set_xlabel('Iterations', fontsize=20)
        ax_cumulative.set_xlabel('Iterations', fontsize=20)

        min_step, max_step = steps.min(), steps.max()
        ax_mean.set_xlim(min_step, max_step)
        ax_cumulative.set_xlim(min_step, max_step)

    axes[0, 0].set_ylabel('Normalized Regret', fontsize=20)
    axes[0, 1].set_ylabel('Normalized Cumulative Regret', fontsize=20)

    plt.subplots_adjust(bottom=0.24, wspace=0.2)
    plt.suptitle('CNN BO', fontsize=24, fontweight='bold')

    handles, labels = ax_mean.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    fig.legend(
        by_label.values(), by_label.keys(),
        loc="lower center", fancybox=True, shadow=True, ncol=6, fontsize=16, facecolor='white'
    )
    save_dir = osp.join('gp_plot.pdf')
    plt.savefig(save_dir, dpi=500, bbox_inches='tight', format='pdf')

def train(args, cfg, logger, save_dir, link_output_dir):
    save_config(cfg, save_dir / "config.yaml")

    torch.manual_seed(args.train_seed)
    torch.cuda.manual_seed(args.train_seed)

    exp_sub_path = os.path.join(os.path.basename(__file__.split(".")[0]), 'cnn_bo')
    exp_path = link_output_dir(exp_sub_path)
    logger.info(f"Experiment path: \"{exp_path}\"")
    line = ' '.join(sys.argv)
    logger.info(f"code: {line}")

    args.root = save_dir

    if args.mode == 'bo':
        x, y, X_MIN, Y_MIN, X_MAX, Y_MAX = load_data()
        print(f'True best: {np.min(y)}')
        bayesian_optimization(args, cfg, x, y, X_MIN, Y_MIN, X_MAX, Y_MAX)
    else:
        raise NotImplementedError
    logger.info("end")

if __name__ == "__main__":
    code = launch(
        train,
        setup_argparse,
        aliases={},
    )
    exit(code)