

import gym
import pybullet_envs
import os
import numpy as np
from collections import deque
import torch
import wandb
import argparse
# from buffer import ReplayBuffer
import glob
from utils import save, collect_random, parse_args, make_batch_env,\
     collect_uniform_random_buffer_lock, collect_offline_buffer, ReplayBuffer
import random
from collections import deque
from agent import CQLAgent
import pickle

def train(config):
    np.random.seed(1234)
    random.seed(1234)
    torch.manual_seed(1234)
    env, eval_env = make_batch_env(config)
    np.random.seed(config.seed)
    random.seed(config.seed)
    torch.manual_seed(config.seed)
    args = config
    
    # env.seed(config.seed)
    # env.action_space.seed(config.seed)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    buffer = ReplayBuffer(env.observation_space.shape, 
                             env.action_space.n, 
                             int(args.num_episodes)*6+1, 
                             args.batch_size, 
                             device,
                             recent_size=0)

    # offline_path = 'mixed_distribution0.5'
    # offline_buffer = collect_offline_buffer(args, env, args.num_episodes, epsilon = -1)
    # offline_buffer.save(offline_path)
    offline_path = 'mixed_distribution'
    offline_buffer = ReplayBuffer(env.observation_space.shape, 
                             env.action_space.n, 
                             int(args.num_episodes)*2+1, 
                             args.batch_size, 
                             torch.device("cpu"),
                             recent_size=0)
    offline_buffer.load(offline_path)
    # with open(offline_path,'wb') as fb:
    #     pickle.dump(offline_buffer, fb)
    # with open(offline_path, 'rb') as fb:
    #     offline_buffer = pickle.load(fb)
    num_runs = int(args.num_episodes/ args.horizon/ args.num_envs)

    # buffer = collect_offline_buffer(config, env, config.num_episodes*100, epsilon = -1)
    # buffer = collect_offline_buffer(config, env, config.num_episodes*100, epsilon = -2, buffer = buffer)

    # buffer = collect_uniform_random_buffer_lock(config, env, config.num_episodes*100)
    #warm start from offline buffer
    buffer.add_from_buffer(offline_buffer, 1000000)

    steps = 0
    average10 = deque(maxlen=10)
    total_steps = 0
    for _ in range(1):
    # with wandb.init(project="CQL", name="CQL-comblock", config=config):
        
        agent = CQLAgent(state_size=env.observation_space.shape,
                         action_size=env.action_space.n,
                         device=device)

        wandb.watch(agent.network, log="gradients", log_freq=10)

        for i in range(1, num_runs):
            # print(f'state: {state}')
            episode_steps = 0
            rewards = np.zeros((config.num_envs,1))
            taken_actions = []

            # #collect samples
            # for h in range(args.horizon):
            h = random.randint(0, args.horizon-1)
            t = 0
            state = env.reset()
            while t < h:
                action = agent.get_actions(state, epsilon = 0)
                next_state, reward, done, _ = env.step(action)
                buffer.add_batch( state,action,reward,next_state,done, args.num_envs)
                buffer.add_from_buffer(offline_buffer, args.num_envs)
                state = next_state
                t += 1
            action = np.random.randint(0, args.num_actions, args.num_envs)
            next_state, reward, done, _ = env.step(action)
            buffer.add_batch( state,action,reward,next_state,done, args.num_envs)
            buffer.add_from_buffer(offline_buffer, args.num_envs)



            #update and evaluation
            state = env.reset()
            reached = 0
            while True:
                if np.mean(env.get_state()) < 2:
                    reached += 1
                action = agent.get_actions(state, epsilon = 0)
                # action = agent.get_action(state.flatten(), epsilon=0)
                # taken_actions.append(action[0])
                taken_actions.append(action)
                steps += 1
                # print(action)
                next_state, reward, done, _ = env.step(action)
                # buffer.add_batch( state,action,reward,next_state,done, args.num_envs)
                # buffer.add_from_buffer(offline_buffer, args.num_envs)
                for _ in range(1):
                    loss, cql_loss, bellmann_error = agent.learn(buffer.sample(batch_size = 512))
                state = next_state
                rewards += reward
                episode_steps += 1
                if done:
                    break
            rewards = np.mean(rewards)
            # print(f'taken_actions: {taken_actions}')
            # state00 = torch.Tensor([[1,0,0,1,0,0,0,0]]).to(device)
            # state01 = torch.Tensor([[0,1,0,1,0,0,0,0]]).to(device)
            # state10 = torch.Tensor([[1,0,0,0,1,0,0,0]]).to(device)
            # state11 = torch.Tensor([[0,1,0,0,1,0,0,0]]).to(device)
            # print(f'A values for state 00 {agent.network(state00)}')
            # print(f'A values for state 01 {agent.network(state10)}')
            # print(f'A values for state 10 {agent.network(state10)}')
            # print(f'A values for state 11 {agent.network(state11)}')

            

            average10.append(rewards)
            total_steps += episode_steps
            print("Episode: {} | Reward: {} | Q Loss: {} | Steps: {}".format(i, rewards, loss, steps,))
        
            wandb.log({"Reward": rewards,
                       "Average10": np.mean(average10),
                       'Reached': reached,
                       "Steps": total_steps,
                       "Q Loss": loss,
                       "CQL Loss": cql_loss,
                       "Bellmann error": bellmann_error,
                       "Steps": steps,
                       "Episode": i})
            if np.mean(average10) == 1:
                break

            # if (i %10 == 0) and config.log_video:
            #     mp4list = glob.glob('video/*.mp4')
            #     if len(mp4list) > 1:
            #         mp4 = mp4list[-2]
            #         wandb.log({"gameplays": wandb.Video(mp4, caption='episode: '+str(i-10), fps=4, format="gif"), "Episode": i})

            if i % 5 == 0:
                save(config, save_name="CQL-DQN", model=agent.network, wandb=wandb, ep=0)
            # print(env.opt_a)
            # print(env.opt_b)

if __name__ == "__main__":
    # os.environ['WANDB_MODE'] = 'offline'
    config = parse_args()
    config.run_name = 'CQL-comblock'
    with wandb.init(
            project= "discounted fqi",
            job_type="ratio_search",
            config=vars(config),
            name= 'hybrid-fqi'):
        train(config)

    # with open('epsilon_distribution', 'rb') as fb:
    #     offline_buffer = pickle.load(fb)
