from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import json
import os
import random
import multiprocessing
import time 
from pathlib import Path
from tqdm import tqdm
import numpy as np

import torch
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable
from torch import nn
import torch.nn.functional as F

import params
from model.combinar_MI6 import CombinarMI
from cuda import use_cuda, LongTensor, FloatTensor
from env.env import ProgramEnv
from env.operator import Operator, operator_to_index
from env.statement import Statement, statement_to_index
from dsl.program import Program
from dsl.example import Example

from utils_hd import *

from sklearn import preprocessing

from pytorch_metric_learning.losses import NTXentLoss

from dsl import constraint
from dsl.value import NULLVALUE

torch.manual_seed(params.seed)
torch.cuda.manual_seed_all(params.seed)
np.random.seed(params.seed)
random.seed(params.seed)


def generate_prog_data(line):
    data = json.loads(line.rstrip())
    examples = Example.from_line(data)
    env = ProgramEnv(examples)
    program = Program.parse(data['program'])

    inputs = []
    statements = []
    drop = []
    operators = []
    programs = []
    steps = []
    input_nums = []
    types = []
    program_lengths = []
    new_inputs = []


    for i, statement in enumerate(program.statements):
        inputs.append(env.get_encoding())

        # Translate absolute indices to post-drop indices
        f, args = statement.function, list(statement.args)
        for j, arg in enumerate(args):
            if isinstance(arg, int):
                args[j] = env.real_var_idxs.index(arg)

        statement = Statement(f, args)
        statements.append(statement_to_index[statement])

        used_args = []
        for next_statement in program.statements[i:]:
            used_args += [x for x in next_statement.args if isinstance(x, int)]

        to_drop = []
        for j in range(params.max_program_vars):
            if j >= env.num_vars or env.real_var_idxs[j] not in used_args:
                to_drop.append(1)
            else:
                to_drop.append(0)

        drop.append(to_drop)
        rand_idx = random.choice([j for j in range(len(to_drop)) if to_drop[j] > 0])

        operator = Operator.from_statement(statement)
        operators.append(operator_to_index[operator])

        if env.num_vars < params.max_program_vars:
            env.step(statement)
        else:
            # Choose a random var (that is not used anymore) to drop.
            env.step(statement, rand_idx)

        programs.append(str(program))
        # print("Inputs Shape:", [inp for inp in inputs])
        # print("Statements:", statements)
        # print("Drop:", drop)
        # print("Operators:", operators)
        steps.append(i)
        input_nums.append(len(program.input_types))
        typ = program.input_types + (3 - len(program.input_types)) * ['NULL'] + program.var_types[-1:]
        typ_token = []
        # for choice in query
        for item in typ:
            if str(item) == 'LIST':
                typ_token.append([0, 1, 0])
            elif str(item) == 'INT':
                typ_token.append([1, 0, 0])
            elif str(item) == 'NULL':
                typ_token.append([0, 0, 1])
            else:
                raise ValueError('bad type {}'.format(item))
        types.append([typ_token] * 5) # 5 ios same types
    new_inputs.append(env.get_encoding())
    program_seq = [[1] + [idx + 3 for idx in statements] + [2] + [0] * (params.max_prog_len - len(statements))]
    program_lengths = [len(statements) + 2]
#    print('program_seq:', program_seq)
    return new_inputs[:len(program_seq)], statements[:len(program_seq)], drop[:len(program_seq)], operators[:len(program_seq)], programs[:len(program_seq)], steps[:len(program_seq)], input_nums[:len(program_seq)], types[:len(program_seq)], program_seq, program_lengths

def load_data(fileobj, max_len):
    X = []
    Y = []
    Z = []
    W = []
    P = []
    S = []
    N = []
    T = []
    SEQ = []
    L = []

    #print("Loading dataset...")
    lines = fileobj.read().splitlines()
    if max_len is not None:
        selected_lines = random.sample(lines, max_len)
        lines = selected_lines

#    pool = multiprocessing.Pool(processes=1)
#    res = list(tqdm(pool.imap(generate_prog_data, lines[-12800:]), total=len(lines[-12800:])))
    res = []
#    for line in tqdm(lines[-1280:], total=len(lines[-1280:])):
    for line in tqdm(lines, total=len(lines)):
        res.append(generate_prog_data(line))

    for input, target, to_drop, operators, programs, steps, input_nums, types, program_seq, program_lengths in res:
        X += input
        Y += target
        Z += to_drop
        W += operators
        P += programs
        S += steps
        N += input_nums
        T += types 
        SEQ += program_seq 
        L += program_lengths

    return np.array(X), np.array(Y), np.array(Z), np.array(W), np.array(P), np.array(S), np.array(N), np.array(T), np.array(SEQ), np.array(L)

def generate_random_ios():
    save_dir = Path('trained_models/' + params.dataset + '/random')
    with open(params.train_path, 'r') as f:
        train_data, train_statement_target, train_drop_target, train_operator_target, train_program, train_step, train_input_num, train_typ, train_program_seq, train_plengths = load_data(f, params.max_len)
    print('here2')
    with open(params.val_path, 'r') as f:
        val_data, val_statement_target, val_drop_target, val_operator_target, val_program, val_step, val_input_num, val_typ, val_program_seq, val_plengths = load_data(f, params.max_len)

    print(train_program[0:1])

    le = preprocessing.LabelEncoder()
    programs = le.fit_transform(np.concatenate([train_program, val_program]))
    train_program = programs[:len(train_program)]
    val_program = programs[len(train_program):]

    # Define model
    
    model = CombinarMI(le)

    device = torch.cuda.current_device()
    model = model.to(device)

    #Convert to appropriate types
    # The cuda types are not used here on purpose - most GPUs can't handle so much memory
    train_data, train_statement_target, train_drop_target, train_operator_target, train_program, train_step, train_input_num, train_typ, train_program_seq, train_plengths = \
        torch.LongTensor(train_data), torch.LongTensor(train_statement_target), \
        torch.FloatTensor(train_drop_target), torch.LongTensor(train_operator_target), \
        torch.LongTensor(train_program).view(-1, 1), torch.LongTensor(train_step), torch.LongTensor(train_input_num), torch.LongTensor(train_typ), torch.LongTensor(train_program_seq), torch.LongTensor(train_plengths)
    val_data, val_statement_target, val_drop_target, val_operator_target, val_program, val_step, val_input_num, val_typ, val_program_seq, val_plengths = \
        torch.LongTensor(val_data), torch.LongTensor(val_statement_target), \
        torch.FloatTensor(val_drop_target), torch.LongTensor(val_operator_target), \
        torch.LongTensor(val_program).view(-1, 1), torch.LongTensor(val_step), torch.LongTensor(val_input_num), torch.LongTensor(val_typ), torch.LongTensor(val_program_seq), torch.LongTensor(val_plengths)

    val_data = Variable(val_data.type(LongTensor))
    val_statement_target = Variable(val_statement_target.type(LongTensor))
    val_drop_target = Variable(val_drop_target.type(FloatTensor))
    val_operator_target = Variable(val_operator_target.type(LongTensor))

    train_dataset = TensorDataset(train_data, train_statement_target, train_drop_target, train_operator_target, train_program, train_step, train_input_num, train_typ, train_program_seq, train_plengths)
    train_data_loader = DataLoader(train_dataset, batch_size=params.batch_size,
                            shuffle=False, pin_memory=False)

    val_dataset = TensorDataset(val_data, val_statement_target, val_drop_target, val_operator_target, val_program, val_step, val_input_num, val_typ, val_program_seq, val_plengths)
    val_data_loader = DataLoader(val_dataset, batch_size=params.batch_size,
                            shuffle=False, pin_memory=False)

    model.eval()
    with torch.no_grad():
        ios = generate_ios(train_program, train_typ, model, train_step, train_drop_target, random=True)
        f = open(save_dir / 'train_gps', 'w')
        for item in ios:
            problem = dict(program=item['program'], examples=item['examples'])
            f.write(json.dumps(problem) + '\n')
        f.close()

        ios = generate_ios(val_program, val_typ, model, val_step, val_drop_target, random=True)
        f = open(save_dir / 'val_gps', 'w')
        for item in ios:
            problem = dict(program=item['program'], examples=item['examples'])
            f.write(json.dumps(problem) + '\n')
        f.close()


def generate_random_test():
    #save_dir = Path('trained_models/' + params.dataset + '/random')
    save_dir = Path('trained_models/')
    with open(params.train_path, 'r') as f:
        lines = f.read().splitlines()
#        if max_len is not None:
#            selected_lines = random.sample(lines, max_len)
#            lines = selected_lines
        res = []
        for line in tqdm(lines, total=len(lines)):

            data = json.loads(line.rstrip())
            examples = Example.from_line(data)
            env = ProgramEnv(examples)
            program = Program.parse(data['program'])

            input_output_examples = constraint.get_input_output_examples(program, num_examples=95,
                                                                        num_tries=1000)
            example = []
            if input_output_examples is None:
                print('io is none')
            print(input_output_examples)


#    print(train_program[0:1])
#
#    le = preprocessing.LabelEncoder()
#    programs = le.fit_transform(np.concatenate([train_program, val_program]))
#    train_program = programs[:len(train_program)]
#    val_program = programs[len(train_program):]
#
#    # Define model
#    
#    model = CombinarMI(le)
#
#    device = torch.cuda.current_device()
#    model = model.to(device)
#
#    #Convert to appropriate types
#    # The cuda types are not used here on purpose - most GPUs can't handle so much memory
#    train_data, train_statement_target, train_drop_target, train_operator_target, train_program, train_step, train_input_num, train_typ, train_program_seq, train_plengths = \
#        torch.LongTensor(train_data), torch.LongTensor(train_statement_target), \
#        torch.FloatTensor(train_drop_target), torch.LongTensor(train_operator_target), \
#        torch.LongTensor(train_program).view(-1, 1), torch.LongTensor(train_step), torch.LongTensor(train_input_num), torch.LongTensor(train_typ), torch.LongTensor(train_program_seq), torch.LongTensor(train_plengths)
#    val_data, val_statement_target, val_drop_target, val_operator_target, val_program, val_step, val_input_num, val_typ, val_program_seq, val_plengths = \
#        torch.LongTensor(val_data), torch.LongTensor(val_statement_target), \
#        torch.FloatTensor(val_drop_target), torch.LongTensor(val_operator_target), \
#        torch.LongTensor(val_program).view(-1, 1), torch.LongTensor(val_step), torch.LongTensor(val_input_num), torch.LongTensor(val_typ), torch.LongTensor(val_program_seq), torch.LongTensor(val_plengths)
#
#    val_data = Variable(val_data.type(LongTensor))
#    val_statement_target = Variable(val_statement_target.type(LongTensor))
#    val_drop_target = Variable(val_drop_target.type(FloatTensor))
#    val_operator_target = Variable(val_operator_target.type(LongTensor))
#
#    train_dataset = TensorDataset(train_data, train_statement_target, train_drop_target, train_operator_target, train_program, train_step, train_input_num, train_typ, train_program_seq, train_plengths)
#    train_data_loader = DataLoader(train_dataset, batch_size=params.batch_size,
#                            shuffle=False, pin_memory=False)
#
#    val_dataset = TensorDataset(val_data, val_statement_target, val_drop_target, val_operator_target, val_program, val_step, val_input_num, val_typ, val_program_seq, val_plengths)
#    val_data_loader = DataLoader(val_dataset, batch_size=params.batch_size,
#                            shuffle=False, pin_memory=False)
#
#    ios = []
#    example_batch = []
#    p_str = le.inverse_transform(p.view(-1).tolist())
#    for prog_sp in p_str:
#        prog_sp = Program.parse(prog_sp.rstrip())
#        input_output_examples = constraint.get_input_output_examples(prog_sp, num_examples=95,
#                                                                    num_tries=1000)
#        example = []
#        if input_output_examples is None:
#            print('io is none')
#        print(input_output_examples)
##            input_output_examples = [[[NULLVALUE], NULLVALUE]] * 5
##        for exp in input_output_examples:
##            inp, out = exp
##            encoded_inp = [var.encoded for var in inp]
##            encoded_out = out.encoded
##            if len(encoded_inp) < 3:
##                encoded_inp.extend([NULLVALUE.encoded] * (3 - len(encoded_inp)))
##            example.append(encoded_inp + [encoded_out])
##        example = torch.Tensor(example)
##        example_batch.append(example)
##    x = torch.stack(example_batch, 0).cuda().long()
#
#    num_examples = random.randint(1, 5) 
#    typ = torch.cat([x[:,:num_examples,:3,:2], x[:,:num_examples,-1:,:2]], -2)
#    x = torch.cat([x[:,:num_examples,:3,2:], x[:,:num_examples,-1:,2:]], -2)
#
#    model.eval()
#    with torch.no_grad():
#        ios = generate_ios(train_program, train_typ, model, train_step, train_drop_target, random=True)
#        f = open(save_dir / 'train_gps', 'w')
#        for item in ios:
#            problem = dict(program=item['program'], examples=item['examples'])
#            f.write(json.dumps(problem) + '\n')
#        f.close()
#
#        ios = generate_ios(val_program, val_typ, model, val_step, val_drop_target, random=True)
#        f = open(save_dir / 'val_gps', 'w')
#        for item in ios:
#            problem = dict(program=item['program'], examples=item['examples'])
#            f.write(json.dumps(problem) + '\n')
#        f.close()

if __name__ == '__main__':
#    generate_random_ios()
    generate_random_test()