import numpy as np
import torch
import os
import math
import gym
import sys
import random
import time
import json
import copy

from collections import deque
import multiprocessing

from utils import make_batch_env, parse_args, set_seed_everywhere, ReplayBuffer, make_batch_target_env, make_batch_partition_env
from envs.Lock import Lock

from algs.lsvi_ucb_oracle import LSVI_UCB_ORA

from source_train import source_train, make_rep_learner

from algs.multi_rep_learn import MultiRepLearn


def evaluate(env, agent, args):
    returns = np.zeros((args.num_eval,1))
    
    obs, state, = env.reset()
    for h in range(args.horizon):
        action = agent.act_batch(obs, state, h)
        next_obs, next_state, reward, done, _ = env.step(action)
        obs = next_obs
        state = next_state
        returns += reward

    return np.mean(returns)

def load_source(args, noise_types, device):

    source_envs = []

    for i in range(args.num_sources):

        if args.partition:
            env, _ = make_batch_partition_env(args, i)
            save_path = "partition_{}".format(0)
        else:
            env, _ = make_batch_env(args, noise_types[i])
            save_path = noise_types[i]
        
        temp_path = os.path.join(args.load_path, save_path, str(args.seed))
        #env, _ = make_batch_env(args, noise_types[i])
        env.opt_a = np.load(os.path.join(temp_path, "opt_a.npy"))
        env.opt_b = np.load(os.path.join(temp_path, "opt_b.npy"))
        #env.rotation = np.load(os.path.join(temp_path, "rotation.npy"))

        source_envs.append(env)

        print(save_path, "a:")
        print(env.opt_a)

        print(save_path, "b:")
        print(env.opt_b)

    return source_envs

def main(args):
    
    set_seed_everywhere(args.seed)
    noise_types = ["hadamard_gaussian", "hadamard_uniform", "hadamard_bernoulli",  "hadamard_uniposneg", "hadamard_berposneg"]
    #noise_types = ["hadamard_gaussian", "hadamard_uniform", "hadamard_ber"]
    if not os.path.exists(args.temp_path):
        os.makedirs(args.temp_path)

    device = torch.device("cpu")
    
    #if args.load_source:
    source_envs = load_source(args, noise_types, device)

    env, eval_env = make_batch_target_env(args, source_envs)

    num_actions = env.action_space.n

    if not os.path.exists(args.load_path):
        os.makedirs(args.load_path)

    temp_path = args.temp_path

    if not os.path.exists(temp_path):
        os.makedirs(temp_path)

    np.save(os.path.join(temp_path, "opt_a"), env.opt_a)
    np.save(os.path.join(temp_path, "opt_b"), env.opt_b)
    np.save(os.path.join(temp_path, "rotation"), env.rotation)
    
    for h in range(args.horizon):
        temp_path_h = os.path.join(temp_path, "buffer_{}".format(h))
        if not os.path.exists(temp_path_h):
            os.makedirs(temp_path_h)


    num_runs = int(args.num_episodes / args.horizon / args.num_envs)

    buffers = []    

    counts = np.zeros((args.horizon,3),dtype=np.int)

    for _ in range(args.horizon):
        buffers.append(
                ReplayBuffer(env.observation_space.shape, 
                             env.action_space.n, 
                             int(args.num_episodes / args.horizon) * 2 + args.num_warm_start * args.num_envs, 
                             args.batch_size, 
                             device,
                             recent_size=args.recent_size)
            )

    if args.dense:
        args.alpha = args.horizon / 50
    else:
        args.alpha = args.horizon / 5

    agent = LSVI_UCB_ORA(env.observation_space.shape[0],
                     env.state_dim,
                     env.action_dim,
                     args.horizon,
                     1,
                     device,
                     lamb = 1)

    results = []

    if args.variable_latent:
        returns = deque(maxlen=50)
    else:
        returns = deque(maxlen=5)

    # for _ in range(2):
    #     obs, state = env.reset()
    #     for h in range(args.horizon):
    #         action = np.random.randint(0, num_actions, args.num_envs)
    #         next_obs, next_state, reward, done, _ = env.step(action)
    #         buffers[h].add_batch(obs,state,action,reward,next_obs,next_state,args.num_envs)
    #         obs = next_obs
    #         state = next_state

    for n in range(num_runs):

        obs, state = env.reset()
        #print(state)
        for h in range(args.horizon):
            action = agent.act_batch(obs, state, h)
            rand_action = np.random.randint(0, num_actions, args.num_envs)
            action[0] = rand_action[0]
            next_obs,next_state, reward, done, _ = env.step(action)
            #if h == 0:
            #    print(action)
            count = env.get_counts()
            counts[h] = counts[h] + count
            buffers[h].add_batch(obs,state,action,reward,next_obs,next_state,args.num_envs)
            obs = next_obs
            state = next_state

        if n % args.update_frequency == 0:


            agent.update(buffers)
            eval_return = evaluate(eval_env, agent, args)

            returns.append(eval_return)

            #print(n, counts)

            wandb.log({
                        "eval": np.mean(list(returns)) if args.variable_latent else eval_return,
                        "episode:": n * args.num_envs}
                    )

            agent.save_weight(temp_path)

            if np.mean(list(returns)) == 1 and not args.variable_latent and not args.dense:
                return 


if __name__ == '__main__':

    args = parse_args()

    import wandb

    if args.partition:
        project_name = "target_oracle_partition"
    else:
        project_name = "target_oracle"

    with wandb.init(
            project=project_name,
            job_type="ratio_search",
            config=vars(args),
            name=args.exp_name):
        main(args)









