from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import json
import os
import random
import multiprocessing
import time 
from pathlib import Path
from tqdm import tqdm
import numpy as np

import torch
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable
from torch import nn
import torch.nn.functional as F

import params
from model.combinar_MI import CombinarMI
from cuda import use_cuda, LongTensor, FloatTensor
from env.env import ProgramEnv
from env.operator import Operator, operator_to_index
from env.statement import Statement, statement_to_index
from dsl.program import Program
from dsl.example import Example

from utils_hd import *

from sklearn import preprocessing


torch.manual_seed(params.seed)
torch.cuda.manual_seed_all(params.seed)
np.random.seed(params.seed)
random.seed(params.seed)


def generate_prog_data(line):
    data = json.loads(line.rstrip())
    examples = Example.from_line(data)
    env = ProgramEnv(examples)
    program = Program.parse(data['program'])

    inputs = []
    statements = []
    drop = []
    operators = []
    programs = []
    steps = []
    input_nums = []
    types = []
    program_lengths = []


    for i, statement in enumerate(program.statements):
        inputs.append(env.get_encoding())

        # Translate absolute indices to post-drop indices
        f, args = statement.function, list(statement.args)
        for j, arg in enumerate(args):
            if isinstance(arg, int):
                args[j] = env.real_var_idxs.index(arg)

        statement = Statement(f, args)
        statements.append(statement_to_index[statement])

        used_args = []
        for next_statement in program.statements[i:]:
            used_args += [x for x in next_statement.args if isinstance(x, int)]

        to_drop = []
        for j in range(params.max_program_vars):
            if j >= env.num_vars or env.real_var_idxs[j] not in used_args:
                to_drop.append(1)
            else:
                to_drop.append(0)

        drop.append(to_drop)
        rand_idx = random.choice([j for j in range(len(to_drop)) if to_drop[j] > 0])

        operator = Operator.from_statement(statement)
        operators.append(operator_to_index[operator])

        if env.num_vars < params.max_program_vars:
            env.step(statement)
        else:
            # Choose a random var (that is not used anymore) to drop.
            env.step(statement, rand_idx)

        programs.append(str(program))
        # print("Inputs Shape:", [inp for inp in inputs])
        # print("Statements:", statements)
        # print("Drop:", drop)
        # print("Operators:", operators)
        steps.append(i)
        input_nums.append(len(program.input_types))
        typ = program.input_types + (3 - len(program.input_types)) * ['NULL'] + program.var_types[-1:]
        typ_token = []
        # for choice in query
        for item in typ:
            if str(item) == 'LIST':
                typ_token.append([0, 1, 0])
            elif str(item) == 'INT':
                typ_token.append([1, 0, 0])
            elif str(item) == 'NULL':
                typ_token.append([0, 0, 1])
            else:
                raise ValueError('bad type {}'.format(item))
        types.append([typ_token] * 5) # 5 ios same types
#    print('statements:', statements)
#    print('i:', i)
    program_seq = [[1] + [idx + 3 for idx in statements] + [2] + [0] * (4 - len(statements))] * (i + 1)
    program_lengths = [len(statements) + 2] * (i + 1)
#    print('program_seq:', program_seq)
    return inputs, statements, drop, operators, programs, steps, input_nums, types, program_seq, program_lengths


def load_data(fileobj, max_len):
    X = []
    Y = []
    Z = []
    W = []
    P = []
    S = []
    N = []
    T = []
    SEQ = []
    L = []

    #print("Loading dataset...")
    lines = fileobj.read().splitlines()
    if max_len is not None:
        selected_lines = random.sample(lines, max_len)
        lines = selected_lines

#    pool = multiprocessing.Pool(processes=1)
#    res = list(tqdm(pool.imap(generate_prog_data, lines[-12800:]), total=len(lines[-12800:])))
    res = []
    for line in tqdm(lines[-12800:], total=len(lines[-12800:])):
#    for line in tqdm(lines, total=len(lines)):
        res.append(generate_prog_data(line))

    for input, target, to_drop, operators, programs, steps, input_nums, types, program_seq, program_lengths in res:
        X += input
        Y += target
        Z += to_drop
        W += operators
        P += programs
        S += steps
        N += input_nums
        T += types 
        SEQ += program_seq 
        L += program_lengths

    return np.array(X), np.array(Y), np.array(Z), np.array(W), np.array(P), np.array(S), np.array(N), np.array(T), np.array(SEQ), np.array(L)


def train():
    params.gpus = init_distributed_mode() 
    print(params.gpus)
    # Define paths for storing tensorboard logs
#    date = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
#    save_dir = params.model_output_path + '/' + date + '/PE_model/'
    if not params.load_from_checkpoint:
        date = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
        save_dir = Path(params.model_output_path + date)
    else:
        save_dir = Path(params.model_output_path + params.load_from_checkpoint)
    print('saved to:', save_dir)
    if not save_dir.exists():
        if dist.get_rank() == 0:
            os.makedirs(str(save_dir))
        dist.barrier()
    tb_writer = SummaryWriter(save_dir)

    # Load train and val data
    with open(params.train_path, 'r') as f:
        train_data, train_statement_target, train_drop_target, train_operator_target, train_program, train_step, train_input_num, train_typ, train_program_seq, train_plengths = load_data(f, params.max_len)

    with open(params.val_path, 'r') as f:
        val_data, val_statement_target, val_drop_target, val_operator_target, val_program, val_step, val_input_num, val_typ, _, _ = load_data(f, params.max_len)

    print(train_program[0:1])

    le = preprocessing.LabelEncoder()
    programs = le.fit_transform(np.concatenate([train_program, val_program]))
    train_program = programs[:len(train_program)]
    val_program = programs[len(train_program):]

    print(train_program[0:1])
    print(le.inverse_transform(train_program[0:1]))
    # Define model
    
    model = CombinarMI(le)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    device = torch.cuda.current_device()
    model = model.to(device)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[params.gpus], find_unused_parameters=False)

    #Define optimizer and loss
#    optimizer = torch.optim.Adam(model.parameters(), lr=params.learn_rate)
    main_params = model.module.main_parameters()
    info_params = model.module.info_parameters()
    main_learning_rate = params.learn_rate
    info_learning_rate = params.learn_rate
    main_optimizer = torch.optim.Adam(main_params, lr=main_learning_rate)
    info_optimizer = torch.optim.Adam(info_params, lr=info_learning_rate)
#    scheduler = ReduceLROnPlateau(optimizer=main_optimizer, mode='min',
#                                  factor=0.1, patience=args.patience,
#                                  verbose=True, min_lr=1e-7)
#    info_scheduler = ReduceLROnPlateau(optimizer=info_optimizer, mode='min',
#                                       factor=0.1, patience=args.patience,
#                                       verbose=True, min_lr=1e-7)

    models = [model]
    models_name = ['model']
    optims = [main_optimizer, info_optimizer]
    optims_name = ['main_optimizer', 'info_optimizer']

    start_epoch = 0
    best_val_error = np.inf
    if params.load_from_checkpoint:
        print("=> loading checkpoint '{}'".format(params.checkpoint_dir))
        loaded_models, loaded_opts, start_epoch, best_val_error = \
            load_checkpoint(models, optims, models_name, optims_name, params.load_from_checkpoint)
        model = loaded_models[0]
        main_optimizer, info_optimizer = loaded_opts

    main_lr_sched = torch.optim.lr_scheduler.StepLR(main_optimizer, step_size=params.lr_scheduler_step_size)
    info_lr_sched = torch.optim.lr_scheduler.StepLR(info_optimizer, step_size=params.lr_scheduler_step_size)

    statement_criterion = nn.CrossEntropyLoss()
    drop_criterion = nn.BCELoss()
    operator_criterion = nn.CrossEntropyLoss()
    l2_criterion = nn.MSELoss()

    #Convert to appropriate types
    # The cuda types are not used here on purpose - most GPUs can't handle so much memory
    train_data, train_statement_target, train_drop_target, train_operator_target, train_program, train_step, train_input_num, train_typ, train_program_seq, train_plengths = \
        torch.LongTensor(train_data), torch.LongTensor(train_statement_target), \
        torch.FloatTensor(train_drop_target), torch.LongTensor(train_operator_target), \
        torch.LongTensor(train_program).view(-1, 1), torch.LongTensor(train_step), torch.LongTensor(train_input_num), torch.LongTensor(train_typ), torch.LongTensor(train_program_seq), torch.LongTensor(train_plengths)
    val_data, val_statement_target, val_drop_target, val_operator_target, val_program, val_step, val_input_num, val_typ = \
        torch.LongTensor(val_data), torch.LongTensor(val_statement_target), \
        torch.FloatTensor(val_drop_target), torch.LongTensor(val_operator_target), \
        torch.LongTensor(val_program).view(-1, 1), torch.LongTensor(val_step), torch.LongTensor(val_input_num), torch.LongTensor(val_typ)


    val_data = Variable(val_data.type(LongTensor))
    val_statement_target = Variable(val_statement_target.type(LongTensor))
    val_drop_target = Variable(val_drop_target.type(FloatTensor))
    val_operator_target = Variable(val_operator_target.type(LongTensor))

    train_dataset = TensorDataset(train_data, train_statement_target, train_drop_target, train_operator_target, train_program, train_step, train_input_num, train_typ, train_program_seq, train_plengths)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_data_loader = DataLoader(train_dataset, batch_size=params.batch_size,
                            shuffle=False, pin_memory=False, sampler=train_sampler, num_workers=0)

    val_dataset = TensorDataset(val_data, val_statement_target, val_drop_target, val_operator_target, val_program, val_step, val_input_num, val_typ)
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
    val_data_loader = DataLoader(val_dataset, batch_size=params.batch_size,
                            shuffle=False, pin_memory=False, sampler=val_sampler, num_workers=0)

    x_s = torch.zeros(params.batch_size, 1, 4, params.max_list_len, device='cuda').long() + params.integer_range
    for epoch in range(start_epoch, params.num_epochs):
        train_sampler.set_epoch(epoch)
        model.train()
        print("Epoch %d" % epoch)

        train_statement_losses, train_drop_losses, train_operator_losses = [], [], []
        train_z_n_kl_loss, train_t_n_kl_loss, train_z_t_kl_loss, train_recon_io_loss, train_recon_p_loss = [], [], [], [], []

        batch_idx = 0
        for batch in tqdm(train_data_loader):
#            query_num = random.randint(1, int(min(5, epoch / 2 + 1)))
#            query_num = 5
            query_num = random.randint(2, 5)
            x = Variable(batch[0].type(LongTensor))
            y = Variable(batch[1].type(LongTensor))
            z = Variable(batch[2].type(FloatTensor))
            w = Variable(batch[3].type(LongTensor))
            p = Variable(batch[4].type(LongTensor))
            s = Variable(batch[5].type(LongTensor))
            n = Variable(batch[6].type(LongTensor))
            t = Variable(batch[7].type(LongTensor))
            program_seq = Variable(batch[8].type(LongTensor))
            plengths = Variable(batch[9].type(LongTensor))

#            torch.set_printoptions(profile="full")            
#            print(x)

            main_optimizer.zero_grad()
            info_optimizer.zero_grad()
            # mutual information loss
            z_n_kl_loss = torch.tensor([0.]).cuda()
            t_n_kl_loss = torch.tensor([0.]).cuda()
            z_t_kl_loss = torch.tensor([0.]).cuda()
            recon_io_loss = torch.tensor([0.]).cuda()
            recon_p_loss = torch.tensor([0.]).cuda()
            # ps loss
            statement_loss = torch.tensor([0.]).cuda()
            drop_loss = torch.tensor([0.]).cuda()
            operator_loss = torch.tensor([0.]).cuda()

            loss = torch.tensor([0.]).cuda()
            info_loss = torch.tensor([0.]).cuda()
            #TODO: distance loss
            distance = []
            var = None
            var_types = None
            x = x_s[:x.shape[0]].clone() #torch.cat([t.unsqueeze(1)[:, :, :, :-1], x_s[:x.shape[0]].clone()], axis=-1)
            x = F.one_hot(x, params.integer_range + 1).float()

            io_features_list = []            
            program_features_list = []            
            for i in range(query_num):
#                (pred_act, pred_drop, pred_operator), x, var, var_types = model(x, var, var_types, p, s, z, t, i)
                typ = t[:, :int(max(1, i))]
                torch.save({'x':x, 'typ':typ, 'program_seq':program_seq, 'plengths':plengths}, 'x%d.pt'%(batch_idx%2))
                #print(x.argmax(-1))
                io_features = model.module.query.encode_io(x, typ)
                program_features = model.module.query.encode_program(program_seq, plengths)
                io_features_list.append(io_features)
                program_features_list.append(program_features)
#                mus, logvars = model.module.query.encode_into_z(io_features, program_features)
                mus, logvars = model.module.query.encode_into_z2(program_features)
                zs = model.module.query.reparameterize(mus, logvars)
#                query_inp, query_index = model.module.query.decode_query(zs, io_features, typ, params.hard_softmax)
                query_inp, query_index = model.module.query.decode_query(zs, program_features, typ, params.hard_softmax)
                (pred_act, pred_drop, pred_operator), x, var, var_types = \
                    model.module.ps(query_index, query_inp, p, s, z, i, x, var, var_types)
                # PS loss
                statement_loss += statement_criterion(pred_act, y)
                drop_loss += drop_criterion(pred_drop, z)
                operator_loss += operator_criterion(pred_operator, w)
                # Varvational loss.
                z_n_kl_loss += gaussian_KL_loss(mus, logvars)
                # t-space KL loss
                if params.t_space:
#                    io_features_t = io_features.clone().detach() 
                    io_features_t = io_features 
#                    io_features_t = model.module.query.encode_io_t(x, typ)
                    t_mus, t_logvars = model.module.query.encode_into_t(io_features_t)
                    t_n_kl_loss += gaussian_KL_loss(t_mus, t_logvars)
                    z_t_kl_loss += compute_two_gaussian_loss(mus, logvars, t_mus, t_logvars)

            statement_loss /= query_num
            drop_loss /= query_num
            operator_loss /= query_num
            z_n_kl_loss /= query_num
            t_n_kl_loss /= query_num
            z_t_kl_loss /= query_num
            loss += params.lambda_ps * (statement_loss + operator_loss + drop_loss)
            loss += params.lambda_z_n * z_n_kl_loss
            loss += params.lambda_t_n * t_n_kl_loss
            loss += params.lambda_z_t * z_t_kl_loss

            loss.backward()
            main_optimizer.step()

            # Reconstruction.
            if params.io_recon or params.program_recon:
                main_optimizer.zero_grad()
                info_optimizer.zero_grad()
                for i in range(query_num):
                    io_features = io_features_list[i]
                    program_features = program_features_list[i]
                    io_targets = io_features.detach()
                    program_targets = program_features.detach()
                    recon_io_features, recon_p_features = model.module.query.reconstruct_inputs(
                        io_targets, program_targets)
                    if params.io_recon:
                        recon_io_loss += l2_criterion(recon_io_features, io_targets)
                    if params.program_recon:
                        recon_p_loss += l2_criterion(recon_p_features, program_targets)
                
                recon_io_loss /= query_num
                recon_p_loss /= query_num
                info_loss += params.lambda_io * recon_io_loss
                info_loss += params.lambda_program * recon_p_loss
                info_loss.backward()
                info_optimizer.step()

            train_statement_losses.append(statement_loss.item() * y.shape[0])
            train_drop_losses.append(drop_loss.item() * z.shape[0])
            train_operator_losses.append(operator_loss.item() * w.shape[0])

            train_z_n_kl_loss.append(z_n_kl_loss.item() * y.shape[0])
            train_t_n_kl_loss.append(t_n_kl_loss.item() * y.shape[0])
            train_z_t_kl_loss.append(z_t_kl_loss.item() * y.shape[0])

            train_recon_io_loss.append(recon_io_loss.item() * y.shape[0])
            train_recon_p_loss.append(recon_p_loss.item() * y.shape[0])

            global_step = int(batch_idx + epoch * (len(train_data) / params.batch_size))
            tb_writer.add_scalar("loss_batch/ps_loss", (statement_loss + operator_loss + drop_loss).item(), global_step)
            tb_writer.add_scalar("loss_batch/z_n_kl_loss", z_n_kl_loss.item(), global_step)
            tb_writer.add_scalar("loss_batch/t_n_kl_loss", t_n_kl_loss.item(), global_step)
            tb_writer.add_scalar("loss_batch/z_t_kl_loss", z_t_kl_loss.item(), global_step)
            tb_writer.add_scalar("loss_batch/recon_io_loss", recon_io_loss.item(), global_step)
            tb_writer.add_scalar("loss_batch/recon_p_loss", recon_p_loss.item(), global_step)

            batch_idx += 1

        avg_statement_train_loss = np.array(train_statement_losses).sum()
        avg_drop_train_loss = np.array(train_drop_losses).sum()
        avg_operator_train_loss = np.array(train_operator_losses).sum()

        avg_z_n_kl_loss = np.array(train_z_n_kl_loss).sum()
        avg_t_n_kl_loss = np.array(train_t_n_kl_loss).sum()
        avg_z_t_kl_loss = np.array(train_z_t_kl_loss).sum()

        avg_recon_io_loss = np.array(train_recon_io_loss).sum()
        avg_recon_p_loss = np.array(train_recon_p_loss).sum()

        model.eval()

        with torch.no_grad():
            query_num = 5

            total_statement_loss = 0.
            total_drop_loss = 0.
            total_operator_loss = 0.
            total_val_error = 0.
            for batch in tqdm(val_data_loader):
                x = Variable(batch[0].type(LongTensor))
                y = Variable(batch[1].type(LongTensor))
                z = Variable(batch[2].type(FloatTensor))
                w = Variable(batch[3].type(LongTensor))
                p = Variable(batch[4].type(LongTensor))
                s = Variable(batch[5].type(LongTensor))
                n = Variable(batch[6].type(LongTensor))
                t = Variable(batch[7].type(LongTensor))

                statement_loss = torch.tensor([0.]).cuda()
                drop_loss = torch.tensor([0.]).cuda()
                operator_loss = torch.tensor([0.]).cuda()
                loss = torch.tensor([0.]).cuda()
                distance = []
                var = None
                var_types = None
                x = x_s[:x.shape[0]].clone()
                x = F.one_hot(x, params.integer_range + 1).float()
                for i in range(query_num):
                    output, x, var, var_types = model.module.predict(x, var, var_types, p, s, z, t, i)
                val_statement_pred = output[0]
                val_drop_pred = output[1]
                val_operator_pred = output[2]

                val_statement_loss = statement_criterion(val_statement_pred, y)
                val_drop_loss = drop_criterion(val_drop_pred, z)
                val_operator_loss = operator_criterion(val_operator_pred, w)

                total_statement_loss += (val_statement_loss * y.shape[0]).item()
                total_drop_loss += (val_drop_loss * z.shape[0]).item()
                total_operator_loss += (val_operator_loss * w.shape[0]).item()

                predict = val_statement_pred.data.max(1)[1]
                total_val_error += (predict != y.data).sum().item()

            main_lr_sched.step()
            info_lr_sched.step()

            metrics = [
                avg_statement_train_loss,
                avg_drop_train_loss,
                avg_operator_train_loss,
                avg_z_n_kl_loss,
                avg_t_n_kl_loss,
                avg_z_t_kl_loss,
                avg_recon_io_loss,
                avg_recon_p_loss,
                total_statement_loss,
                total_drop_loss,
                total_operator_loss,
                total_val_error,
            ]
            reduce_tmp = []
            for metric in metrics:
                reduce_tmp.append(torch.tensor([metric], dtype=torch.float64, device='cuda')) 
            dist.barrier()
            for tensor_idx in range(len(reduce_tmp[:-4])):
                dist.all_reduce(reduce_tmp[tensor_idx])
                metrics[tensor_idx] = reduce_tmp[tensor_idx].tolist()[0]/len(train_data)
            for tensor_idx in range(len(reduce_tmp[-4:])):
                dist.all_reduce(reduce_tmp[-4+tensor_idx])
                metrics[-4+tensor_idx] = reduce_tmp[-4+tensor_idx].tolist()[0]/len(val_data)

            avg_statement_train_loss, \
            avg_drop_train_loss, \
            avg_operator_train_loss, \
            avg_z_n_kl_loss, \
            avg_t_n_kl_loss, \
            avg_z_t_kl_loss, \
            avg_recon_io_loss, \
            avg_recon_p_loss, \
            val_statement_loss, \
            val_drop_loss, \
            val_operator_loss, \
            val_error = metrics

            print("Train loss: S %f" % avg_statement_train_loss, "D %f" % avg_drop_train_loss,
                  "F %f" % avg_operator_train_loss)
            print("Val loss: S %f" % val_statement_loss, "D %f" % val_drop_loss,
                  "F %f" % val_operator_loss)

            print("Val classification error: %f" % val_error)

            if dist.get_rank() == 0:
                tb_writer.add_scalar("loss/train_statement_loss", avg_statement_train_loss, epoch)
                tb_writer.add_scalar("loss/train_drop_loss", avg_drop_train_loss, epoch)
                tb_writer.add_scalar("loss/train_operator_loss", avg_operator_train_loss, epoch)

                tb_writer.add_scalar("info/train_z_n_kl_loss", avg_z_n_kl_loss, epoch)
                tb_writer.add_scalar("info/train_t_n_kl_loss", avg_t_n_kl_loss, epoch)
                tb_writer.add_scalar("info/train_z_t_kl_loss", avg_z_t_kl_loss, epoch)

                tb_writer.add_scalar("info/train_recon_io_loss", avg_recon_io_loss, epoch)
                tb_writer.add_scalar("info/train_recon_p_loss", avg_recon_p_loss, epoch)

                tb_writer.add_scalar("loss/val_statement_loss", val_statement_loss, epoch)
                tb_writer.add_scalar("loss/val_drop_loss", val_drop_loss, epoch)
                tb_writer.add_scalar("loss/val_operator_loss", val_operator_loss, epoch)

                tb_writer.add_scalar("error/val_error", val_error, epoch)
                tb_writer.add_scalar("lr/lr", main_optimizer.state_dict()['param_groups'][0]['lr'], epoch)

            if val_error < best_val_error:
                ckpt_path = save_dir / 'model-best'
                print("Found new best model")
                best_val_error = val_error
#                save(model, optimizer, epoch, params, save_dir)
                save_checkpoint(models, optims, models_name, optims_name, epoch, best_val_error, ckpt_path)
                patience_ctr = 0
            else:
                patience_ctr += 1
                if patience_ctr == params.patience:
                    print("Ran out of patience. Stopping training early...")
                    break
            ckpt_path = save_dir / 'model-latest'
            save_checkpoint(models, optims, models_name, optims_name, epoch, best_val_error, ckpt_path)

if __name__ == '__main__':
    train()