from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import json
import os
import random
import multiprocessing
import time 
from pathlib import Path
from tqdm import tqdm
import numpy as np

import torch
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable
from torch import nn
import torch.nn.functional as F

import params
from model.combinar_MI6 import CombinarMI
from cuda import use_cuda, LongTensor, FloatTensor

from karel.consistency import Simulator

from nps.utils import *
from utils_hd import *

from sklearn import preprocessing

from data.data import (load_input_file, load_input_file_seq, get_minibatch, shuffle_dataset,
                     KarelDataset, QueryDataset, KarelDatasetNoWorlds, KarelDatasetSeq, BucketingSampler)

torch.manual_seed(params.seed)
torch.cuda.manual_seed_all(params.seed)
np.random.seed(params.seed)
random.seed(params.seed)


def train():
    params.gpus = init_distributed_mode() 
    print(params.gpus)
    # Define paths for storing tensorboard logs
#    date = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
#    save_dir = params.model_output_path + '/' + date + '/PE_model/'
    if not params.load_from_checkpoint:
        date = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
        save_dir = Path(params.model_output_path + date)
    else:
        save_dir = Path(params.model_output_path + params.load_from_checkpoint)
    save_dir = Path(params.model_output_path + 'test')
    print('saved to:', save_dir)
    if dist.get_rank() == 0:
        if not save_dir.exists():
            os.makedirs(str(save_dir))
        tb_writer = SummaryWriter(save_dir)
    dist.barrier()

    firstio = 'dataset' 
    if 'random' in firstio or 'fix' in firstio:
        train_dataset, vocab = load_input_file_seq(params.train_file, params.vocab_file)
        val_dataset, _ = load_input_file_seq(params.val_file, params.vocab_file)
    else:
        train_dataset, vocab = load_input_file(params.train_file, params.vocab_file)
        val_dataset, _ = load_input_file(params.val_file, params.vocab_file)

    vocabulary_size = len(vocab["tkn2idx"])
    tgt_start = vocab["tkn2idx"]["<s>"]
    tgt_end = vocab["tkn2idx"]["m)"]
    tgt_pad = vocab["tkn2idx"]["<pad>"]
    print(len(vocab["tkn2idx"]))
    nb_ios = params.nb_ios
    batch_size = params.batch_size

    #train_dataset['targets'] = train_dataset['targets'][:1280*2]
    #val_dataset['targets'] = val_dataset['targets'][:128]

    if 'random' in firstio or 'fix' in firstio:
        train_dataset = KarelDatasetSeq(train_dataset, tgt_start, tgt_end, tgt_pad, nb_ios, shuffle=False)
        val_dataset = KarelDatasetSeq(val_dataset, tgt_start, tgt_end, tgt_pad, nb_ios, shuffle=False)
    else:
        train_dataset = KarelDatasetNoWorlds(train_dataset, tgt_start, tgt_end, tgt_pad, nb_ios, shuffle=False)
        val_dataset = KarelDatasetNoWorlds(val_dataset, tgt_start, tgt_end, tgt_pad, nb_ios, shuffle=False)
#    train_sampler = BucketingSampler(train_dataset, batch_size, 0, 10)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_data_loader = DataLoader(train_dataset, shuffle=False, pin_memory=False,
                        batch_size=batch_size, sampler=train_sampler,
                        num_workers=0, collate_fn=query_collate)
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
    val_data_loader = DataLoader(val_dataset, shuffle=False, pin_memory=False,
                        batch_size=batch_size, sampler=val_sampler,
                        num_workers=0, collate_fn=query_collate)
    simulator = Simulator(vocab["idx2tkn"])

    info_citerion = NTXentDistLoss(temperature=0.07)
    ns_citerion = NSLoss(gamma=params.ns_gamma)
    ns_citerion_base = NSLoss(gamma=12)
    prob_citerion = NTXentProbLoss2(temperature=0.1)
    latent_criterion = nn.CrossEntropyLoss()

    ### Define model
    #TODO
    model = CombinarMI()
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    device = torch.cuda.current_device()
    model = model.to(device)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[params.gpus], find_unused_parameters=False)

    ### Define optimizer and loss
#    optimizer = torch.optim.Adam(model.parameters(), lr=params.learn_rate)
    main_params = model.module.main_parameters()
    main_learning_rate = params.learn_rate
    main_optimizer = torch.optim.Adam(main_params, lr=main_learning_rate)
#    scheduler = ReduceLROnPlateau(optimizer=main_optimizer, mode='min',
#                                  factor=0.1, patience=args.patience,
#                                  verbose=True, min_lr=1e-7)
#    info_scheduler = ReduceLROnPlateau(optimizer=info_optimizer, mode='min',
#                                       factor=0.1, patience=args.patience,
#                                       verbose=True, min_lr=1e-7)

    models = [model]
    models_name = ['model']
    optims = [main_optimizer]
    optims_name = ['main_optimizer']

    start_epoch = 0
    best_val_error = np.inf
    best_val_loss = np.inf
    if params.load_from_checkpoint:
        print("=> loading checkpoint '{}'".format(params.checkpoint_dir))
        loaded_models, loaded_opts, start_epoch, best_val_error = \
            load_checkpoint(models, optims, models_name, optims_name, params.load_from_checkpoint)
        model = loaded_models[0]
        main_optimizer = loaded_opts

    main_lr_sched = torch.optim.lr_scheduler.StepLR(main_optimizer, step_size=params.lr_scheduler_step_size)

    ###################### TODO: Karel from here ###################
    for epoch in range(start_epoch, params.num_epochs):
        train_sampler.set_epoch(epoch)
        model.train()
        print("Epoch %d" % epoch)

        train_statement_losses, train_drop_losses, train_operator_losses = [], [], []
        train_z_n_kl_loss, train_t_n_kl_loss, train_z_t_kl_loss, train_recon_io_loss, train_recon_p_loss = [], [], [], [], []
        batch_idx = 0
        for batch in tqdm(train_data_loader):
            global_step = int(batch_idx + epoch * (len(train_dataset) / params.batch_size / dist.get_world_size()))
            loss = torch.tensor([0.]).cuda()

#            query_num = random.randint(1, int(min(5, epoch / 2 + 1)))
            query_num = 1
#            query_num = random.randint(1, 5)
            x = Variable(batch[0].type(FloatTensor))
            y = Variable(batch[1].type(FloatTensor))
            program_seq = Variable(batch[2].type(LongTensor))
            plengths = Variable(batch[3].type(LongTensor))

#            torch.set_printoptions(profile="full")            
#            print(x)

            main_optimizer.zero_grad()

            num_io = random.randint(1, 5)

            input_grids, output_grids = x[:, :num_io], y[:, :num_io]

            batch_loss_infonce = 0.
            batch_loss_ns = 0.
            batch_loss_var_z = 0.
            batch_loss_var_t = 0.
            batch_loss_latent = 0.
                
            io_features = model.module.query.encode_io(input_grids, output_grids)
            ############### f-space distribution ############
            mus_t, logvars_t = model.module.query.encode_into_t(io_features)
            mus_t, logvars_t = model.module.query.intersection(mus_t, logvars_t)
            #################################################
            program_features = model.module.query.encode_program(program_seq, plengths)
            io_features = torch.cat([mus_t, logvars_t], -1)
            io_features = io_features.view(-1, params.dist_dim * 2)
            program_features = program_features.view(-1, params.dist_dim)

            ############### prob loss ##############
            loss_infonce = prob_citerion(program_features, io_features)
            print('loss_infonce', loss_infonce.item())
            ########################################

            loss_var_t = torch.mean(logvars_t.exp())
            loss += loss_infonce
            batch_loss_infonce += loss_infonce.item()
            batch_loss_var_t += loss_var_t.item()

            if params.latent_code:
                batch_loss_latent /= query_num
            batch_loss_infonce /= query_num
            batch_loss_var_t /= query_num

            loss /= query_num
            loss.backward()
            main_optimizer.step()

            if dist.get_rank() == 0:
                tb_writer.add_scalar("MI/infonce", batch_loss_infonce, global_step)
                tb_writer.add_scalar("MI/var_t", batch_loss_var_t, global_step)
                tb_writer.add_scalar("MI/total", loss.item(), global_step)
                tb_writer.add_scalar("lr/lr", main_optimizer.state_dict()['param_groups'][0]['lr'], global_step)
                if params.latent_code:
                    tb_writer.add_scalar("MI/latent", batch_loss_latent, global_step)
                if global_step % 30 == 0:
                    x_index = input_grids[0].argmax(-3, keepdim=True) / 15
                    img_batch = x_index#.view(x_index.shape[0], 1, -1, x_index.shape[-1])
                    tb_writer.add_images('query', img_batch, global_step)

            batch_idx += 1
            #break
        #break

        #main_lr_sched.step()

        model.eval()

        with torch.no_grad():
            query_num = 5

            total_statement_loss = 0.
            total_drop_loss = 0.
            total_operator_loss = 0.
            total_val_error = 0.
            total_loss = 0.
            patience_ctr = 0
            for batch in tqdm(val_data_loader):
                x = Variable(batch[0].type(FloatTensor))
                y = Variable(batch[1].type(FloatTensor))
                program_seq = Variable(batch[2].type(LongTensor))
                plengths = Variable(batch[3].type(LongTensor))

                input_grids, output_grids = x, y

                io_features = model.module.query.encode_io(input_grids, output_grids)
                ############### f-space distribution ############
                mus_t, logvars_t = model.module.query.encode_into_t(io_features)
                mus_t, logvars_t = model.module.query.intersection(mus_t, logvars_t)
                #################################################
                program_features = model.module.query.encode_program(program_seq, plengths)
                #print('program_featrues', program_features)

                io_features = torch.cat([mus_t, logvars_t], -1)
                io_features = io_features.view(-1, params.dist_dim * 2)
                program_features = program_features.view(-1, params.dist_dim)

                #io_features = io_features.view(-1, params.dense_output_size)
                #program_features = program_features.repeat_interleave(query_step + 1, 0)

                ############### prob loss ##############
                loss_infonce = prob_citerion(program_features, io_features)
                ########################################
                total_loss += (loss_infonce * y.shape[0]).item()
                #break

            t1 = torch.tensor([total_loss], dtype=torch.float64, device='cuda')
            dist.barrier()
            dist.all_reduce(t1)
            total_loss = \
                t1.tolist()[0] / len(val_dataset)
            if dist.get_rank() == 0:
                tb_writer.add_scalar("val/total_loss", total_loss, epoch)

            if total_loss < best_val_loss:
                ckpt_path = save_dir / 'model-best'
                print("Found new best model")
                best_val_loss = total_loss
#                save(model, optimizer, epoch, params, save_dir)
                save_checkpoint(models, optims, models_name, optims_name, epoch, best_val_error, ckpt_path)
                patience_ctr = 0
            else:
                patience_ctr += 1
                if patience_ctr == params.patience:
                    print("Ran out of patience. Stopping training early...")
                    break
            ckpt_path = save_dir / 'model-latest'
            save_checkpoint(models, optims, models_name, optims_name, epoch, best_val_error, ckpt_path)
        #break

        ############## dump dataset #################
    #model.eval()
    #with torch.no_grad():
    #    if params.dump_dataset:
    #        if dist.get_rank() == 0:
    #            ckpt_path = save_dir / 'model-best'
    #            print("=> loading checkpoint '{}'".format(ckpt_path))
    #            loaded_models, loaded_opts, start_epoch, best_val_error = \
    #                load_checkpoint(models, optims, models_name, optims_name, ckpt_path)
    #            model = loaded_models[0]
    #            model.eval()
    #            ios = generate_ios(train_program, train_typ, model, train_step, train_drop_target)
    #            f = open(save_dir / 'train_gps', 'w')
    #            for item in ios:
    #                problem = dict(program=item['program'], examples=item['examples'])
    #                f.write(json.dumps(problem) + '\n')
    #            f.close()

    #            ios = generate_ios(val_program, val_typ, model, val_step, val_drop_target)
    #            f = open(save_dir / 'val_gps', 'w')
    #            for item in ios:
    #                problem = dict(program=item['program'], examples=item['examples'])
    #                f.write(json.dumps(problem) + '\n')
    #            f.close()
    #        dist.barrier()
        #############################################

if __name__ == '__main__':
    train()