import numpy as np
import torch
import os
import math
import gym
import sys
import random
import time
import json
import copy

from collections import deque
import multiprocessing

from utils import make_batch_env, parse_args, set_seed_everywhere, ReplayBuffer, make_batch_target_env, make_batch_partition_env
from envs.Lock import Lock

from algs.lsvi_ucb import LSVI_UCB

from source_train import source_train, make_rep_learner, rep_train, evaluate

from algs.multi_rep_learn import MultiRepLearn

def make_multirep_learner(env, device, args, path):    

    rep_learners = []
    for h in range(args.horizon):
        if h == 0:
            temperature = args.phi0_temperature
        else:
            temperature = args.target_temperature

        rep_learners.append(
                MultiRepLearn(env.observation_space.shape[0],
                         env.state_dim,
                         env.action_dim,
                         args.hidden_dim,
                         args.rep_num_update,
                         args.rep_num_feature_update,
                         args.rep_num_adv_update,
                         device,
                         discriminator_lr=args.discriminator_lr,
                         discriminator_beta=args.discriminator_beta,
                         feature_lr=args.feature_lr,
                         feature_beta=args.feature_beta,
                         weight_lr=args.linear_lr,
                         weight_beta=args.linear_beta, 
                         batch_size = args.batch_size,
                         lamb = args.rep_lamb,
                         tau =  temperature,
                         optimizer = args.optimizer,
                         softmax = args.softmax,
                         reuse_weights = args.reuse_weights,
                         temp_path = path,
                         num_sources = args.num_sources)
                )
    return rep_learners   

def load_source(args, noise_types, device):

    source_policies = []
    source_envs = []

    for i in range(args.num_sources):

        if args.partition:
            env, _ = make_batch_partition_env(args, i)
            save_path = "partition_{}".format(0)
        else:
            env, _ = make_batch_env(args, noise_types[i])
            save_path = noise_types[i]
        
        temp_path = os.path.join(args.load_path, save_path, str(args.seed))
        #env, _ = make_batch_env(args, noise_types[i])
        env.opt_a = np.load(os.path.join(temp_path, "opt_a.npy"))
        env.opt_b = np.load(os.path.join(temp_path, "opt_b.npy"))
        #env.rotation = np.load(os.path.join(temp_path, "rotation.npy"))

        source_envs.append(env)

        print(save_path, "a:")
        print(env.opt_a)

        print(save_path, "b:")
        print(env.opt_b)

        rep_learners = make_rep_learner(env, device, args, temp_path)
        for h in range(args.horizon):
            rep_learners[h].load_phi(h)
        
        agent = LSVI_UCB(env.observation_space.shape[0],
                    env.state_dim,
                    env.action_dim,
                    args.horizon,
                    args.alpha,
                    device,
                    rep_learners,
                    recent_size = args.lsvi_recent_size,
                    lamb = args.lsvi_lamb)

        agent.load_weight(temp_path)

        source_policies.append(agent)

    return source_policies, source_envs

def main(args):
    
    set_seed_everywhere(args.seed)
    noise_types = ["hadamard_gaussian", "hadamard_uniform", "hadamard_bernoulli",  "hadamard_uniposneg", "hadamard_berposneg"]
    #noise_types = ["hadamard_gaussian", "hadamard_uniform", "hadamard_ber"]
    if not os.path.exists(args.temp_path):
        os.makedirs(args.temp_path)

    device = torch.device("cpu")
    
    #if args.load_source:
    source_policies, source_envs = load_source(args, noise_types, device)

    env, eval_env = make_batch_target_env(args, source_envs)

    num_actions = env.action_space.n

    if not os.path.exists(args.load_path):
        os.makedirs(args.load_path)

    temp_path = args.temp_path

    if not os.path.exists(temp_path):
        os.makedirs(temp_path)

    np.save(os.path.join(temp_path, "opt_a"), env.opt_a)
    np.save(os.path.join(temp_path, "opt_b"), env.opt_b)
    np.save(os.path.join(temp_path, "rotation"), env.rotation)
    
    for h in range(args.horizon):
        temp_path_h = os.path.join(temp_path, "buffer_{}".format(h))
        if not os.path.exists(temp_path_h):
            os.makedirs(temp_path_h)


    num_runs = int(args.num_episodes / args.horizon / args.num_envs)

    buffers = []    

    for _ in range(args.horizon):
        buffers.append(
                ReplayBuffer(env.observation_space.shape, 
                             env.action_space.n, 
                             int(args.num_episodes / args.horizon) * 2 + args.num_warm_start * args.num_envs, 
                             args.batch_size, 
                             device,
                             recent_size=args.recent_size)
            )

    rep_learners = make_rep_learner(env, device, args, temp_path)

    if args.dense:
        args.alpha = args.horizon / 50
    else:
        args.alpha = args.horizon / 5

    agent = LSVI_UCB(env.observation_space.shape[0],
                     env.state_dim,
                     env.action_dim,
                     args.horizon,
                     args.alpha,
                     device,
                     rep_learners,
                     recent_size = args.lsvi_recent_size,
                     lamb = args.lsvi_lamb)

    if args.horizon >= 50:
        args.num_warm_start = 200

    for _ in range(args.num_warm_start):
        obs, state = env.reset()
        for h in range(args.horizon):
            action = np.random.randint(0, num_actions, args.num_envs)
            next_obs, next_state, reward, done, _ = env.step(action)
            buffers[h].add_batch(obs,state,action,reward,next_obs,next_state,args.num_envs)
            obs = next_obs
            state = next_state

    counts = np.zeros((args.horizon,3),dtype=np.int)

    results = []

    if args.variable_latent:
        returns = deque(maxlen=50)
    else:
        returns = deque(maxlen=5)

    inference_start_time = time.time()

    for n in range(num_runs):

        for h in range(args.horizon):
            t = 0
            obs, state = env.reset()
            while t < h:
                action = agent.act_batch(obs, t)
                next_obs,next_state, _, _, _ = env.step(action)
                obs = next_obs
                state = next_state
                t += 1
            #print(t)
            action = np.random.randint(0, num_actions, args.num_envs)
            next_obs, next_state, reward, done, _ = env.step(action)
            buffers[h].add_batch(obs,state,action,reward,next_obs,next_state,args.num_envs)

            count = env.get_counts()
            counts[h] = counts[h] + count

            if h != args.horizon - 1:
                obs = next_obs
                state = next_state
                action = np.random.randint(0, num_actions, args.num_envs)
                next_obs, next_state, reward, done, _ = env.step(action)
                buffers[h+1].add_batch(obs,state,action,reward,next_obs,next_state,args.num_envs)

                count = env.get_counts()
                counts[h+1] = counts[h+1] + count

            else:
                obs, state = env.reset()
                action = np.random.randint(0, num_actions, args.num_envs)
                next_obs, next_state, reward, done, _ = env.step(action)
                buffers[0].add_batch(obs,state,action,reward,next_obs,next_state,args.num_envs)

                count = env.get_counts()
                counts[0] = count + counts[0]
        
        for b in range(len(buffers)):
            buffers[b].save(os.path.join(temp_path, "buffer_{}".format(b)))

        if n % args.update_frequency == 0:

            inference_time = time.time() - inference_start_time

            assert args.horizon % args.num_threads == 0
            start_time = time.time()
            num_multi_runs = int(args.horizon / args.num_threads) 
            
            feature_loss_list = []
            adv_loss_list = []

            for m in range(num_multi_runs):
                queue = multiprocessing.Queue()
                workers = []
                for i in range(args.num_threads):
                    h = m*args.num_threads + i
                    worker_args = (rep_learners[h], buffers[h], h, queue)
                    workers.append(multiprocessing.Process(target=rep_train, args=worker_args))
                for worker in workers:
                    worker.start()

                for _ in workers:
                    pid, feature_loss, adv_loss = queue.get()
                    feature_loss_list.append(feature_loss)
                    adv_loss_list.append(adv_loss)
                    rep_learners[pid].load_phi(pid)
                
            rep_learn_time = time.time() - start_time

            start_time = time.time()
            agent.update(buffers)
            lsvi_time = time.time() - start_time


            start_time = time.time()

            eval_return = evaluate(eval_env, agent, args)

            returns.append(eval_return)

            eval_time = time.time() - start_time

            reached = 0
            for h in range(args.horizon):
                if counts[h,:2].sum() < 5:
                    reached = h
                    break

            wandb.log({"rep_learn_time": rep_learn_time,
                        "lsvi_time": lsvi_time,
                        "eval": np.mean(list(returns)) if args.variable_latent else eval_return,
                        "episode:": n * args.num_envs,
                        "reached": reached,
                        "state 0": counts[-1,0],
                        "state 1": counts[-1,1],
                        "episode:": n * args.num_envs * args.horizon,
                        "sampling time": inference_time,
                        "eval time": eval_time})


            agent.save_weight(temp_path)

            np.save("{}/counts".format(temp_path), counts)

            inference_start_time = time.time()

            if np.mean(list(returns)) == 1 and not args.variable_latent and not args.dense:
                return 


if __name__ == '__main__':

    args = parse_args()

    import wandb

    if args.partition:
        project_name = "target_train_partition"
    else:
        project_name = "target_train"

    with wandb.init(
            project=project_name,
            job_type="ratio_search",
            config=vars(args),
            name=args.exp_name):
        main(args)









