import numpy as np
import math 
import sys
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
from concave_maze import SimpleMazeConcave as ConcaveSimple
from concave_maze_extended import SimpleMazeConcave as ConcaveExtended

import mosgd

from io_utils import get_result_folder, dump_output, get_result_path, save_model, dump_settings

environment_dict = {"extended":ConcaveExtended,"simple":ConcaveSimple}

desired_policy_dict_intrinsic = {"Fast":{"threshold": -7.5, "stopping_point_1": -4.05, "stopping_point_2": -4.05},
                       "Mid-NoStop":{"threshold": -9.0, "stopping_point_1": 3.05, "stopping_point_2": -8.05},
#                        "Mid":{"threshold": -9.0 "stopping_point_1": -9.05, "stopping_point_2": -8.05},
                       "Careful":{"threshold": -1., "stopping_point_1": 0.95, "stopping_point_2": -10.05}
                      }

desired_policy_dict_intrinsic_flip = {"Fast":{"threshold": -7.5, "stopping_point_1": -4.05, "stopping_point_2": -4.05},
                       "Mid-NoStop":{"threshold": -3.1, "stopping_point_1": 3.05, "stopping_point_2": -8.05},
                       "Careful":{"threshold": -1., "stopping_point_1": 0.95, "stopping_point_2": -10.05}
                      }

desired_policy_dict_non_intrinsic = {"Reach":{"threshold": 0.5, "stopping_point_1": 0.95, "stopping_point_2": -2.05},
                                     "Reach-Nostop":{"threshold": 0.5, "stopping_point_1": 5, "stopping_point_2": -2.05},
                                     "NoReach":{"threshold": 0.1, "stopping_point_1": -4.05, "stopping_point_2": -4.05},
                      }

desired_policy_dict_non_intrinsic_flip = {"Fast":{"threshold": -7.5, "stopping_point_1": -4.05, "stopping_point_2": -4.05},
                       "Mid":{"threshold": -3.5, "stopping_point_1": -3.05, "stopping_point_2": -8.05},
                       "Careful":{"threshold": -1., "stopping_point_1": 0.95, "stopping_point_2": -10.05}
                      }

class Policy(nn.Module):
    def __init__(self, temp=1):
        super(Policy, self).__init__()
        self.temp = temp
        self.affine1 = nn.Linear(15, 128)
        self.dropout = nn.Dropout(p=0.6)
        self.affine2 = nn.Linear(128, 4)

        self.saved_log_probs = []
        self.rewards = [[],[]]

    def forward(self, x):
        x = self.affine1(x)
        x = self.dropout(x)
        x = F.relu(x)
        action_scores = self.affine2(x)
        return F.softmax(action_scores/self.temp, dim=1)
    
def select_action(state, policy):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = policy(state)
    m = Categorical(probs)
    action = m.sample()
    policy.saved_log_probs.append(m.log_prob(action))
    return action.item()    
        

def finish_episode(policy, optimizer, gamma, eps, desired_policy, active_const):
    gs = []
    fs = []
    for obj in range(2):
        R = 0
        policy_loss = []
        returns = []
        undiscounted_return = 0
        for r in policy.rewards[obj][::-1]:
            R = r + gamma * R
            returns.insert(0, R)
            undiscounted_return += r
        fs.append(R)
        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + eps)
        for log_prob, R in zip(policy.saved_log_probs, returns):
            policy_loss.append(-log_prob * R)
        optimizer.zero_grad()
        policy_loss = torch.cat(policy_loss).sum()
        policy_loss.backward(retain_graph= obj < len(policy.rewards)-1)
        grads = [param.grad.detach().numpy() for param in policy.parameters()]
        grads_flat = [g.flatten() for g in grads]
        g = np.concatenate(grads_flat)
        gs.append(g)
        policy.zero_grad()
        del policy.rewards[obj][:]
        
    d = mosgd.direction(gs, fs, ths=[desired_policy["threshold"]], active_const=active_const, delta=math.pi/18)
    if not d is None:
        direction_by_layer = []
        ind = 0
        for param in policy.parameters():
            direction_by_layer.append(torch.from_numpy(d[ind:ind+torch.numel(param)].reshape(param.shape)))
            ind += torch.numel(param)
    #     print("Layered again: ", [d.shape for d in direction_by_layer])
        with torch.no_grad():
            for param, layer_grad in zip(policy.parameters(), direction_by_layer):
    #             print(f"{param=}")
                param.grad = layer_grad
        optimizer.step()    
    del policy.saved_log_probs[:]    
    
def train(seed, n_episodes, desired_policy):
    env = environment_dict[ENV](intrinsic_reward=False, flip_rewards=False)
    # env.seed(args.seed)
    torch.manual_seed(seed)

    policy = Policy(temp=10)
    optimizer = optim.SGD(policy.parameters(), lr=1e-3)
    eps = np.finfo(np.float32).eps.item()

    active_const = True        

    running_reward1 = 0
    running_reward2 = 0
    
    
    settings_dict = {"temp":policy.temp, "intrinsic":env.intrinsic, "flip_rews":env.flip_rewards, "policy_settings":desired_policy,
                    "n_episodes":NUM_EPISODES, "n_seeds":NUM_SEEDS, "policy_name":DESIRED_POLICY, "policy_group":DESIRED_POLICY_GROUP,
                    "active_const":active_const}
    
    progressions = {"running1":[],
                    "running2":[],
                    "time":[],
                    "ep_reward1":[],
                    "ep_reward2":[],
                    "seed":seed,
                   }

    #     highest_reward1 = -100000
    for i_episode in range(n_episodes):
        state, ep_reward1, ep_reward2 = env.reset(), 0, 0
        for t in range(1, 100):  # Don't infinite loop while learning
            action = select_action(state, policy)
            state, reward1, reward2, done, _ = env.step(action)
            policy.rewards[0].append(reward1)
            policy.rewards[1].append(reward2)
            ep_reward1 += reward1
            ep_reward2 += reward2
            if done:
                break
#         highest_reward1 = max(highest_reward1, ep_reward1)

        running_reward1 = 0.05 * ep_reward1 + (1 - 0.05) * running_reward1
        running_reward2 = 0.05 * ep_reward2 + (1 - 0.05) * running_reward2
        
        progressions["running1"].append(running_reward1)
        progressions["running2"].append(running_reward2)
        progressions["ep_reward1"].append(ep_reward1)
        progressions["ep_reward2"].append(ep_reward2)
        progressions["time"].append(t)

        
        finish_episode(policy, optimizer, gamma=0.99, eps=eps, desired_policy=desired_policy, active_const=active_const)
        
        if i_episode % 100 == 0:
            print('Seed {}\tEpisode {}\tLast reward: {:.2f}\tAverage reward: {:.2f} \tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
                  seed, i_episode, ep_reward1, running_reward1,  ep_reward2, running_reward2))

        if running_reward1 > desired_policy["stopping_point_1"] and running_reward2 > desired_policy["stopping_point_2"]:
            print("Solved! Running reward is now {} and "
                  "the last episode runs to {} time steps!".format(running_reward1, t))
            return (True, i_episode, running_reward1, running_reward2, progressions, settings_dict)
    print(f"Unsuccessful :(. Tried for {i_episode} but the current runnning rewards are {running_reward1} and {running_reward2}")
    return (False, i_episode, running_reward1, running_reward2, progressions, settings_dict)



NUM_EPISODES = 4000
NUM_SEEDS=10
DESIRED_POLICY="Reach-Nostop"
DESIRED_POLICY_GROUP = desired_policy_dict_non_intrinsic
ENV = "simple"

def evaluate():
    import multiprocessing as mp    
    pool = mp.Pool(min(mp.cpu_count()//4, 40, NUM_SEEDS))
    print(pool, NUM_SEEDS)

    # Step 3: Use loop to parallelize
    pool_results = []
    for seed in range(NUM_SEEDS):
        print(f"Run seed {seed}")
        pool_results.append(pool.apply_async(train, args=(seed, NUM_EPISODES, DESIRED_POLICY_GROUP[DESIRED_POLICY])))

    # Step 4: Close Pool and let all the processes complete

    pool.close()
    pool.join()  # postpones the execution of next line of code until all processes in the queue are done.

    results = [r.get() for r in pool_results]

    return results

def evaluate_single():

    pool_results = []
    for seed in range(NUM_SEEDS):
        print(f"Run seed {seed}")
        pool_results.append(train(seed, NUM_EPISODES, DESIRED_POLICY_GROUP[DESIRED_POLICY]))
    
    return pool_results


if __name__ == '__main__':
    result_folder = get_result_folder(os.path.join("lex", ENV))     
    
    multiprocess = sys.argv[1]
    if multiprocess == "single":
        results = evaluate_single()
    elif multiprocess == "multi":
        results = evaluate()
    else:
        raise Exception
        
    for seed in range(len(results)):
        settings_path = get_result_path(os.path.join(result_folder, f"settings_seed{seed}.json"))
        results_path = get_result_path(os.path.join(result_folder, f"results_seed{seed}.pkl"))
        dump_settings(results[seed][-1],settings_path)
        dump_output(results[seed][-2],results_path)


    print(sum([res[0] for res in results]))
    
    