# External imports
import json
import logging
import os
import random
import time
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
import numpy as np

from pathlib import Path
from tqdm import tqdm

# Project imports
from nps.data import load_input_file, get_minibatch, shuffle_dataset, KarelDataset
from nps.network import IOs2Seq
from nps.reinforce import EnvironmentClasses, RewardCombinationFun
from nps.training_functions import (do_supervised_minibatch,
                                    do_syntax_weighted_minibatch,
                                    do_rl_minibatch,
                                    do_rl_minibatch_two_steps,
                                    do_beam_rl)
from nps.evaluate import evaluate_model
from syntax.checker import PySyntaxChecker
from karel.consistency import Simulator

import torch.distributed as dist
from torch.utils.data.dataloader import DataLoader
from torch.nn.utils.rnn import pad_sequence

from torch.utils.tensorboard import SummaryWriter
import nps.params as params

class TrainSignal(object):
    SUPERVISED = "supervised"
    RL = "rl"
    BEAM_RL = "beam_rl"

signals = ["supervised", "rl", "beam_rl"]

def add_train_cli_args(parser):
    train_group = parser.add_argument_group("Training",
                                            description="Training options")
    train_group.add_argument('--signal', type=str,
                             choices=signals,
                             default=signals[0],
                             help="Where to get gradients from"
                             "Default: %(default)s")
    train_group.add_argument('--nb_ios', type=int,
                             default=5)
    train_group.add_argument('--nb_epochs', type=int,
                             default=2,
                             help="How many epochs to train the model for. "
                             "Default: %(default)s")
    train_group.add_argument('--optim_alg', type=str,
                             default='Adam',
                             choices=['Adam', 'RMSprop', 'SGD'],
                             help="What optimization algorithm to use. "
                             "Default: %(default)s")
    train_group.add_argument('--batch_size', type=int,
                             default=32,
                             help="Batch Size for the optimization. "
                             "Default: %(default)s")
    train_group.add_argument('--learning_rate', type=float,
                             default=1e-3,
                             help="Learning rate for the optimization. "
                             "Default: %(default)s")
    train_group.add_argument("--train_file", type=str,
                             default="data/1m_6ex_karel/train.json",
                             help="Path to the training data. "
                             " Default: %(default)s")
    train_group.add_argument("--val_file", type=str,
                             default="data/1m_6ex_karel/val.json",
                             help="Path to the validation data. "
                             " Default: %(default)s")
    train_group.add_argument("--vocab", type=str,
                             default="data/1m_6ex_karel/new_vocab.vocab",
                             help="Path to the output vocabulary."
                             " Default: %(default)s")
    train_group.add_argument("--nb_samples", type=int,
                             default=0,
                             help="Max number of samples to look at."
                             "If 0, look at the whole dataset.")
    train_group.add_argument("--result_folder", type=str,
                             default="exps/fake_run",
                             help="Where to store the results. "
                             " Default: %(default)s")
    train_group.add_argument("--init_weights", type=str,
                             default=None)
    train_group.add_argument("--use_grammar", action="store_true")
    train_group.add_argument('--beta', type=float,
                             default=1e-3,
                             help="Gain applied to syntax loss. "
                             "Default: %(default)s")
    train_group.add_argument("--val_frequency", type=int,
                             default=1,
                             help="Frequency (in epochs) of validation.")

    rl_group = parser.add_argument_group("RL-specific training options")
    rl_group.add_argument("--environment", type=str,
                          choices=EnvironmentClasses.keys(),
                          default="BlackBoxGeneralization",
                          help="What type of environment to get a reward from"
                          "Default: %(default)s.")
    rl_group.add_argument("--reward_comb", type=str,
                          choices=RewardCombinationFun.keys(),
                          default="RenormExpected",
                          help="How to aggregate the reward over several samples.")
    rl_group.add_argument('--nb_rollouts', type=int,
                          default=100,
                          help="When using RL,"
                          "how many trajectories to sample per example."
                          "Default: %(default)s")
    rl_group.add_argument('--rl_beam', type=int,
                          default=50,
                          help="Size of the beam when doing reward"
                          " maximization over the beam."
                          "Default: %(default)s")
    rl_group.add_argument('--rl_inner_batch', type=int,
                          default=2,
                          help="Size of the batch on expanded candidates")
    rl_group.add_argument('--rl_use_ref', action="store_true")

def save_checkpoint(model, optimizer, epoch, best_val_acc, ckpt_path):

    def is_dist_avail_and_initialized():
        if not dist.is_available():
            return False
        if not dist.is_initialized():
            return False
        return True

    def get_rank():
        if not is_dist_avail_and_initialized():
            return 0
        return dist.get_rank()
    
    def is_main_process():
        return get_rank() == 0

    def save_on_master(*args, **kwargs):
        if is_main_process():
            torch.save(*args, **kwargs)

    checkpoint = {}
    raw_model = model.module if hasattr(model, "module") else model
    checkpoint.update({
        'model': raw_model.state_dict(),
    })
    checkpoint.update({
        'optimizer': optimizer.state_dict()
    })
    checkpoint.update({
        'epoch': epoch,
        'best_val_acc': best_val_acc
    })
    save_on_master(checkpoint, ckpt_path)

def load_checkpoint(model, optimizer, ckpt_path):
    path_to_weight_dump = ckpt_path
    weight_ckpt = torch.load(path_to_weight_dump)

    raw_model = model.module if hasattr(model, "module") else model
    raw_model.load_state_dict(weight_ckpt['model'])
    loaded_model = model

    optimizer.load_state_dict(weight_ckpt['optimizer'])
    loaded_optimizer = optimizer
    start_epoch = weight_ckpt['epoch'] + 1
    best_val_acc = weight_ckpt['best_val_acc']
    return loaded_model, loaded_optimizer, start_epoch, best_val_acc

def init_distributed_mode():
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ['WORLD_SIZE'])
        gpu = int(os.environ['LOCAL_RANK'])
    elif 'SLURM_PROCID' in os.environ:
        rank = int(os.environ['SLURM_PROCID'])
        gpu = rank % torch.cuda.device_count()
    else:
        print('Not using distributed mode')
        distributed = False
        return
    distributed = True

    torch.cuda.set_device(gpu)
    dist_backend = 'nccl'
    dist_url = 'env://'
    print('| distributed init (rank {}): {}'.format(
        rank, dist_url), flush=True)
    torch.distributed.init_process_group(backend=dist_backend, init_method=dist_url,
                                         world_size=world_size, rank=rank)
    setup_for_distributed(rank == 0)
    return gpu
    
def setup_for_distributed(is_master):
    """    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print


def train_seq2seq_model(
        # Optimization
        signal, nb_ios, nb_epochs, optim_alg,
        batch_size, learning_rate, use_grammar, beta, val_frequency,
        # Model
        kernel_size, conv_stack, fc_stack,
        tgt_embedding_size, lstm_hidden_size, nb_lstm_layers,
        learn_syntax,
        # RL specific options
        environment, reward_comb, nb_rollouts,
        rl_beam, rl_inner_batch, rl_use_ref,
        # What to train
        train_file, val_file, vocab_file, nb_samples, initialisation,
        # Where to write results
        result_folder, args_dict,
        # Run options
        use_cuda, log_frequency):

    gpu = init_distributed_mode()

    #############################
    # Admin / Bookkeeping stuff #
    #############################
    # Creating the results directory
    
    #if not params.load_from_checkpoint:
    #    date = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
    #    save_dir = Path(params.model_output_path + date)
    #else:
    #    save_dir = Path(params.model_output_path + params.load_from_checkpoint)
    save_dir = Path('/'.join(train_file.split('/')[:-1]))
    #save_dir = Path(params.model_output_path + 'test2')
    print('saved to:', save_dir)

    if dist.get_rank() == 0:
        if not save_dir.exists():
            os.makedirs(str(save_dir))
        tb_writer = SummaryWriter(save_dir)
    dist.barrier()
    
    result_dir = save_dir

    print('here1!!!!!!!!!!!!!')
    # Dumping the arguments
    args_dump_path = result_dir / "args.json"
    with open(str(args_dump_path), "w") as args_dump_file:
        json.dump(args_dict, args_dump_file, indent=2)
    # Setting up the logs
    log_file = result_dir / "logs.txt"
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        filename=str(log_file),
        filemode='w'
    )
    train_loss_path = result_dir / "train_loss.json"
    models_dir = result_dir / "Weights"
    print(models_dir)
    if not models_dir.exists() and dist.get_rank() == 0:
        os.makedirs(str(models_dir))
        time.sleep(1)  # Let some time for the dir to be created
    dist.barrier()

    #####################################
    # Load Model / Dataset / Vocabulary #
    #####################################
    # Load-up the dataset
    # TODO: train_file
    print(train_file)
    dataset, vocab = load_input_file(train_file, vocab_file)

    print('here2!!!!!!!!!!!!!')
    if use_grammar:
        syntax_checker = PySyntaxChecker(vocab["tkn2idx"], use_cuda)
    # Reduce the number of samples in the dataset, if needed
    if nb_samples > 0:
        # Randomize the dataset to shuffle it, because I'm not sure that there
        # is no meaning in the ordering of the samples
        random.seed(0)
        dataset = shuffle_dataset(dataset, batch_size)
        dataset = {
            'sources' : dataset['sources'][:nb_samples],
            'targets' : dataset['targets'][:nb_samples],
        }

    vocabulary_size = len(vocab["tkn2idx"])
    print('here3!!!!!!!!!!!!!')
    best_val_acc = np.NINF
    start_epoch = 0
    model = IOs2Seq(kernel_size, conv_stack, fc_stack,
                    vocabulary_size, tgt_embedding_size,
                    lstm_hidden_size, nb_lstm_layers,
                    learn_syntax)
    if use_grammar:
        model.set_syntax_checker(syntax_checker)
    if torch.cuda.is_available():
        device = torch.cuda.current_device()
        print(device)
        model = model.to(device)
        print(gpu)
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu], find_unused_parameters=True)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    # Setup the optimizers
    optimizer_cls = getattr(optim, optim_alg)
    optimizer = optimizer_cls(model.parameters(),
                              lr=learning_rate)
    if (save_dir / 'Weights' / 'latest.model').exists():
        model, optimizer, start_epoch, best_val_acc = load_checkpoint(model, optimizer, str(save_dir) + '/Weights' + '/latest.model')
    print(start_epoch)


    tgt_start = vocab["tkn2idx"]["<s>"]
    tgt_end = vocab["tkn2idx"]["m)"]
    tgt_pad = vocab["tkn2idx"]["<pad>"]

    ############################################
    # Setup Loss / Optimizer / Eventual Critic #
    ############################################
    if signal == TrainSignal.SUPERVISED:
        # Create a mask to not penalize bad prediction on the padding
        weight_mask = torch.ones(vocabulary_size)
        weight_mask[tgt_pad] = 0
        # Setup the criterion
        loss_criterion = nn.CrossEntropyLoss(weight=weight_mask)
    elif signal == TrainSignal.RL or signal == TrainSignal.BEAM_RL:
        simulator = Simulator(vocab["idx2tkn"])

        if signal == TrainSignal.BEAM_RL:
            reward_comb_fun = RewardCombinationFun[reward_comb]
    else:
        raise Exception("Unknown TrainingSignal.")

    if use_cuda:
#        model.cuda()
        if signal == TrainSignal.SUPERVISED:
            loss_criterion.cuda()

    #####################
    # ################# #
    # # Training Loop # #
    # ################# #
    #####################

    def my_collate(batch):
        inp_grids = torch.stack([item[0] for item in batch], 0)
        out_grids = torch.stack([item[1] for item in batch], 0)
        in_tgt_seq = [item[2] for item in batch]
        input_lines = [item[3] for item in batch]
        out_tgt_seq = [item[4] for item in batch]
        inp_worlds = [item[5] for item in batch]
        out_worlds = [item[6] for item in batch]
        targets = [item[7] for item in batch]
        inp_test_worlds = [item[8] for item in batch]
        out_test_worlds = [item[9] for item in batch]

        in_tgt_seq = pad_sequence(in_tgt_seq, batch_first=True, padding_value=tgt_pad)
        in_tgt_seq = in_tgt_seq[:, :-1]
        out_tgt_seq = pad_sequence(out_tgt_seq, batch_first=True, padding_value=tgt_pad)
        
#        input_lines = pad_sequence(input_lines, batch_first=True, padding_value=tgt_pad)

        input_lines = in_tgt_seq.tolist()

        return [inp_grids, out_grids, in_tgt_seq, input_lines, out_tgt_seq, \
            inp_worlds, out_worlds, targets, inp_test_worlds, out_test_worlds]

    losses = []
    recent_losses = []
    train_dataset = KarelDataset(dataset, tgt_start, tgt_end, tgt_pad, nb_ios, shuffle=False)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = DataLoader(train_dataset, shuffle=False, pin_memory=False,
                        batch_size=batch_size, sampler=train_sampler,
                        num_workers=0, collate_fn=my_collate)
    for epoch_idx in range(start_epoch, nb_epochs):
        train_sampler.set_epoch(epoch_idx)
        nb_ios_for_epoch = nb_ios
        # This is definitely not the most efficient way to do it but oh well
#        dataset = shuffle_dataset(dataset, batch_size)

        print('here21')
        pbar = tqdm(enumerate(train_loader), total=len(train_loader))
        print('here22')
        for batch_idx, (inp_grids, out_grids,
            in_tgt_seq, in_tgt_seq_list, out_tgt_seq,
            inp_worlds, out_worlds,
            targets,
            inp_test_worlds, out_test_worlds) in pbar:

            #random_ios = random.randint(1,5)
            #inp_grids = inp_grids[:, :random_ios]
            #out_grids = out_grids[:, :random_ios]

            optimizer.zero_grad()
            model.train()

            if signal == TrainSignal.SUPERVISED:
#                sp_idx = 0
#                inp_grids2, out_grids2, \
#                    in_tgt_seq2, in_tgt_seq_list2, out_tgt_seq2, \
#                    _, _, _, _, _ = get_minibatch(dataset, sp_idx, batch_size,
#                                                  tgt_start, tgt_end, tgt_pad,
#                                                  nb_ios_for_epoch)
##                torch.set_printoptions(profile="full")
##                print(inp_grids)
##                print('---------------------')
##                print(in_tgt_seq)
##                print('---------------------')
##                print(out_tgt_seq)
##                print('---------------------')
##                print(in_tgt_seq_list)
#                print('---------------------')
#                print(inp_grids)
#                print('---------------------')
#                print(inp_grids2)
#                print('---------------------')
#                print(inp_grids.shape)
#                print(inp_grids2.shape)
#                print(out_grids.shape)
#                print(out_grids2.shape)
#                print('---------------------')
#                print(in_tgt_seq)
#                print(in_tgt_seq2)
#                print(in_tgt_seq.shape)
#                print(in_tgt_seq2.shape)
#                print('---------------------')
#                print(out_tgt_seq)
#                print(out_tgt_seq2)
#                print(out_tgt_seq.shape)
#                print(out_tgt_seq2.shape)
#                print('---------------------')
#                print(in_tgt_seq_list)
#                print(in_tgt_seq_list2)
                if use_cuda:
                    inp_grids, out_grids = inp_grids.cuda(), out_grids.cuda()
                    in_tgt_seq, out_tgt_seq = in_tgt_seq.cuda(), out_tgt_seq.cuda()
                if learn_syntax:
                    minibatch_loss = do_syntax_weighted_minibatch(model,
                                                                  inp_grids, out_grids,
                                                                  in_tgt_seq, in_tgt_seq_list,
                                                                  out_tgt_seq,
                                                                  loss_criterion, beta)
                else:
                    minibatch_loss = do_supervised_minibatch(model,
                                                             inp_grids, out_grids,
                                                             in_tgt_seq, in_tgt_seq_list,
                                                             out_tgt_seq, loss_criterion)
                recent_losses.append(minibatch_loss)
            elif signal == TrainSignal.RL or signal == TrainSignal.BEAM_RL:
                if use_cuda:
                    inp_grids, out_grids = inp_grids.cuda(), out_grids.cuda()
                    in_tgt_seq, out_tgt_seq = in_tgt_seq.cuda(), out_tgt_seq.cuda()
#                inp_grids, out_grids, \
#                    _, _, _, \
#                    inp_worlds, out_worlds, \
#                    targets, \
#                    inp_test_worlds, out_test_worlds = get_minibatch(dataset, sp_idx, batch_size,
#                                                                     tgt_start, tgt_end, tgt_pad,
#                                                                     nb_ios_for_epoch)
#                if use_cuda:
#                    inp_grids, out_grids = inp_grids.cuda(), out_grids.cuda()
                # We use 1/nb_rollouts as the reward to normalize wrt the
                # size of the rollouts
                if signal == TrainSignal.RL:
                    reward_norm = 1 / float(nb_rollouts)
                elif signal == TrainSignal.BEAM_RL:
                    reward_norm = 1
                else:
                    raise NotImplementedError("Unknown training signal")

                lens = [len(target) for target in targets]
                max_len = max(lens) + 10
                env_cls = EnvironmentClasses[environment]
                if "Consistency" in environment:
                    envs = [env_cls(reward_norm, trg_prog, sp_inp_worlds, sp_out_worlds, simulator)
                            for trg_prog, sp_inp_worlds, sp_out_worlds
                            in zip(targets, inp_worlds, out_worlds)]
                elif "Generalization" or "Perf" in environment:
                    envs = [env_cls(reward_norm, trg_prog, sp_inp_test_worlds, sp_out_test_worlds, simulator )
                            for trg_prog, sp_inp_test_worlds, sp_out_test_worlds
                            in zip(targets, inp_test_worlds, out_test_worlds)]
                else:
                    raise NotImplementedError("Unknown environment type")

                if signal == TrainSignal.RL:
                    minibatch_reward = do_rl_minibatch(model,
                                                       inp_grids, out_grids,
                                                       envs,
                                                       tgt_start, tgt_end, max_len,
                                                       nb_rollouts)
                    # minibatch_reward = do_rl_minibatch_two_steps(model,
                    #                                              inp_grids, out_grids,
                    #                                              envs,
                    #                                              tgt_start, tgt_end, tgt_pad,
                    #                                              max_len, nb_rollouts,
                    #                                              rl_inner_batch)
                elif signal == TrainSignal.BEAM_RL:
                    minibatch_reward = do_beam_rl(model,
                                                  inp_grids, out_grids, targets,
                                                  envs, reward_comb_fun,
                                                  tgt_start, tgt_end, tgt_pad,
                                                  max_len, rl_beam, rl_inner_batch, rl_use_ref)
                else:
                    raise NotImplementedError("Unknown Environment type")
                recent_losses.append(minibatch_reward)
            else:
                raise NotImplementedError("Unknown Training method")
            optimizer.step()
            if (batch_idx % log_frequency == log_frequency-1 and len(recent_losses) > 0) or \
               (len(train_loader) - 1) <= batch_size:
#               (len(dataset["sources"]) - sp_idx ) < batch_size:
                logging.info('Epoch : %d Minibatch : %d Loss : %.5f' % (
                    epoch_idx, batch_idx, sum(recent_losses)/len(recent_losses))
                )
                global_step = int(batch_idx + epoch_idx * (len(train_dataset) / batch_size / dist.get_world_size()))
                if dist.get_rank() == 0:
                    tb_writer.add_scalar("LGRL/loss", sum(recent_losses)/len(recent_losses), global_step)
                dist.barrier()
                losses.extend(recent_losses)
                recent_losses = []
                # Dump the training losses
                with open(str(train_loss_path), "w") as train_loss_file:
                    json.dump(losses, train_loss_file, indent=2)

                if signal == TrainSignal.BEAM_RL:
                    # RL is much slower so we dump more frequently
                    path_to_weight_dump = models_dir / ("weights_%d.model" % epoch_idx)
#                    with open(str(path_to_weight_dump), "wb") as weight_file:
                        # Needs to be in cpu mode to dump, otherwise will be annoying to load
#                        if use_cuda:
#                            model.cpu()
#                        torch.save(model, weight_file)
                    weight_file = str(path_to_weight_dump)
                    save_checkpoint(model, optimizer, epoch_idx, best_val_acc, weight_file)
#                        if use_cuda:
#                            model.cuda()

            #pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}")
            pbar.set_description(f"epoch {epoch_idx}")

        # Dump the weights at the end of the epoch
        path_to_weight_dump = models_dir / ("weights_%d.model" % epoch_idx)
#        with open(str(path_to_weight_dump), "wb") as weight_file:
            # Needs to be in cpu mode to dump, otherwise will be annoying to load
#            if use_cuda:
#                model.cpu()
#            torch.save(model, weight_file)
        weight_file = str(path_to_weight_dump)
        save_checkpoint(model, optimizer, epoch_idx, best_val_acc, weight_file)
        weight_file = str(models_dir / ("latest.model"))
        print(epoch_idx)
        save_checkpoint(model, optimizer, epoch_idx, best_val_acc, weight_file)
        # Dump the training losses
        if dist.get_rank() == 0:
            with open(str(train_loss_path), "w") as train_loss_file:
                json.dump(losses, train_loss_file, indent=2)
        dist.barrier()

        logging.info("Done with epoch %d." % epoch_idx)

        if (epoch_idx+1) % val_frequency == 0 or (epoch_idx+1) == nb_epochs:
            # Evaluate the model on the validation set
            out_path = str(result_dir / ("eval/epoch_%d/val_.txt" % epoch_idx))
            val_acc, exact_acc = evaluate_model(model, str(path_to_weight_dump), vocab_file,
                                     val_file, 5, 0, use_grammar,
                                     out_path, 100, 50, batch_size,
                                     use_cuda, False)
            val_acc = 0 if val_acc is None else val_acc
            exact_acc = 0 if exact_acc is None else exact_acc
            logging.info("Epoch : %d ValidationAccuracy : %f." % (epoch_idx, exact_acc))
            if dist.get_rank() == 0:
                tb_writer.add_scalar("LGRL/generalization_acc", val_acc, epoch_idx)
                tb_writer.add_scalar("LGRL/exact_acc", exact_acc, epoch_idx)
            dist.barrier()
            if exact_acc > best_val_acc:
                logging.info("Epoch : %d ValidationBest : %f." % (epoch_idx, val_acc))
                best_val_acc = exact_acc
                path_to_weight_dump = models_dir / "best.model"
#                with open(str(path_to_weight_dump), "wb") as weight_file:
                    # Needs to be in cpu mode to dump, otherwise will be annoying to load
#                    if use_cuda:
#                        model.cpu()
#                    torch.save(model, weight_file)
#                    if use_cuda:
#                        model.cuda()
                weight_file = str(path_to_weight_dump)
                save_checkpoint(model, optimizer, epoch_idx, best_val_acc, weight_file)
