from turtle import done
import numpy as np
import torch
import os
import math
import gym
import sys
import random
import time
import json
import copy

from collections import deque

import multiprocessing

from utils import parse_args, sample, set_seed_everywhere, ReplayBufferSingle, make_batch_env

from briee.lsvi_ucb_cce import LSVI_UCB
from briee.lsvi_ucb_oracle import LSVI_UCB_ORA
from briee.rep_learn import RepLearn

from main import make_rep_learner

def evaluate(env, agent, exploiter, args):
    returns = np.zeros((args.num_eval,1))
    
    obs = env.reset()
    for h in range(args.horizon):
        action = agent.act_batch(obs, h)
        exploit = exploiter.act_batch(obs, h)
        action[:,0] = exploit
        next_obs, reward, done, _ = env.step(action)
        obs = next_obs
        returns += reward

    return np.mean(returns)


def main(args):
    set_seed_everywhere(args.seed)

    env, eval_env = make_batch_env(args)

    

    num_actions = env.action_space.n

    device = torch.device("cpu")

    if not os.path.exists(args.temp_path):
        os.makedirs(args.temp_path)

    env.trans_prob_matrices = np.load(os.path.join(args.temp_path, "trans.npy"))
    env.reward_matrices = np.load(os.path.join(args.temp_path, "reward.npy"))
    eval_env.trans_prob_matrices = np.load(os.path.join(args.temp_path, "trans.npy"))
    eval_env.reward_matrices = np.load(os.path.join(args.temp_path, "reward.npy"))

    num_runs = args.num_episodes

    buffers = []    

    oracle_v, oracle_nash_strategy = env.get_nash_strategy()

    for _ in range(args.horizon):
        buffers.append(
                ReplayBufferSingle(env.observation_space.shape, 
                             env.action_space.n, 
                             args.num_players,
                             args.num_episodes * args.num_envs + 100, 
                             args.batch_size, 
                             device,
                             recent_size=args.recent_size)
            )

    rep_learners = make_rep_learner(env, device, args)

    for h in range(args.horizon):
        rep_learners[h].load_phi(h)

    agent = LSVI_UCB(env.observation_space.shape[0],
                     env.state_dim,
                     env.action_dim,
                     args.horizon,
                     args.alpha,
                     device,
                     rep_learners,
                     recent_size = args.lsvi_recent_size,
                     lamb = args.lsvi_lamb)
    agent.load_weight(args.temp_path)

    exploiter = LSVI_UCB_ORA(env.observation_space.shape[0],
                     env.state_dim,
                     env.action_dim,
                     args.horizon,
                     0.1,
                     device,
                     lamb = args.lsvi_lamb)

    for n in range(num_runs):

        obs = env.reset()

        for h in range(args.horizon):            
            action = agent.act_batch(obs, h)
            exploit = exploiter.act_batch(obs, h)
            rand_action = np.random.randint(0, 3, args.num_envs)
            exploit[:5] = rand_action[:5]
            action[:,0] = exploit
            next_obs, reward, _, _ = env.step(action)
            #reward = -reward
            buffers[h].add_batch(obs,exploit,reward,next_obs,args.num_envs)
            obs = next_obs
            


        if n % args.update_frequency == 0:

            exploiter.update(buffers)
            eval_return = evaluate(eval_env, agent, exploiter, args)
            print(n, eval_return)

    final_eval = []
    for _ in range(10):
        eval_return = evaluate(eval_env, agent, exploiter, args)
        final_eval.append(eval_return)
    print("final", np.mean(final_eval))
    print("oracle v:", np.mean(oracle_v[0], axis=0))
        

if __name__ == '__main__':

    args = parse_args()

    import wandb

    os.environ['WANDB_MODE'] = 'offline'
    project_name = "bmdp_h{}".format(args.horizon)

    with wandb.init(
            project= project_name,
            job_type="ratio_search",
            config=vars(args),
            name=args.exp_name):
        main(args)
