import os
import torch
from shutil import copyfile, copytree
import torch.nn as nn
import argparse
import givenData
import numpy as np
from gym.envs.registration import register

def init(module, weight_init, bias_init, gain=1):
    weight_init(module.weight.data, gain=gain)
    bias_init(module.bias.data)
    return module

class AddBias(nn.Module):
    def __init__(self, bias):
        super(AddBias, self).__init__()
        self._bias = nn.Parameter(bias.unsqueeze(1))

    def forward(self, x):
        if x.dim() == 2:
            bias = self._bias.t().view(1, -1)
        elif x.dim() == 1:
            bias = self._bias.t().view(1, -1)
        elif x.dim() == 3:
            bias = self._bias.t().view(1, 1, -1)
        else:
            assert False

        return x + bias

def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):
    """Decreases the learning rate linearly"""
    lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs)))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def backup(timeStr, args, upper_policy = None):
    if args.evaluate:
        targetDir = os.path.join('./logs/evaluation', timeStr)
    else:
        targetDir = os.path.join('./logs/experiment', timeStr)

    if not os.path.exists(targetDir):
        os.makedirs(targetDir)
    copyfile('attention_model.py', os.path.join(targetDir, 'attention_model.py'))
    copyfile('distributions.py',    os.path.join(targetDir, 'distributions.py'))
    copyfile('envs.py',    os.path.join(targetDir, 'envs.py'))
    copyfile('evaluation.py', os.path.join(targetDir, 'evaluation.py'))
    copyfile('evaluation_tools.py', os.path.join(targetDir, 'evaluation_tools.py'))
    copyfile('givenData.py',    os.path.join(targetDir, 'givenData.py'))
    copyfile('graph_encoder.py', os.path.join(targetDir, 'graph_encoder.py'))
    copyfile('kfac.py',    os.path.join(targetDir, 'kfac.py'))
    copyfile('main.py',    os.path.join(targetDir, 'main.py'))
    copyfile('model.py',   os.path.join(targetDir, 'model.py'))
    copyfile('storage.py',   os.path.join(targetDir, 'storage.py'))
    copyfile('tools.py',   os.path.join(targetDir, 'tools.py'))
    copyfile('train_tools.py', os.path.join(targetDir, 'train_tools.py'))

    gymPath = './pct_envs'
    envName = args.id.split('-v')
    envName = envName[0] + envName[1]
    envPath = os.path.join(gymPath, envName)
    copytree(envPath, os.path.join(targetDir, envName))

    if upper_policy is not None:
        torch.save(upper_policy.state_dict(), os.path.join(args.model_save_path, timeStr, 'upper-first-' + timeStr + ".pt"))

# Parsing PCT node from state returned in environment
def get_leaf_nodes(observation, internal_node_holder, leaf_node_holder):
    unify_obs = observation.reshape((observation.shape[0], -1, 9))
    leaf_nodes = unify_obs[:, internal_node_holder:internal_node_holder + leaf_node_holder, :]
    return unify_obs, leaf_nodes

def get_leaf_nodes_with_factor(observation,  batch_size, internal_node_holder, leaf_node_holder):
    unify_obs = observation.reshape((batch_size, -1, 9))
    # unify_obs[:, :, 0:6] *= factor
    leaf_nodes = unify_obs[:, internal_node_holder:internal_node_holder + leaf_node_holder, :]
    return unify_obs, leaf_nodes

'''
Parsing the raw state returned in environment:

internal_nodes    : A packed item vector, [x1, y1, z1, x2, y2, z2, density(optional) ]
                    x1, y1, z1 are coordinates of a packed item
                    x2 = x1 + x, y2 = y1 + y, z2 = z1 + z
                    x, y, z are sizes of a packed item (a little different from the original paper,
                    these two description have similar performance.).
leaf_nodes        : A placement vector, [x1, y1, z1, x2, y2, z2]
                    x1, y1, z1 are coordinates of a placement.
                    x2 = x1 + x, y2 = y1 + y, z2 = z1 + z
                    x, y, z are  sizes of the current item after an axis-aligned orientation (a little different from the original paper,
                    these two description have similar performance.).
next_item         : The next item to be packed [density(optional), 0, 0,x, y, z]
                    x, y, z are  sizes of the current item.
invalid_leaf_nodes: The mask which indicates whether this placement is feasible.
full_mask         : The mask which indicates whether this node should be encode by GAT.
'''
def observation_decode_leaf_node_actor(observation, internal_node_holder, internal_node_length, leaf_node_holder,num_seen_box=None):
    internal_nodes = observation[:, 0:internal_node_holder, 0:internal_node_length]
    leaf_nodes = observation[:, internal_node_holder:internal_node_holder + leaf_node_holder, 0:8]
    valid_flag = observation[:,internal_node_holder: internal_node_holder + leaf_node_holder, 8]
    if num_seen_box is not None:
        current_box = observation[:,internal_node_holder + leaf_node_holder:internal_node_holder + leaf_node_holder + num_seen_box, 0:6]
        full_mask = observation[:, 0:internal_node_holder + leaf_node_holder + num_seen_box, -1]
    else:
        current_box = observation[:,internal_node_holder + leaf_node_holder:, 0:6]
        full_mask = observation[:, :, -1]
    return internal_nodes, leaf_nodes, current_box, valid_flag, full_mask


def observation_decode_leaf_node_search(observation, internal_node_holder, internal_node_length, leaf_node_holder):
    internal_nodes = observation[:, 0:internal_node_holder, 0:internal_node_length]
    leaf_nodes = observation[:, internal_node_holder:internal_node_holder + leaf_node_holder, 0:8]
    current_box = observation[:,internal_node_holder + leaf_node_holder:internal_node_holder + leaf_node_holder+1, 0:6]
    valid_flag = observation[:,internal_node_holder: internal_node_holder + leaf_node_holder, 8]
    full_mask = observation[:, 0:internal_node_holder + leaf_node_holder+1, -1]
    return internal_nodes, leaf_nodes, current_box, valid_flag, full_mask


def process_search_space(observation, internal_node_holder, leaf_node_holder, indices, sol_action = None):

    internal_nodes = observation[:, 0:internal_node_holder, :]
    leaf_nodes = observation[:, internal_node_holder:internal_node_holder + leaf_node_holder, :]
    current_box = observation[:,internal_node_holder + leaf_node_holder:, :]
    bsz = indices.size(0)
    search_num = indices.size(1)
    idx = torch.arange(0,bsz).repeat(search_num).sort()[0].to(observation.device)
    processed_leaf_nodes = leaf_nodes[idx,indices.view((-1,)),:].view((bsz,search_num,-1))
    if sol_action is not None:
        processed_leaf_nodes = torch.cat((processed_leaf_nodes,sol_action.unsqueeze(dim=1)),dim=1)
    return torch.cat((internal_nodes,processed_leaf_nodes,current_box),dim=1), processed_leaf_nodes

def construct_training_dataset(dataset, num_steps, tag, seq_len_set):

    out_dataset = []

    for i in range(len(seq_len_set)):

        cur_dataset = dataset[seq_len_set[i]]
        if tag == 'dist':
            out_dataset.append(cur_dataset)
        else:
            out_dataset.append(cur_dataset[:,:,:num_steps])

    data_shape = out_dataset[0].size()[1:]
    return torch.stack(out_dataset,dim=0).view(-1,*data_shape)

def construct_training_set_for_current_epoch(dataset, train_steps):

    dataset_size = dataset['instances'].size(0)
    select_idx = torch.randperm(dataset_size)[:train_steps]
    cur_dataset = {'instances':dataset['instances'][select_idx],'distributions':dataset['distributions'][select_idx],'sols':dataset['sols'][select_idx]}
    return cur_dataset


def load_policy(load_path, upper_policy):
    
    assert os.path.exists(load_path), 'File does not exist'
    pretrained_state_dict = torch.load(load_path, map_location='cpu')
    if len(pretrained_state_dict) == 2:
        pretrained_state_dict, ob_rms = pretrained_state_dict

    load_dict = {}
    for k, v in pretrained_state_dict.items():
        if 'actor.embedder.layers' in k:
            load_dict[k.replace('module.weight', 'weight')] = v
        else:
            load_dict[k.replace('module.', '')] = v

    load_dict = {k.replace('add_bias.', ''): v for k, v in load_dict.items()}
    load_dict = {k.replace('_bias', 'bias'): v for k, v in load_dict.items()}
    for k, v in load_dict.items():
        if len(v.size()) <= 3:
            load_dict[k] = v.squeeze(dim=-1)
    upper_policy.load_state_dict(load_dict, strict=True)
    print('Loading pre-train upper model', load_path)
    return upper_policy

def get_args():
    parser = argparse.ArgumentParser(description='PCT arguments')
    parser.add_argument('--setting', type=int, default=2, help='Experiment setting, please see our paper for details')
    parser.add_argument('--lnes', type=str, default='EMS', help='Leaf Node Expansion Schemes: EMS (recommend), EV, EP, CP, FC')
    parser.add_argument('--internal-node-holder', type=int, default=80, help='Maximum number of internal nodes')
    parser.add_argument('--leaf-node-holder', type=int, default=50, help='Maximum number of leaf nodes')
    parser.add_argument('--shuffle',type=bool, default=True, help='Randomly shuffle the leaf nodes')
    parser.add_argument('--continuous', action='store_true', help='Use continuous enviroment, otherwise the enviroment is discrete')

    parser.add_argument('--no-cuda',action='store_true', help='Forbidden cuda')
    parser.add_argument('--device', type=int, default=0, help='Which GPU will be called')
    parser.add_argument('--seed',   type=int, default=4, help='Random seed')
    parser.add_argument('--num-seen-box',   type=int, default=1, help='Number of boxes that could be seen')
    parser.add_argument('--eval-idx',   type=int, default=1, help='index of evaluated model')

    parser.add_argument('--dataset-path', type=str, default='', help='Path to Your Dataset')
    parser.add_argument('--least-size', type=int, default=1, help='item size boundary')
    parser.add_argument('--cross-range', type=int, default=6, help='item size range')
    parser.add_argument('--use-acktr', type=bool, default=True, help='Use acktr, otherwise A2C')
    parser.add_argument('--num-processes', type=int, default=64, help='The number of parallel processes used for training')
    parser.add_argument('--next-holder', type=int, default=10, help='The number of next objects that could be seen')
    parser.add_argument('--num-steps', type=int, default=10, help='The rollout length for n-step training')
    parser.add_argument('--learning-rate', type=float, default=1e-6, metavar='η', help='Learning rate, only works for A2C')
    parser.add_argument('--actor-loss-coef',        type=float, default=1.0, help='The coefficient of actor loss')
    parser.add_argument('--critic-loss-coef',       type=float, default=1.0, help='The coefficient of critic loss')
    parser.add_argument('--encoder-loss-coef',       type=float, default=1.0, help='The coefficient of encoder loss')
    parser.add_argument('--bc-loss-coef',       type=float, default=1.0, help='The coefficient of behavior cloning loss')
    parser.add_argument('--max-grad-norm',          type=float, default=0.5, help='Max norm of gradients')
    parser.add_argument('--embedding-size',     type=int, default=64,  help='Dimension of input embedding')
    parser.add_argument('--hidden-size',        type=int, default=128, help='Dimension of hidden layers')
    parser.add_argument('--gat-layer-num',      type=int, default=1, help='The number GAT layers')
    parser.add_argument('--gamma', type=float, default=1.0, metavar='γ', help='Discount factor')

    parser.add_argument('--model-save-interval',    type=int,   default=200   , help='How often to save the model')
    parser.add_argument('--model-update-interval',  type=int,   default=20e3 , help='How often to create a new model')
    parser.add_argument('--model-save-path',type=str, default='./logs/experiment', help='The path to save the trained model')
    parser.add_argument('--print-log-interval',     type=int,   default=10, help='How often to print training logs')

    parser.add_argument('--evaluate', action='store_true',default=False, help='Evaluate only')
    parser.add_argument('--output-mask', action='store_true',default=False, help='Train only')
    parser.add_argument('--is-evaluate', action='store_true',default=False, help='Evaluate only')
    parser.add_argument('--evaluation-episodes', type=int, default=100, metavar='N', help='Number of episodes evaluated')
    parser.add_argument('--load-model', action='store_true', help='Load the trained model')
    parser.add_argument('--model-path', type=str, help='The path to load model')
    parser.add_argument('--sub-time-str', type=str, help='which model to load')
    parser.add_argument('--load-dataset', action='store_true', help='Load an existing dataset, otherwise the data is generated on the fly')
    parser.add_argument('--keep-prev', action='store_true', help='keep previous information')
    parser.add_argument('--dataset-path', type=str, help='The path to load dataset')
    parser.add_argument('--search-num', type=int,   default=5   , help='size of searched space')

    parser.add_argument('--sample-from-distribution', action='store_true', help='Sample continuous item size from a uniform distribution U(a,b), otherwise sample items from \'item_size_set\' in \'givenData.py\'')
    parser.add_argument('--sample-left-bound', type=float, metavar='a', help='The parametre a of distribution U(a,b)')
    parser.add_argument('--sample-right-bound', type=float, metavar='b', help='The parametre b of distribution U(a,b)')

    parser.add_argument('--use-gae', action='store_true', default=False,
                        help='choose whether to use GAE for advantage approximation or not')
    parser.add_argument('--use-teacher', action='store_true', default=False,
                        help='choose whether to use teacher or not')
    parser.add_argument('--use-meta', action='store_true', default=False,
                        help='choose whether to use meta training')
    parser.add_argument('--learnable-init', action='store_true', default=False,
                        help='choose whether to use GAE for advantage approximation or not')
    parser.add_argument('--ema-use', action='store_true', default=False,
                        help='use ema to update teacher model or not')
    parser.add_argument('--copy-use', action='store_true', default=False,
                        help='copy parameters to update teacher model or not')
    parser.add_argument('--allow-dist-input', action='store_true', default=False,
                        help='allow input distribution.')
    parser.add_argument('--use-dataset', action='store_true', default=False,
                        help='use dataset to train.')
    parser.add_argument('--gae-lambda', type=float, default=0.95,
                        help='gae lambda parameter (default: 0.95)')
    parser.add_argument('--use-proper-time-limits', action='store_true', default=False,
                        help='compute returns taking into account time limits')
    parser.add_argument('--ppo-epoch', type=int, default=1,
                        help='number of ppo epochs (default: 10)')
    parser.add_argument('--num-mini-batch', type=int, default=32,
                        help='number of batches for ppo (default: 32)')
    parser.add_argument('--value-loss-coef', type=float, default=1.,
                        help='The coefficient of value loss of PPO')
    parser.add_argument('--entropy-coef', type=float, default=0.01,
                        help='The coefficient of entropy of PPO')
    parser.add_argument('--clip-param', type=float, default=0.1,
                        help='ppo clip parameter (default: 0.1)')
    parser.add_argument('--max-model-num', type=int, default=20)

    args = parser.parse_args()

    if args.no_cuda: args.device = 'cpu'

    args.container_size = givenData.container_size
    args.item_size_set  = givenData.item_size_set

    if args.sample_from_distribution and args.sample_left_bound is None:
        args.sample_left_bound = 0.1 * min(args.container_size)
    if args.sample_from_distribution and args.sample_right_bound is None:
        args.sample_right_bound = 0.5 * min(args.container_size)

    if args.continuous:
        # args.id = 'PctContinuous-v0'
        args.id = 'PctContinuous-v1'
    else:
        #args.id = 'PctDiscrete-v0'
        args.id = 'PctDiscrete-v1'

    if args.setting == 1:
        args.internal_node_length = 6
    elif args.setting == 2:
        args.internal_node_length = 6
    elif args.setting == 3:
        args.internal_node_length = 7
    if args.evaluate:
        args.num_processes = 1
    args.normFactor = 1.0 / np.max(args.container_size)

    return args

def get_args_heuristic():
    parser = argparse.ArgumentParser(description='Heuristic baseline arguments')

    parser.add_argument('--continuous', action='store_true', help='Use continuous enviroment, otherwise the enviroment is discrete')
    parser.add_argument('--setting', type=int, default=2, help='Experiment setting, please see our paper for details')
    # parser.add_argument('--evaluate', action='store_true', help='Evaluate only')
    parser.add_argument('--evaluation-episodes', type=int, default=10, metavar='N', help='Number of episodes evaluated')
    parser.add_argument('--load-dataset', action='store_true', help='Load an existing dataset, otherwise the data is generated on the fly')
    parser.add_argument('--dataset-path', type=str, help='The path to load dataset')
    parser.add_argument('--heuristic', type=str, default='LSAH', help='Options: LSAH DBL MACS OnlineBPH HM BR RANDOM')


    args = parser.parse_args()
    args.container_size = givenData.container_size
    args.item_size_set  = givenData.item_size_set
    args.evaluate = True

    if args.continuous:
        assert args.heuristic == 'LSAH' or args.heuristic == 'OnlineBPH' or args.heuristic == 'BR', 'only LSAH, OnlineBPH, and BR allowed for continuous environment'

    if args.setting == 1:
        args.internal_node_length = 6
    elif args.setting == 2:
        args.internal_node_length = 6
    elif args.setting == 3:
        args.internal_node_length = 7
    if args.evaluate:
        args.num_processes = 1
    args.normFactor = 1.0 / np.max(args.container_size)

    return args

def registration_envs():
    register(
        id='PctDiscrete-v0',                                  # Format should be xxx-v0, xxx-v1
        entry_point='pct_envs.PctDiscrete0:PackingDiscrete',  # Expalined in envs/__init__.py
    )
    register(
        id='PctDiscrete-v1',                                  # Format should be xxx-v0, xxx-v1
        entry_point='pct_envs.PctDiscrete0:PackingDiscreteV1',  # Expalined in envs/__init__.py
    )
    register(
        id='PctContinuous-v0',
        entry_point='pct_envs.PctContinuous0:PackingContinuous',
    )
    register(
        id='PctContinuous-v1',
        entry_point='pct_envs.PctContinuous0:PackingContinuousV1',
    )


def load_dataset(save_path, if_sol = False):

    idx_place = 3 if if_sol else 1

    def find_idx(name):
        return int(name.split('-')[idx_place].split('.')[0])

    if if_sol:
        sols = []
    else:
        distributions = []
        instances = []

    data_path = os.listdir(save_path)
    data_path.sort(key=find_idx)
        
    for i in range(len(data_path)):

        data = torch.load(os.path.join(save_path,data_path[i]))
        if if_sol:
            sols.append(data)
        else:
            distributions.append(data['dist'])
            instances.append(data['ins'])            

    if if_sol:
        sols = torch.stack(sols,dim=0)
        return sols
    else:
        distributions = torch.stack(distributions,dim=0)
        instances = torch.stack(instances,dim=0)
        return distributions, instances
    
def load_multiple_datasets(save_path, seq_len_set, sol_tag = None):

    if_sol = sol_tag is not None
    tag = 'num-box' if not if_sol else '{}-num-box'.format(sol_tag)
        

    if if_sol:
        sols = {}
    else:
        distributions = {}
        instances = {}

    for i in range(len(seq_len_set)):

        seq_len = seq_len_set[i]
        cur_save_path = os.path.join(save_path,'{}-{}'.format(tag, seq_len))

        if if_sol:
            cur_sols = load_dataset(cur_save_path, if_sol)
            sols[seq_len] = cur_sols
        else:
            cur_distributions, cur_instances = load_dataset(cur_save_path, if_sol)
            distributions[seq_len] = cur_distributions
            instances[seq_len] = cur_instances

    if if_sol:
        return sols
    else:
        return distributions, instances, torch.load(os.path.join(save_path,'box-set.pt'))

def save_sol(save_path, sol, seq_len, idx, tag = 'MCTS-sol'):

    save_path = os.path.join(save_path, '{}-num-box-{}'.format(tag, seq_len))

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    torch.save(
        sol,
        os.path.join(save_path, '{}-instance-{}.pt'.format(tag, idx))
    )


def load_model_mcts(policy,model_save_path, tag, sub_time_str):

    policy_model_state_dict = torch.load(os.path.join(model_save_path, '{}-{}.pt'.format(tag, sub_time_str)))
    load_dict = {}
    for k, v in policy_model_state_dict.items():
        if 'actor.embedder.layers' in k:
            load_dict[k.replace('module.weight', 'weight')] = v
        else:
            load_dict[k.replace('module.', '')] = v

    load_dict = {k.replace('add_bias.', ''): v for k, v in load_dict.items()}
    load_dict = {k.replace('_bias', 'bias'): v for k, v in load_dict.items()}
    for k, v in load_dict.items():
        if len(v.size()) <= 3:
            load_dict[k] = v.squeeze(dim=-1)
    policy.load_state_dict(load_dict, strict=True)