import os
import dmc_envs
import torch
import robel
import gym
import numpy as np
from utilis.config import ARGConfig
from utilis.default_config import default_config, dmc_config
from model.algorithm import BAC
from utilis.Replaybuffer import ReplayMemory
import datetime
import itertools
from copy import copy
import shutil
import wandb
import csv

from torch.utils.tensorboard import SummaryWriter
import yaml
import ipdb

def train_loop(config, msg = "default"):
    env = gym.make(config.env_name)
    env.seed(config.seed)
    env.action_space.seed(config.seed)

    torch.manual_seed(config.seed)
    np.random.seed(config.seed)

    agent = BAC(env.observation_space.shape[0], env.action_space, config)

    result_path = './results/{}/{}/{}_{}_{}_{}_{}'.format(config.env_name, msg, 
                                                      datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), 
                                                      config.policy, config.seed, 
                                                      "autotune" if config.automatic_entropy_tuning else "",
                                                      config.msg)

    checkpoint_path = result_path + '/' + 'checkpoint'

    if not os.path.exists(result_path):
        os.makedirs(result_path)
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    with open(os.path.join(result_path, "config.log"), 'w') as f:
        f.write(str(config))

    current_path = os.path.dirname(os.path.abspath(__file__))
    files = os.listdir(current_path)

    files_to_save = ['main.py', 'model']
    ignore_files = [x for x in files if x not in files_to_save]
    shutil.copytree('.', result_path + '/code', ignore=shutil.ignore_patterns(*ignore_files))
    memory = ReplayMemory(config.replay_size, config.seed)

    total_numsteps = 0
    updates = 0
    best_reward = -1e6
    for i_episode in itertools.count(1):
        episode_reward = 0
        episode_steps = 0
        done = False
        state = env.reset()

        while not done:
            if config.start_steps > total_numsteps:
                action = env.action_space.sample() 
            else:
                action = agent.select_action(state) 

            if len(memory) > config.batch_size:
                for i in range(config.updates_per_step):
                        q_critic_1_loss, q_critic_2_loss, v_loss, policy_loss, ent_loss, alpha, q_grad, lambda_q_d, Q_value, Q_value_e, Q_value_d = agent.update_parameters(memory, config.batch_size, updates)
                        wandb.log(
                            data={
                                'loss/q_critic_1': q_critic_1_loss,
                                'loss/q_critic_2': q_critic_2_loss,
                                'loss/v_loss': v_loss,
                                'loss/policy_loss': policy_loss,
                                'loss/entropy_loss': ent_loss,
                                'parameter/alpha': alpha,
                                'parameter/q_grad': q_grad,
                            },
                            step = total_numsteps
                        )
                        wandb.log(
                            data={
                                'parameter/lambda_q_d': lambda_q_d.item(),
                                'Q value comparison': {'q_value': Q_value.item(), 'q_exploration': Q_value_e.item(), 'q_data': Q_value_d.item()}
                            },
                            step = total_numsteps
                        )
                        updates += 1
            next_state, reward, done, _ = env.step(action) # Step
            total_numsteps += 1
            episode_steps += 1
            episode_reward += reward

            if '_max_episode_steps' in dir(env):
                mask = 1 if episode_steps == env._max_episode_steps else float(not done)
            else:
                mask = 1 if episode_steps == 1000 else float(not done)

            memory.push(state, action, reward, next_state, mask) # Append transition to memory

            state = next_state

        if total_numsteps > config.num_steps:
            break


        wandb.log(
            data={
                'reward/train_reward': episode_reward
            },
            step = total_numsteps
        )
        print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {}".format(i_episode, total_numsteps, episode_steps, round(episode_reward, 2)))

        # test agent
        if i_episode % 10 == 0 and config.eval is True:
            avg_reward = 0.
            avg_success = 0.
            for _  in range(config.eval_episodes):
                state = env.reset()
                episode_reward = 0
                done = False
                while not done:
                    action = agent.select_action(state, evaluate=True)

                    next_state, reward, done, info = env.step(action)
                    episode_reward += reward

                    state = next_state
                avg_reward += episode_reward
                if 'score/success' in info.keys():
                    avg_success += float(info['score/success'])

            avg_reward /= config.eval_episodes
            avg_success /= config.eval_episodes

            if config.save_checkpoint == True:
                if avg_reward >= best_reward:
                    best_reward = avg_reward
                    agent.save_checkpoint(checkpoint_path, int(total_numsteps))

            wandb.log(
                data = {
                    'reward/test_avg_reward': avg_reward,
                    'reward/success_rate': avg_success
                },
                step = total_numsteps
            )

            print("----------------------------------------")
            print("Env: {}, Test Episodes: {}, Avg. Reward: {}".format(config.env_name, config.eval_episodes, round(avg_reward, 2)))
            print("----------------------------------------")

    env.close() 



def main():
    arg = ARGConfig()
    arg.add_arg("env_name", "DogRun-v0", "Environment name")
    arg.add_arg("device", "0", "Computing device")
    arg.add_arg("policy", "Gaussian", "Policy Type: Gaussian | Deterministic (default: Gaussian)")
    arg.add_arg("tag", "default", "Experiment tag")
    arg.add_arg("algo", "BAC", "choose algorithm (BAC, SAC, TD3, TD3-BEE)")
    arg.add_arg("start_steps", 10000, "Number of start steps")
    arg.add_arg("automatic_entropy_tuning", True, "Automaically adjust α (default: True)")
    arg.add_arg("quantile", 0.7, "the quantile regression for value function (default: 0.9)")
    arg.add_arg("seed", 123456, "experiment seed")
    arg.add_arg("lambda", "fixed_0.5", "method to calculated lambda, fixed_x, ada, min, max")
    arg.add_arg("des", "", "short description for the experiment")
    arg.add_arg("num_steps", 1000001, "total number of steps")
    arg.add_arg("save_checkpoint", False, "save checkpoint or not")
    arg.add_arg("replay_size", 1000000, "size of replay buffer")
    arg.parser()

    config = dmc_config  
    config.update(arg)
    #* load config file
    config_path = os.path.join("utilis", "configs", "quadrupedwalk.log")
    config = ARGConfig().load_saved(config_path)
    algorithm = config.algo

    if 'WANDB_RUN_ID' in os.environ:
        sweep_params = yaml.load(open(os.environ['WANDB_SWEEP_PARAM_PATH'], "r").read(), Loader=yaml.FullLoader)
        lambda_method = sweep_params['lambda']['value']
        quantile = sweep_params['quantile']['value']
        config["lambda"] = lambda_method = sweep_params['lambda']['value']
        config["quantile"] = sweep_params['quantile']['value']

    experiment_name = "{}-{}-{}-{}-q{}-{}|{}".format(
        algorithm, 
        config['env_name'], 
        str(config["seed"]), 
        config["lambda"].replace("/","_"), 
        str(config["quantile"]),
        config["automatic_entropy_tuning"],
        config["des"]
    )

    run_id = "{}_{}_{}_{}_q{}_{}_{}".format(
        algorithm, 
        config['env_name'],
        str(config["seed"]), 
        config["lambda"].replace("/","_"), 
        str(config["quantile"]),
        config["automatic_entropy_tuning"],
        datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    )

    run = wandb.init(
        project = config["project_name"],
        config = {
            "env_name": config['env_name'],
            "automatic_entropy_tuning": config["automatic_entropy_tuning"],
            "quantile": config["quantile"],
            "algorithm" : algorithm,
            "seed": config["seed"],
            "lambda": config["lambda"],
            "num_steps": config["num_steps"]
        },
        name = experiment_name,
        id = run_id,
        save_code = False
    )

    show_config = wandb.Table(
        columns= ["env_name", "algorithm", "seed", "lambda", "quantile",  "automatic_entropy_tuning", "time"],
        data=[[config['env_name'],
        algorithm,
        str(config["seed"]), 
        config["lambda"].replace("/","_"), 
        str(config["quantile"]),
        config["automatic_entropy_tuning"],
        datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")]]
    )
    run.log({"Settings": show_config})

    print(f">>>> Training {algorithm} on {config.env_name} environment, on {config.device}")
    train_loop(config, msg=algorithm)
    wandb.finish()


if __name__ == "__main__":
    main()
