from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import json
import os
import random
import multiprocessing
import time 
from pathlib import Path
from tqdm import tqdm
import numpy as np

import torch
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable
from torch import nn
import torch.nn.functional as F

import params
from model.combinar_MI7 import CombinarMI
from cuda import use_cuda, LongTensor, FloatTensor
from env.env import ProgramEnv
from env.operator import Operator, operator_to_index
from env.statement import Statement, statement_to_index
from dsl.program import Program
from dsl.example import Example

from utils_hd import *

from sklearn import preprocessing

from pytorch_metric_learning.losses import NTXentLoss

from dsl import constraint
from dsl.value import NULLVALUE

torch.manual_seed(params.seed)
torch.cuda.manual_seed_all(params.seed)
np.random.seed(params.seed)
random.seed(params.seed)


def generate_prog_data(line):
    data = json.loads(line.rstrip())
    examples = Example.from_line(data)
    env = ProgramEnv(examples)
    program = Program.parse(data['program'])

    inputs = []
    statements = []
    drop = []
    operators = []
    programs = []
    steps = []
    input_nums = []
    types = []
    program_lengths = []
    new_inputs = []


    for i, statement in enumerate(program.statements):
        inputs.append(env.get_encoding())

        # Translate absolute indices to post-drop indices
        f, args = statement.function, list(statement.args)
        for j, arg in enumerate(args):
            if isinstance(arg, int):
                args[j] = env.real_var_idxs.index(arg)

        statement = Statement(f, args)
        statements.append(statement_to_index[statement])

        used_args = []
        for next_statement in program.statements[i:]:
            used_args += [x for x in next_statement.args if isinstance(x, int)]

        to_drop = []
        for j in range(params.max_program_vars):
            if j >= env.num_vars or env.real_var_idxs[j] not in used_args:
                to_drop.append(1)
            else:
                to_drop.append(0)

        drop.append(to_drop)
        rand_idx = random.choice([j for j in range(len(to_drop)) if to_drop[j] > 0])

        operator = Operator.from_statement(statement)
        operators.append(operator_to_index[operator])

        if env.num_vars < params.max_program_vars:
            env.step(statement)
        else:
            # Choose a random var (that is not used anymore) to drop.
            env.step(statement, rand_idx)

        programs.append(str(program))
        # print("Inputs Shape:", [inp for inp in inputs])
        # print("Statements:", statements)
        # print("Drop:", drop)
        # print("Operators:", operators)
        steps.append(i)
        input_nums.append(len(program.input_types))
        typ = program.input_types + (3 - len(program.input_types)) * ['NULL'] + program.var_types[-1:]
        typ_token = []
        # for choice in query
        for item in typ:
            if str(item) == 'LIST':
                typ_token.append([0, 1, 0])
            elif str(item) == 'INT':
                typ_token.append([1, 0, 0])
            elif str(item) == 'NULL':
                typ_token.append([0, 0, 1])
            else:
                raise ValueError('bad type {}'.format(item))
        types.append([typ_token] * 5) # 5 ios same types
    new_inputs.append(env.get_encoding())
    program_seq = [[1] + [idx + 3 for idx in statements] + [2] + [0] * (params.max_prog_len - len(statements))]
    program_lengths = [len(statements) + 2]
#    print('program_seq:', program_seq)
    return new_inputs[:len(program_seq)], statements[:len(program_seq)], drop[:len(program_seq)], operators[:len(program_seq)], programs[:len(program_seq)], steps[:len(program_seq)], input_nums[:len(program_seq)], types[:len(program_seq)], program_seq, program_lengths

def load_data(fileobj, max_len):
    X = []
    Y = []
    Z = []
    W = []
    P = []
    S = []
    N = []
    T = []
    SEQ = []
    L = []

    #print("Loading dataset...")
    lines = fileobj.read().splitlines()
    if max_len is not None:
        selected_lines = random.sample(lines, max_len)
        lines = selected_lines

#    pool = multiprocessing.Pool(processes=1)
#    res = list(tqdm(pool.imap(generate_prog_data, lines[-12800:]), total=len(lines[-12800:])))
    res = []
#    for line in tqdm(lines[-1280:], total=len(lines[-1280:])):
    for line in tqdm(lines, total=len(lines)):
        res.append(generate_prog_data(line))

    for input, target, to_drop, operators, programs, steps, input_nums, types, program_seq, program_lengths in res:
        X += input
        Y += target
        Z += to_drop
        W += operators
        P += programs
        S += steps
        N += input_nums
        T += types 
        SEQ += program_seq 
        L += program_lengths

    return np.array(X), np.array(Y), np.array(Z), np.array(W), np.array(P), np.array(S), np.array(N), np.array(T), np.array(SEQ), np.array(L)


def train():
    params.gpus = init_distributed_mode() 
    print(params.gpus)
    # Define paths for storing tensorboard logs
#    date = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
#    save_dir = params.model_output_path + '/' + date + '/PE_model/'
    save_dir = Path(params.model_output_path + params.dump_dataset)
    print('saved to:', save_dir)

    print('here1')
    # Load train and val data
    #with open(params.train_path, 'r') as f:
    #    train_data, train_statement_target, train_drop_target, train_operator_target, train_program, train_step, train_input_num, train_typ, train_program_seq, train_plengths = load_data(f, params.max_len)
    print('here2')
    for val_path in ['data/len_4', 'data/len_5']:
    #for val_path in ['data/len_5', 'data/len_8', 'data/len_10', 'data/len_12', 'data/len_14']:
        with open(val_path, 'r') as f:
            val_data, val_statement_target, val_drop_target, val_operator_target, val_program, val_step, val_input_num, val_typ, val_program_seq, val_plengths = load_data(f, params.max_len)

        #print(train_program[0:1])

        le = preprocessing.LabelEncoder()
        #programs = le.fit_transform(np.concatenate([train_program, val_program]))
        programs = le.fit_transform(val_program)
        #train_program = programs[:len(train_program)]
        #val_program = programs[len(train_program):]
        val_program = programs

        #print(train_program[0:1])
        #print(le.inverse_transform(train_program[0:1]))
        # Define model
    
        model = CombinarMI(le)
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

        device = torch.cuda.current_device()
        model = model.to(device)
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[params.gpus], find_unused_parameters=False)

#        main_params = model.module.main_parameters()
#        info_params = model.module.info_parameters()
#        main_learning_rate = params.learn_rate
#        info_learning_rate = params.learn_rate
#        main_optimizer = torch.optim.Adam(main_params, lr=main_learning_rate)
#        info_optimizer = torch.optim.Adam(info_params, lr=info_learning_rate)

        models = [model]
        models_name = ['model']
        optims = []
        optims_name = []
#        optims = [main_optimizer, info_optimizer]
#        optims_name = ['main_optimizer', 'info_optimizer']

        start_epoch = 0
        best_val_loss = np.inf
        #if params.load_from_checkpoint:
        #    print("=> loading checkpoint '{}'".format(params.load_from_checkpoint))
        #    loaded_models, loaded_opts, start_epoch, best_val_loss = \
        #        load_checkpoint(models, optims, models_name, optims_name, params.load_from_checkpoint)
        #    model = loaded_models[0]


        statement_criterion = nn.CrossEntropyLoss()
        drop_criterion = nn.BCELoss()
        operator_criterion = nn.CrossEntropyLoss()
        l2_criterion = nn.MSELoss()
#        info_citerion = NTXentLoss(temperature=0.07)
        info_citerion = NTXentDistLoss(temperature=0.07)
        ns_citerion = NSLoss(gamma=params.ns_gamma)
        ns_citerion_base = NSLoss(gamma=12)
        prob_citerion = NTXentProbLoss2(temperature=0.07)
        latent_criterion = nn.CrossEntropyLoss()
        hellinger_criterion = HellingerLoss()


        #Convert to appropriate types
        # The cuda types are not used here on purpose - most GPUs can't handle so much memory
        #train_data, train_statement_target, train_drop_target, train_operator_target, train_program, train_step, train_input_num, train_typ, train_program_seq, train_plengths = \
        #    torch.LongTensor(train_data), torch.LongTensor(train_statement_target), \
        #    torch.FloatTensor(train_drop_target), torch.LongTensor(train_operator_target), \
        #    torch.LongTensor(train_program).view(-1, 1), torch.LongTensor(train_step), torch.LongTensor(train_input_num), torch.LongTensor(train_typ), torch.LongTensor(train_program_seq), torch.LongTensor(train_plengths)
        val_data, val_statement_target, val_drop_target, val_operator_target, val_program, val_step, val_input_num, val_typ, val_program_seq, val_plengths = \
            torch.LongTensor(val_data), torch.LongTensor(val_statement_target), \
            torch.FloatTensor(val_drop_target), torch.LongTensor(val_operator_target), \
            torch.LongTensor(val_program).view(-1, 1), torch.LongTensor(val_step), torch.LongTensor(val_input_num), torch.LongTensor(val_typ), torch.LongTensor(val_program_seq), torch.LongTensor(val_plengths)


        val_data = Variable(val_data.type(LongTensor))
        val_statement_target = Variable(val_statement_target.type(LongTensor))
        val_drop_target = Variable(val_drop_target.type(FloatTensor))
        val_operator_target = Variable(val_operator_target.type(LongTensor))

        #train_dataset = TensorDataset(train_data, train_statement_target, train_drop_target, train_operator_target, train_program, train_step, train_input_num, train_typ, train_program_seq, train_plengths)
        #train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        #train_data_loader = DataLoader(train_dataset, batch_size=params.batch_size,
        #                        shuffle=False, pin_memory=False, sampler=train_sampler, num_workers=0)

        val_dataset = TensorDataset(val_data, val_statement_target, val_drop_target, val_operator_target, val_program, val_step, val_input_num, val_typ, val_program_seq, val_plengths)
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
        val_data_loader = DataLoader(val_dataset, batch_size=params.batch_size,
                                shuffle=False, pin_memory=False, sampler=val_sampler, num_workers=0)

        x_s = torch.zeros(params.batch_size, 1, 4, params.max_list_len, device='cuda').long() + params.integer_range
        patience_ctr = 0

            ############## dump dataset #################
        model.eval()
        with torch.no_grad():
            if dist.get_rank() == 0:
                ckpt_path = save_dir / 'model-best'
                print("=> loading checkpoint '{}'".format(ckpt_path))
                loaded_models, loaded_opts, start_epoch, best_val_loss = \
                    load_checkpoint(models, optims, models_name, optims_name, ckpt_path)
                model = loaded_models[0]
                model.eval()
                #ios = generate_ios(train_program, train_typ, model, train_step, train_drop_target)
                #f = open(save_dir / 'train_gps', 'w')
                #for item in ios:
                #    problem = dict(program=item['program'], examples=item['examples'])
                #    f.write(json.dumps(problem) + '\n')
                #f.close()

                #ios = generate_ios(val_program, val_typ, model, val_step, val_drop_target)
                #f = open(save_dir / 'val_gps', 'w')
                #for item in ios:
                #    problem = dict(program=item['program'], examples=item['examples'])
                #    f.write(json.dumps(problem) + '\n')
                #f.close()

                ios = generate_ios(val_program, val_typ, model, val_step, val_drop_target)
                f = open(save_dir / val_path.split('/')[-1], 'w')
                print(save_dir / val_path.split('/')[-1])
                for item in ios:
                    problem = dict(program=item['program'], examples=item['examples'])
                    f.write(json.dumps(problem) + '\n')
                f.close()
            dist.barrier()
            #############################################


if __name__ == '__main__':
    train()