import os.path
import sys

from Utils import environments
import gym
import d4rl
import time
import datetime
import torch
import itertools
import numpy as np
from loguru import logger

import argparse

from diffusion_predictor.Predictor_config import DMBP_config, update_DMBP_config
from diffusion_predictor.Predictor_net import Diffusion_instance
from diffusion_predictor.render_img import MuJoCoRenderer
from Utils.Batch_Buffer import batch_buffer
from Utils.seed import setup_seed, seed_env

def train_DMBP(args):
    env_name, dataset_name, seed = args.env_name, args.dataset, args.seed

    config = update_DMBP_config(env_name, DMBP_config, args)

    device = torch.device(args.device)
    setup_seed(seed)
    
    Buffer = batch_buffer(config, env_name, dataset_name, device, input_type=config["attack_element"], buffer_mode='normal', buffer_normalization=False)
    detector_ambient = Diffusion_instance(Buffer.obs_dim, Buffer.act_dim, device, config, "ambient")
    detector_naive   = Diffusion_instance(Buffer.obs_dim, Buffer.act_dim, device, config, "naive")

    denoiser = Diffusion_instance(Buffer.obs_dim, Buffer.act_dim, device, config, "naive")

    if not config["load_model_path"]:
        Training_steps_per_epoch = config["steps_per_epoch"]
        Epoches = config["total_epoch"]

        for step_k in range(Epoches):
            detector_naive.train(Buffer, Training_steps_per_epoch, config['batch_size'], None, 0.1, "naive")
            denoiser.train(Buffer, Training_steps_per_epoch, config['batch_size'], detector_naive, 0.5)
            if config["debug"] == True:
                Buffer.simple_detect(denoiser, detector_naive)
            attack_element = config["attack_element"]

            if step_k == Epoches - 1:
                detector_naive.save_model(f"models_cp/{env_name}_{dataset_name}_detector_{attack_element}.pkl")
                denoiser.save_model(f"models_cp/{env_name}_{dataset_name}_denoiser_{attack_element}.pkl")
    else:
        model_path = config["load_model_path"]
        dataset_name = config["dataset"]
        attack_element = config["attack_element"]
        detector_naive.load_model(f"{model_path}_detector_{attack_element}.pkl")
        denoiser.load_model(f"{model_path}_denoiser_{attack_element}.pkl")

    Buffer.get_dataset(denoiser, detector_naive, f"datasets_cp/{env_name}_{dataset_name}_{args.dn_stp}", 0.7)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_name', default='hopper', type=str)
    parser.add_argument('--dataset', default='medium-expert', type=str)
    parser.add_argument('--device', default="cuda:0", type=str)
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--load_model_path', default=None, type=str)
    parser.add_argument('--debug', default=False, type=bool)
    parser.add_argument('--detector_training_loops', default=1000000, type=int)
    parser.add_argument('--dataset_path', default=None, type=str)

    # Configs need inputs
    parser.add_argument('--attack_element', default="transition", type=str)

    parser.add_argument('--dn_stp', default=100, type=int)

    args = parser.parse_args()
    train_DMBP(args)