import hydra
import omegaconf
from omegaconf import DictConfig, OmegaConf
from space_gen import get_bounds
import logging
import time
import os
from mlflow import log_params
import torch
import json
from bbo.datasets.base import SimpleDataset
from bbo.datasets.utils import X_transform, Y_transform
import numpy as np
from hpob_handler import HPOBHandler

this_dir = os.path.abspath(os.path.dirname(__file__))
data_path = os.path.join(this_dir, "lamcts/hpob-data/")
surrogate_path = os.path.join(this_dir, "lamcts/saved-surrogates/")

def read_data_from_json(path):
    with open(path, 'r') as f:
        json_data = json.load(f)
    X = np.array(json_data['X'])
    y = np.array(json_data['y']).reshape(-1,1)
    return X, y

def train(args):
    device = "cuda:1" if torch.cuda.is_available() else "cpu"
    # device = "cpu"
    model_path = f"checkpoints/{args.similar}-{args.search_space_id}.pth"
    
    dim = args.dims
    lb, ub = get_bounds(args)
    
    lb = torch.tensor(lb, device=device)
    ub = torch.tensor(ub, device=device)
    
    algo_cfg = OmegaConf.load("configs/algorithms/np/transformer_opt.yaml") 
    algo = hydra.utils.instantiate(algo_cfg, dim=dim, lb=lb, ub=ub, device=device)
    pretrain_config = algo.pretrain_config

    data = None
    train_id2dataset = {}
    validation_id2dataset = {}
    search_spaces = []
    similar = args.similar
    
    if args.mode == "hpob":
        hpob_hdlr = HPOBHandler(root_dir=data_path, mode="v3", surrogates_dir=surrogate_path)
        data = hpob_hdlr.meta_train_data
        if similar == "similar":
            search_spaces = [args.search_space_id]
        elif similar == "combine":
            dim_task={
                2: ["5860", "5970"],
                3: ["4796"],
                6: ["5859", "5889"],
                8: ["5891"],
                9: ["7607", "7609"],
                16: ["5906", "5971"]
                }
            search_spaces = dim_task[dim]
    else:
        with open(f"lamcts/{args.mode}/meta_dataset.json", "r") as f:
            data = json.load(f)
        if similar == "similar" or similar == "mix-both":
            search_spaces = [args.search_space_id]
        elif similar == "unsimilar":
            search_spaces = list(data.keys())
            search_spaces.remove(args.search_space_id)
        elif similar == "combine":
            search_spaces = list(data.keys())
    for sid in search_spaces:
        dataset_ids = list(data[sid].keys())
        # num_elements = int(len(dataset_ids) * 0.7) 
        # train_did = dataset_ids[:num_elements]
        # valid_did = dataset_ids[num_elements:]
        train_did = dataset_ids[:]
        for did in train_did:
            X, Y = data[sid][did]["X"], data[sid][did]["y"]
            X, Y = torch.tensor(X, device=device), torch.tensor([Y], device=device).reshape(-1,1)
            X = X_transform(X, lb, ub, device)
            Y, mean, std = Y_transform(Y, device=device)
            train_id2dataset[did] = SimpleDataset(X, Y)
        # for did in valid_did:
        #     X, Y = data[sid][did]["X"], data[sid][did]["y"]
        #     X, Y = torch.tensor(X, device=device), torch.tensor([Y], device=device).reshape(-1,1)
        #     X = X_transform(X, lb, ub, device)
        #     Y, mean, std = Y_transform(Y, device=device)
        #     validation_id2dataset[did] = SimpleDataset(X, Y)
    if similar == "mix-both":
        mix_data_dir = f"data/generated_data/{args.search_space_id}"
        subdirs = ["similar", "unsimilar"]
        for subdir in subdirs:
            dir = os.path.join(mix_data_dir,subdir)
            for root, dirs, files in os.walk(dir):
                # num_elements = int(len(files) * 0.7)
                # for file in files[:num_elements]:
                for file in files[:]:
                    file_path = os.path.join(root, file)
                    X, y = read_data_from_json(file_path)
                    X, Y = torch.tensor(X, device=device), torch.tensor(y, device=device).reshape(-1,1)
                    X = X_transform(X, lb, ub, device)
                    Y, mean, std = Y_transform(Y, device=device)
                    train_id2dataset[f"{subdir}+{file}"] = SimpleDataset(X, Y)
                # for file in files[num_elements:]:
                #     file_path = os.path.join(root, file)
                #     X, y = read_data_from_json(file_path)
                #     X, Y = torch.tensor(X, device=device), torch.tensor(y, device=device).reshape(-1,1)
                #     X = X_transform(X, lb, ub, device)
                #     Y, mean, std = Y_transform(Y, device=device)
                #     validation_id2dataset[f"{subdir}+{file}"] = SimpleDataset(X, Y)
                    
    for state, id2dataset in zip(['Train', 'Val'], [train_id2dataset, validation_id2dataset]):
        print('------ {} Dataset info ------'.format(state))
        for k, dataset in id2dataset.items():
            print('{}: {}'.format(k, len(dataset)))
        print('--------------------------')
        
    model, val, best_model, best_val = algo.train(train_id2dataset, None)
    
    # save
    if  model_path is not None:
        save_path = model_path
        save_path = save_path if save_path.endswith('.pth') else save_path + '.pth'
        save_dir = os.path.dirname(save_path)

        if save_dir and not os.path.exists(save_dir):
            os.mkdir(save_dir)

        torch.save(best_model.state_dict(), save_path)
        print('Save best model, val: {}'.format(best_val))