import os
import yaml
import argparse
from datetime import datetime

current_path = os.getcwd()

from fqf_iqn_qrdqn.env import make_pytorch_env
from fqf_iqn_qrdqn.agent import QRDQNAgent, EEAgent, MYAgent, PQRAgent

import ray
from ray import tune
import wandb
from ray.tune.integration.wandb import wandb_mixin

# import sys
# sys.path.append('./')

import time
import numpy as np
import torch
import gym
import CustomNChain
import cv2


class WarpFramePyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        """
        Warp frames to 84x84 as done in the Nature paper and later work.
        :param env: (Gym Environment) the environment
        """
        gym.ObservationWrapper.__init__(self, env)
        self.width = 84
        self.height = 84
        self.observation_space = gym.spaces.Box(
            low=0, high=255, shape=(1, self.height, self.width),
            dtype=env.observation_space.dtype)

    def observation(self, state=None):
        """
        returns the current observation from a frame
        :param frame: ([int] or [float]) environment frame
        :return: ([int] or [float]) the observation
        """
        frame = self.get_obs()

        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        frame = cv2.resize(
            frame, (self.width, self.height), interpolation=cv2.INTER_AREA)

        return frame[None, :, :]

def print_Q_table(self):
    # Make Q(s,a) table
    env2 = gym.make('CustomNChain-v0', small=small, large=large, std=std)
    env2= WarpFramePyTorch(env2)
    state = env2.reset(start_state=0)
    print("#########################################")
    for i in range(5):
        self.online_net.eval()       
        observation = torch.ByteTensor(np.expand_dims(state, axis=0)).cuda().float() /255
        with torch.no_grad():
            Q_value = self.online_net.calculate_q(observation).tolist()

        print("Q(s{},a) :".format(i), Q_value)
        state = env2.step(0)[0]
    print("#########################################")

def run_test(self):
    print(self.steps)

def run_NChain(self):
    while True:
        self.train_episode()
        if self.episodes % 100 ==0:
            self.evaluate()
            iteration = self.episodes // 100
            print('')
            print("#########################################")
            print('{} evaluation    : {} steps '.format(iteration, self.steps))
            print_Q_table(self)    
        if self.steps > 3001:
            break
def run_NChain(agent):
    while True:
        agent.train_episode()
        if agent.episodes:
            pass


def run_main():
    agent = QRDQN()
    run_NChain(agent)





@wandb_mixin
def run(config):
    if isinstance(args.cuda, int):
        device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu')
        torch.cuda.set_device(device)

    # Create environments.
    try:    
        env = gym.make('CustomNChain-v0', small=args.small, large=args.large, std=args.std)
    except:
        import gym
        import CustomNChain

        env = gym.make('CustomNChain-v0', small=args.small, large=args.large, std=args.std)
   

    env = WarpFramePyTorch(env)
    test_env = gym.make('CustomNChain-v0', small=args.small, large=args.large, std=args.std)
    test_env = WarpFramePyTorch(test_env)

    # Specify the directory to log.
    #name = args.config.split('/')[-1].rstrip('.yaml')
    time_local = datetime.now().strftime("%Y%m%d-%H%M%S")

    batch_size = config['batch_size']
    seed = config['seed']
    lr = config['lr']

    log_dir1 = os.path.join(
        'logs', 'CustomNChain', f'qrdqn-batchsize{batch_size}-lr{lr}-seed{seed}-{time_local}')

    log_dir2 = os.path.join(
        'logs', 'CustomNChain', f'EE-batchsize{batch_size}-lr{lr}-seed{seed}-{time_local}')

    log_dir3 = os.path.join(
        'logs', 'CustomNChain', f'MY-batchsize{batch_size}-lr{lr}-seed{seed}-{time_local}')
    
    log_dir4 = os.path.join(
        'logs', 'CustomNChain', f'PQR-batchsize{batch_size}-lr{lr}-seed{seed}-{time_local}')

    # Create the agent and run.
    if args.agent == 'QRDQN':
        print("#### QRDQN START ! ####")
        agent = QRDQNAgent(
            env=env, test_env=test_env, log_dir=log_dir1, seed=seed,
            cuda=isinstance(args.cuda, int), batch_size=batch_size, **config_yaml

    elif args.agent == 'EE':
        print("#### EE START ! ####")
        agent = EEAgent(
            env=env, test_env=test_env, log_dir=log_dir2, seed=seed,
            cuda=isinstance(args.cuda, int), batch_size=batch_size, **config_yaml)
        agent.run()
    elif args.agent == 'MY':
        agent = MYAgent(
            env=env, test_env=test_env, log_dir=log_dir3, seed=seed,
            cuda=isinstance(args.cuda, int), batch_size=batch_size, **config_yaml)
        agent.run()
    elif args.agent == 'PQR':
        agent = PQRAgent(
            env=env, test_env=test_env, log_dir=log_dir4, seed=seed,
            cuda=isinstance(args.cuda, int), batch_size=batch_size, **config_yaml)
        agent.run()
    else:
        raise argparse.ArgumentError

def parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--config', type=str, default=os.path.join('config', 'NChain.yaml'))
    parser.add_argument('--agent', type=str, default='QRDQN')
    parser.add_argument('--small', type=int, default=1)
    parser.add_argument('--large', type=int, default=10)
    parser.add_argument('--std', type=int, default=0)
    parser.add_argument('--cuda', type=int, default=None)
    args,_ = parser.parse_known_args()
    
    return args

if __name__ == '__main__':    
    start = time.time()
    ray.init()

    args = parser()
    ray_config = {
        'batch_size' : tune.grid_search([32]),
        'lr' : tune.grid_search([5e-5]), #25e-5
        'seed' : tune.grid_search([234]), #234,567
        'args': args
        }

    path = os.path.join('config', 'NChain.yaml')
    with open(path) as f:
        config_yaml = yaml.load(f, Loader=yaml.SafeLoader)

    if config_yaml['double_q_learning']:
        group_name = "double_q_QRDQN"
    else:
        group_name = args.agent

    ray_config.update({    
       'wandb': {
            "project":"DRL_NChain",
            "group" : group_name
            }
        })
    print("GROUP NAME :", group_name)

    analysis = tune.run(run, config = ray_config,
         resources_per_trial={"cpu":2, "gpu":1},
         name= "NChain-experiment",
         )

    finish = time.time()
    print("Time :", finish - start)
