from turtle import done
import numpy as np
import torch
import os
import math
import gym
import sys
import random
import time
import json
import copy

from collections import deque

import multiprocessing

from utils import parse_args, sample, set_seed_everywhere, ReplayBufferSingle, make_batch_env

from briee.lsvi_ucb_cce import LSVI_UCB
from briee.lsvi_ucb_oracle import LSVI_UCB_ORA
from briee.rep_learn import RepLearn

from briee.lsvi_ucb_gensum import LSVI_UCB_GENSUM


from main import make_rep_learner

def evaluate_trained(env, agent, args):
    returns1 = np.zeros((args.num_eval,1))
    returns2 = np.zeros((args.num_eval,1))
    
    obs = env.reset()
    for h in range(args.horizon):
        action = agent.act_batch(obs, h)
        next_obs, reward1, reward2, done, _ = env.step(action)
        obs = next_obs
        returns1 += reward1
        returns2 += reward2
        

    return np.mean(returns1), np.mean(returns2) 


def evaluate(env, agent, exploiter, args, agent1=False):
    returns = np.zeros((args.num_eval,1))
    
    obs = env.reset()
    for h in range(args.horizon):
        action = agent.act_batch(obs, h)
        exploit = exploiter.act_batch(obs, h)
        if agent1:
            action[:,0] = exploit
        else:
            action[:,1] = exploit
        next_obs, reward1, reward2, _, _ = env.step(action)
            #reward = -reward
        if agent1:
            reward = reward1
        else:
            reward = reward2

        obs = next_obs
        returns += reward

    return np.mean(returns)

def run(args, agent, env, eval_env, num_runs, device, agent1=False):
    buffers = []    

    #oracle_v, oracle_nash_strategy = env.get_nash_strategy()
    
    _, _, oracle_nash_strategy = env.get_nash_strategy()


    for _ in range(args.horizon):
        buffers.append(
                ReplayBufferSingle(env.observation_space.shape, 
                             env.action_space.n, 
                             args.num_players,
                             args.num_episodes * args.num_envs + 100, 
                             args.batch_size, 
                             device)
            )

    

    exploiter = LSVI_UCB_ORA(env.observation_space.shape[0],
                     env.state_dim,
                     env.action_dim,
                     args.horizon,
                     0.1,
                     device,
                     lamb = args.lsvi_lamb)

    for n in range(num_runs):

        obs = env.reset()

        for h in range(args.horizon):            
            action = agent.act_batch(obs, h)

            state = env.get_state()
            action_prob = oracle_nash_strategy[h][state]
            #print("oracle!!!!")
            #print(action_prob)

            exploit = exploiter.act_batch(obs, h)
            rand_action = np.random.randint(0, 3, args.num_envs)
            exploit[:5] = rand_action[:5]
            if agent1:
                action[:,0] = exploit
            else:
                action[:,1] = exploit
            next_obs, reward1, reward2, _, _ = env.step(action)
            #reward = -reward
            if agent1:
                reward = reward1
            else:
                reward = reward2
            buffers[h].add_batch(obs,exploit,reward,next_obs,args.num_envs)
            obs = next_obs
            
        if n % args.update_frequency == 0:

            exploiter.update(buffers)
            # eval_return = evaluate(eval_env, agent, exploiter, args, agent1=agent1)
            # print(n, eval_return)

    final_eval = []
    for _ in range(10):
        eval_return = evaluate(eval_env, agent, exploiter, args, agent1=agent1)
        final_eval.append(eval_return)
    print("final", np.mean(final_eval))


def main(args):
    set_seed_everywhere(args.seed)

    env, eval_env = make_batch_env(args, gensum=True)

    num_actions = env.action_space.n

    device = torch.device("cpu")

    # if not os.path.exists(args.temp_path):
    #     os.makedirs(args.temp_path)

    # env.trans_prob_matrices = np.load(os.path.join(args.temp_path, "trans.npy"))
    # env.reward_matrices1 = np.load(os.path.join(args.temp_path, "reward1.npy"))
    # env.reward_matrices2 = np.load(os.path.join(args.temp_path, "reward2.npy"))
    # eval_env.trans_prob_matrices = np.load(os.path.join(args.temp_path, "trans.npy"))
    # eval_env.reward_matrices1 = np.load(os.path.join(args.temp_path, "reward1.npy"))
    # eval_env.reward_matrices2 = np.load(os.path.join(args.temp_path, "reward2.npy"))

    num_runs = args.num_episodes

    rep_learners = make_rep_learner(env, device, args)

    for h in range(args.horizon):
        rep_learners[h].load_phi(h)

    agent = LSVI_UCB_GENSUM(env.observation_space.shape[0],
                     env.state_dim,
                     env.action_dim,
                     args.horizon,
                     args.alpha,
                     device,
                     rep_learners,
                     recent_size = args.lsvi_recent_size,
                     lamb = args.lsvi_lamb)
    agent.load_weight(args.temp_path)
    
    pre1 = []
    pre2 = []
    for _ in range(10):
        return1, return2 = evaluate_trained(eval_env, agent, args)
        pre1.append(return1)
        pre2.append(return2)
    #return1, return2 = evaluate_trained(eval_env, agent, args)
    print("pre:", np.mean(pre1), np.mean(pre2))
    

    print("exploit agent 2:")
    run(args, agent, env, eval_env, num_runs, device, agent1=False)

    print("exploit agent 1:")
    run(args, agent, env, eval_env, num_runs, device, agent1=True)

    #print("oracle v:", np.mean(oracle_v[0], axis=0))


if __name__ == '__main__':

    args = parse_args()

    import wandb

    os.environ['WANDB_MODE'] = 'offline'
    project_name = "bmdp_h{}".format(args.horizon)

    with wandb.init(
            project= project_name,
            job_type="ratio_search",
            config=vars(args),
            name=args.exp_name):
        main(args)
