import os
import random
import argparse
import numpy as np
import json
from tqdm import tqdm, trange
from matplotlib import pyplot as plt
import wandb
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
import gpytorch

from surrogate_models.dyhpo import FeatureExtractor, GPRegressionModel
from utils import Logger
from data.data_utils import get_dataset, HPOPlainSampler

def main(args):
    os.environ["WANDB_SILENT"] = "true"
    device = torch.device(f"cuda:{args.gpu_id}")
    torch.cuda.set_device(device)

    if args.debug:
        args.iteration = 1
        args.print_every = 1
        args.eval_every = 1

    # seed
    if args.seed is None:
        args.seed = random.randint(0, 9999)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # dataset
    meta_train_dataset, _ = get_dataset(args.data_dir, args.benchmark_name)
    meta_train_sampler = HPOPlainSampler(meta_train_dataset, batch_size=args.batch_size, device=device)

    # model and opt
    if args.benchmark_name == "lcbench":
        dim_x = 7
    elif args.benchmark_name == "odbench":
        dim_x = 4
    elif args.benchmark_name == "taskset":
        dim_x = 8
    elif args.benchmark_name == "pd1":
        dim_x = 4

    configuration = {
        'nr_layers': 2,
        'nr_initial_features': dim_x,
        'layer1_units': 64,
        'layer2_units': 128,
        'cnn_nr_channels': 4,
        'cnn_kernel_size': 3
    }    
    feature_extractor = FeatureExtractor(configuration).to(device)

    train_x = torch.ones(args.batch_size, args.batch_size).to(device)
    train_y = torch.ones(args.batch_size).to(device)
    likelihood = gpytorch.likelihoods.GaussianLikelihood().to(device)
    model = GPRegressionModel(train_x=train_x, train_y=train_y, likelihood=likelihood).to(device)
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model).to(device)
    opt = torch.optim.Adam([
        {'params': model.parameters(), 'lr': args.lr, 'weight_decay': args.wd},
        {'params': feature_extractor.parameters(), 'lr': args.lr, 'weight_decay': args.wd},
    ])
    sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.iteration)

    # logger
    logger = Logger(
        args.exp_name,
        save_dir=f"{args.save_dir}/{args.benchmark_name}/{args.exp_name}",
        save_only_last=True,
        print_every=args.print_every,
        save_every=args.save_every,
        total_step=args.iteration,
        print_to_stdout=True,
        wandb_project_name=f"greybox_dyhpo_pretrain",
        wandb_config=args
    )    
    logger.register_model_to_save(feature_extractor, "feature_extractor")
    logger.register_model_to_save(model, "model")
    logger.register_model_to_save(likelihood, "likelihood")    
    logger.start()
    
    # outer loop
    for step in range(1, args.iteration+1):
        x, t, lc, y = meta_train_sampler.sample()
        
        opt.zero_grad()
        projected_x = feature_extractor(x, t, lc)
        model.set_train_data(projected_x, y, strict=False)
        output = model(projected_x)
        loss = -mll(output, model.train_targets)
        loss.backward()
        if args.grad_norm > 0.:
            nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)
        opt.step()
        sch.step()

        logger.meter("meta_train", "loss", loss)        
        logger.step()
            
    logger.finish()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Parameter Processing')

    # seed
    parser.add_argument('--seed', type=int, default=42)

    # dir
    parser.add_argument('--save_dir', type=str, default="./pretrained_surrogate_results")
    parser.add_argument('--data_dir', type=str, default="./data")
    parser.add_argument('--exp_name', type=str, default="quicktune")

    # hparams for data
    parser.add_argument('--benchmark_name', type=str, default='lcbench')    

    # hparms for training
    parser.add_argument('--iteration', type=int, default=50000)
    parser.add_argument('--batch_size', type=int, default=512)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--wd', type=float, default=0.0)
    parser.add_argument('--grad_norm', type=float, default=0.0)

    # hparams for logger
    parser.add_argument('--print_every', type=int, default=100)
    parser.add_argument('--eval_every', type=int, default=10000)
    parser.add_argument('--save_every', type=int, default=2000)

    # gpus
    parser.add_argument('--gpu_id', type=int, default=0)
    parser.add_argument('--debug', action="store_true")
    args = parser.parse_args()

    main(args)
