from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import json
import os
import random
import multiprocessing
import time 
from pathlib import Path
from tqdm import tqdm
import numpy as np

import torch
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable
from torch import nn
import torch.nn.functional as F

import params
from model.combinar_MI7 import CombinarMI
from cuda import use_cuda, LongTensor, FloatTensor

from karel.consistency import Simulator

from utils import *
from utils_hd import *

from sklearn import preprocessing

from data.data import (load_input_file, load_input_file_seq, get_minibatch, shuffle_dataset,
                     KarelDataset, QueryDataset, KarelDatasetNoWorlds, KarelDatasetSeq, BucketingSampler)

torch.manual_seed(params.seed)
torch.cuda.manual_seed_all(params.seed)
np.random.seed(params.seed)
random.seed(params.seed)


def train():
    params.gpus = init_distributed_mode() 
    print(params.gpus)
    # Define paths for storing tensorboard logs
#    date = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
#    save_dir = params.model_output_path + '/' + date + '/PE_model/'
    if not params.load_from_checkpoint:
        date = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
        save_dir = Path(params.model_output_path + date)
    else:
        save_dir = Path(params.model_output_path + params.load_from_checkpoint)
    print('saved to:', save_dir)
    if dist.get_rank() == 0:
        if not save_dir.exists():
            os.makedirs(str(save_dir))
        tb_writer = SummaryWriter(save_dir)
    dist.barrier()

    firstio = 'fix' 
    if 'random' in firstio or 'fix' in firstio:
        train_dataset, vocab = load_input_file_seq(params.train_file, params.vocab_file)
        val_dataset, _ = load_input_file_seq(params.val_file, params.vocab_file)
    else:
        train_dataset, vocab = load_input_file(params.train_file, params.vocab_file)
        val_dataset, _ = load_input_file(params.val_file, params.vocab_file)

    vocabulary_size = len(vocab["tkn2idx"])
    tgt_start = vocab["tkn2idx"]["<s>"]
    tgt_end = vocab["tkn2idx"]["m)"]
    tgt_pad = vocab["tkn2idx"]["<pad>"]
    print(len(vocab["tkn2idx"]))
    nb_ios = params.nb_ios
    batch_size = params.batch_size

    train_dataset['targets'] = train_dataset['targets']
    val_dataset['targets'] = val_dataset['targets']

    if 'random' in firstio or 'fix' in firstio:
        train_dataset = KarelDatasetSeq(train_dataset, tgt_start, tgt_end, tgt_pad, nb_ios, shuffle=False)
        val_dataset = KarelDatasetSeq(val_dataset, tgt_start, tgt_end, tgt_pad, nb_ios, shuffle=False)
    else:
        train_dataset = KarelDatasetNoWorlds(train_dataset, tgt_start, tgt_end, tgt_pad, nb_ios, shuffle=False)
        val_dataset = KarelDatasetNoWorlds(val_dataset, tgt_start, tgt_end, tgt_pad, nb_ios, shuffle=False)
#    train_sampler = BucketingSampler(train_dataset, batch_size, 0, 10)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_data_loader = DataLoader(train_dataset, shuffle=False, pin_memory=False,
                        batch_size=batch_size, sampler=train_sampler,
                        num_workers=0, collate_fn=query_collate)
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
    val_data_loader = DataLoader(val_dataset, shuffle=False, pin_memory=False,
                        batch_size=batch_size, sampler=val_sampler,
                        num_workers=0, collate_fn=query_collate)
    simulator = Simulator(vocab["idx2tkn"])

    info_citerion = NTXentDistLoss(temperature=0.1)
    ns_citerion = NSLoss(gamma=params.ns_gamma)
    ns_citerion_base = NSLoss(gamma=12)
    prob_citerion = NTXentProbLoss2(temperature=0.1)
    latent_criterion = nn.CrossEntropyLoss()
    noise_criterion = nn.MSELoss()
    hellinger_criterion = HellingerLoss()

    ### Define model
    #TODO
    model = CombinarMI()
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    device = torch.cuda.current_device()
    model = model.to(device)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[params.gpus], find_unused_parameters=False)

    ### Define optimizer and loss
#    optimizer = torch.optim.Adam(model.parameters(), lr=params.learn_rate)
    main_params = model.module.main_parameters()
    main_learning_rate = params.learn_rate
    main_optimizer = torch.optim.Adam(main_params, lr=main_learning_rate)
#    scheduler = ReduceLROnPlateau(optimizer=main_optimizer, mode='min',
#                                  factor=0.1, patience=args.patience,
#                                  verbose=True, min_lr=1e-7)
#    info_scheduler = ReduceLROnPlateau(optimizer=info_optimizer, mode='min',
#                                       factor=0.1, patience=args.patience,
#                                       verbose=True, min_lr=1e-7)

    models = [model]
    models_name = ['model']
    optims = [main_optimizer]
    optims_name = ['main_optimizer']

    start_epoch = 0
    best_val_error = np.inf
    best_val_loss = np.inf
    patience_ctr = 0
    if params.load_from_checkpoint:
        print("=> loading checkpoint '{}'".format(save_dir))
        loaded_models, loaded_opts, start_epoch, best_val_loss = \
            load_checkpoint(models, optims, models_name, optims_name, save_dir / 'model-latest')
        model = loaded_models[0]
        main_optimizer = loaded_opts[0]
        print('best_val_loss', best_val_loss)

    main_lr_sched = torch.optim.lr_scheduler.StepLR(main_optimizer, step_size=params.lr_scheduler_step_size)

    ###################### TODO: Karel from here ###################
    for epoch in range(start_epoch, params.num_epochs):
        train_sampler.set_epoch(epoch)
        model.train()
        print("Epoch %d" % epoch)

        train_statement_losses, train_drop_losses, train_operator_losses = [], [], []
        train_z_n_kl_loss, train_t_n_kl_loss, train_z_t_kl_loss, train_recon_io_loss, train_recon_p_loss = [], [], [], [], []
        batch_idx = 0
        pbar = tqdm(train_data_loader)
        for batch in pbar:
            global_step = int(batch_idx + epoch * (len(train_dataset) / params.batch_size / dist.get_world_size()))
            loss = torch.tensor([0.]).cuda()

            query_num = random.randint(1, int(min(params.query_num, epoch / 2 + 1)))
#            query_num = 1
#            query_num = random.randint(1, 5)
            x = Variable(batch[0].type(LongTensor))
            y = Variable(batch[1].type(LongTensor))
            program_seq = Variable(batch[2].type(LongTensor))
            plengths = Variable(batch[3].type(LongTensor))

#            torch.set_printoptions(profile="full")            
#            print(x)

            main_optimizer.zero_grad()

            start_input_grids, start_output_grids = get_first_io(program_seq, simulator, vocab)
            input_grids, output_grids = start_input_grids, start_output_grids

            batch_loss_infonce = 0.
            batch_loss_ns = 0.
            batch_loss_var_z = 0.
            batch_loss_var_t = 0.
            batch_loss_latent = 0.
            batch_loss_noise = 0.
            batch_loss_hellinger = 0.
            embedding = model.module.query.encode_io(input_grids, output_grids)
            mus_t, logvars_t = model.module.query.encode_into_t(embedding)
            mus_t, logvars_t = model.module.query.intersection(mus_t, logvars_t)
            for query_step in range(query_num):
                #print(query_step)
                #print(mus_t)
                #print(logvars_t)
                #embedding = model.module.query.reparameterize(mus_t, logvars_t)
                embedding = torch.cat([mus_t, logvars_t], -1)
                #embedding = mus_t
                #embedding = embedding.mean(1)
                ################## latent code ###############
                if params.latent_code:
                    latent_targ = torch.Tensor([query_step] * input_grids.shape[0]).cuda().view(-1, 1)
                    if params.noise:
                        noise_targ = torch.Tensor(np.random.uniform(-1, 1, (input_grids.shape[0], 5))).cuda()
                        latent_emb = torch.cat([latent_targ, noise_targ], -1)
                    else:
                        latent_emb = latent_targ
                    #latent_emb = F.one_hot(latent_targ.long(), params.query_num).squeeze(1)
                    embedding = torch.cat([embedding, latent_emb], -1)
                ##############################################
                #query_inp, query_index = model.module.query.decode_process(embedding, params.hard_softmax)
                query_inp = model.module.query.decode_process(embedding, params.hard_softmax)
                query_out = model.module.env_step(query_inp, program_seq, simulator, params.hard_softmax)

                if query_step > 0:
                    input_grids = torch.cat([input_grids, query_inp], 1)
                    output_grids = torch.cat([output_grids, query_out], 1)
                else:
                    input_grids = query_inp 
                    output_grids = query_out 
                
                io_features = model.module.query.encode_io(input_grids, output_grids)
                #io_features = model.module.query.encode_io(x, typ)
                #print('io_featrues', io_features)
                ############### f-space distribution ############
                if params.hellinger:# and query_step >= 1:
                    #mus_t_old, logvars_t_old = mus_t.clone().unsqueeze(1), logvars_t.clone().unsqueeze(1)
                    mus_t_old, logvars_t_old = mus_t.clone().detach().unsqueeze(1), logvars_t.clone().detach().unsqueeze(1)
                    mus_t, logvars_t = model.module.query.encode_into_t(io_features)
                    loss_hellinger = hellinger_criterion(mus_t[:, -1:], logvars_t[:, -1:], mus_t_old, logvars_t_old)
                    #loss_hellinger = -((logvars_t - logvars_t_old)**2).mean()
                    loss += 1*loss_hellinger
                    batch_loss_hellinger += loss_hellinger.item()
                    mus_t, logvars_t = model.module.query.intersection(mus_t, logvars_t)
                else:
                    mus_t, logvars_t = model.module.query.encode_into_t(io_features)
                    mus_t, logvars_t = model.module.query.intersection(mus_t, logvars_t)
                #################################################
                ################## latent code ################
                if params.latent_code:
                    #io_features = model.module.query.reparameterize(mus_t, logvars_t)
                    #io_features = mus_t
                    io_features = torch.cat([mus_t, logvars_t], -1)
                    latent_pred = model.module.query.latent_decoder(io_features)
                    loss_latent = latent_criterion(latent_pred, latent_targ.view(-1).long())
                    loss += loss_latent
                    batch_loss_latent += loss_latent.item()
                    if params.noise:
                        noise_pred = model.module.query.noise_decoder(io_features)
                        loss_noise = noise_criterion(noise_pred, noise_targ)
                        loss += loss_noise
                        batch_loss_noise += loss_noise.item()
                #################################################
                program_features = model.module.query.encode_program(program_seq, plengths)
                #print('program_features', program_features)
                #mus_z, logvars_z = model.module.query.encode_into_z2(program_features)

                io_features = torch.cat([mus_t, logvars_t], -1)
                #program_features = torch.cat([mus_z, logvars_z], -1)
                io_features = io_features.view(-1, params.dist_dim * 2)

                ############### prob loss ##############
                program_features = program_features.view(-1, params.dist_dim)
                loss_infonce = prob_citerion(program_features, io_features)
                ########################################
                ############### kl loss ###############
                #program_features = program_features.view(-1, params.dist_dim * 2)

                #embeddings = torch.cat([io_features, program_features], 0)
                ##print(embeddings.shape)
                ## DDP info-nce
                #embeddings_list = [torch.ones_like(embeddings) for _ in range(dist.get_world_size())]
                #dist.all_gather(embeddings_list, embeddings)
                #embeddings_list[dist.get_rank()] = embeddings
                #embeddings = torch.cat(embeddings_list)

                #indices = torch.arange(0, int(program_features.shape[0]), device='cuda')
                #labels = torch.cat([indices, indices], 0)
                ##labels = torch.cat([indices.repeat_interleave(query_step + 1, 0), indices], 0)
                ##labels = torch.cat([indices.repeat_interleave(num_examples, 0), indices], 0)

                #labels = torch.cat([labels + offset * program_features.shape[0] for offset in range(dist.get_world_size())])

                #loss_infonce = info_citerion(embeddings, labels)
                ########################################

                print('loss_infonce', loss_infonce.item())

                loss_var_t = torch.mean(logvars_t.exp())
                #loss = 0.0001 * t_n_kl_loss + loss_infonce
                #print(loss_latent)
                #print(loss_infonce)
                #print(loss_var)
                loss += loss_infonce
                #loss += loss_ns
                #loss += 1 * loss_var_z
                batch_loss_infonce += loss_infonce.item()
                #batch_loss_ns += loss_ns_base.item()
                batch_loss_var_t += loss_var_t.item()

                mus_t = mus_t.clone().detach()
                logvars_t = logvars_t.clone().detach()

            pbar.set_description('Epoch %d' % epoch)

            
            if params.latent_code:
                batch_loss_latent /= query_num
                batch_loss_noise /= query_num
            if params.hellinger:# and query_num > 1:
                batch_loss_hellinger /= query_num
            batch_loss_infonce /= query_num
            #batch_loss_ns /= query_num
            batch_loss_var_t /= query_num

            loss /= query_num
            loss.backward()
            main_optimizer.step()

            if dist.get_rank() == 0:
                tb_writer.add_scalar("MI/infonce", batch_loss_infonce, global_step)
                #tb_writer.add_scalar("MI/ns", batch_loss_ns, global_step)
                tb_writer.add_scalar("MI/var_t", batch_loss_var_t, global_step)
                #tb_writer.add_scalar("MI/kl", t_n_kl_loss.item(), global_step)
                tb_writer.add_scalar("MI/total", loss.item(), global_step)
                tb_writer.add_scalar("lr/lr", main_optimizer.state_dict()['param_groups'][0]['lr'], global_step)
                if params.latent_code:
                    tb_writer.add_scalar("MI/latent", batch_loss_latent, global_step)
                    if params.noise:
                        tb_writer.add_scalar("MI/noise", batch_loss_noise, global_step)
                if params.hellinger:
                    tb_writer.add_scalar("MI/hellinger", batch_loss_hellinger, global_step)
                if global_step % 20000 == 0:
                    x_index = input_grids[0].argmax(-3, keepdim=True) / 15
                    img_batch = x_index#.view(x_index.shape[0], 1, -1, x_index.shape[-1])
                    tb_writer.add_images('query', img_batch, global_step)

            batch_idx += 1
            #break
        #break

        #main_lr_sched.step()

        model.eval()

        with torch.no_grad():
            query_num = 5

            total_val_error = 0.
            total_loss = 0.
            for batch in tqdm(val_data_loader):
                x = Variable(batch[0].type(LongTensor))
                y = Variable(batch[1].type(LongTensor))
                program_seq = Variable(batch[2].type(LongTensor))
                plengths = Variable(batch[3].type(LongTensor))

                start_input_grids, start_output_grids = get_first_io(program_seq, simulator, vocab)
                input_grids, output_grids = start_input_grids, start_output_grids
                for query_step in range(query_num):
                    embedding = model.module.query.encode_io(input_grids, output_grids)
                    mus_t, logvars_t = model.module.query.encode_into_t(embedding)
                    mus_t, logvars_t = model.module.query.intersection(mus_t, logvars_t)
                    #embedding = model.module.query.reparameterize(mus_t, logvars_t)
                    #embedding = mus_t
                    embedding = torch.cat([mus_t, logvars_t], -1)
                    ############## latent code ###########
                    if params.latent_code:
                        latent_targ = torch.Tensor([query_step] * input_grids.shape[0]).cuda().view(-1, 1)
                        if params.noise:
                            noise_targ = torch.Tensor(np.random.uniform(-1, 1, (input_grids.shape[0], 5))).cuda()
                            latent_emb = torch.cat([latent_targ, noise_targ], -1)
                        else:
                            latent_emb = latent_targ
                        #latent_emb = F.one_hot(latent_targ.long(), params.query_num).squeeze(1)
                        embedding = torch.cat([embedding, latent_emb], -1)
                    ######################################
                    query_inp = model.module.query.decode_process(embedding, params.hard_softmax)

                    query_out = model.module.env_step(query_inp, program_seq, simulator, params.hard_softmax)
                    if query_step > 0:
                        input_grids = torch.cat([input_grids, query_inp], 1)
                        output_grids = torch.cat([output_grids, query_out], 1)
                    else:
                        input_grids = query_inp 
                        output_grids = query_out 
                
                io_features = model.module.query.encode_io(input_grids, output_grids)
                ############### f-space distribution ############
                mus_t, logvars_t = model.module.query.encode_into_t(io_features)
                mus_t, logvars_t = model.module.query.intersection(mus_t, logvars_t)
                #################################################
                program_features = model.module.query.encode_program(program_seq, plengths)
                #print('program_featrues', program_features)

                io_features = torch.cat([mus_t, logvars_t], -1)
                io_features = io_features.view(-1, params.dist_dim * 2)
                program_features = program_features.view(-1, params.dist_dim)

                #io_features = io_features.view(-1, params.dense_output_size)
                #program_features = program_features.repeat_interleave(query_step + 1, 0)
                ############### prob loss ##############
                program_features = program_features.view(-1, params.dist_dim)
                loss_infonce = prob_citerion(program_features, io_features)
                ########################################
                ############### kl loss ###############
                #program_features = program_features.view(-1, params.dist_dim * 2)

                #embeddings = torch.cat([io_features, program_features], 0)
                ##print(embeddings.shape)
                ## DDP info-nce
                #embeddings_list = [torch.ones_like(embeddings) for _ in range(dist.get_world_size())]
                #dist.all_gather(embeddings_list, embeddings)
                #embeddings_list[dist.get_rank()] = embeddings
                #embeddings = torch.cat(embeddings_list)

                #indices = torch.arange(0, int(program_features.shape[0]), device='cuda')
                #labels = torch.cat([indices, indices], 0)
                ##labels = torch.cat([indices.repeat_interleave(query_step + 1, 0), indices], 0)
                ##labels = torch.cat([indices.repeat_interleave(num_examples, 0), indices], 0)

                #labels = torch.cat([labels + offset * program_features.shape[0] for offset in range(dist.get_world_size())])

                #loss_infonce = info_citerion(embeddings, labels)
                ########################################

                total_loss += (loss_infonce * program_seq.shape[0]).item()
                #break

            t1 = torch.tensor([total_loss], dtype=torch.float64, device='cuda')
            dist.barrier()
            dist.all_reduce(t1)
            total_loss = \
                t1.tolist()[0] / len(val_dataset)
            if dist.get_rank() == 0:
                tb_writer.add_scalar("val/total_loss", total_loss, epoch)

            if total_loss < best_val_loss:
                ckpt_path = save_dir / 'model-best'
                print("Found new best model")
                best_val_loss = total_loss
#                save(model, optimizer, epoch, params, save_dir)
                save_checkpoint(models, optims, models_name, optims_name, epoch, best_val_loss, ckpt_path)
                patience_ctr = 0
            else:
                patience_ctr += 1
                if patience_ctr == params.patience:
                    print("Ran out of patience. Stopping training early...")
                    break
            ckpt_path = save_dir / 'model-latest'
            save_checkpoint(models, optims, models_name, optims_name, epoch, best_val_loss, ckpt_path)
        #break

        ############## dump dataset #################
    #model.eval()
    #with torch.no_grad():
    #    if params.dump_dataset:
    #        if dist.get_rank() == 0:
    #            ckpt_path = save_dir / 'model-best'
    #            print("=> loading checkpoint '{}'".format(ckpt_path))
    #            loaded_models, loaded_opts, start_epoch, best_val_error = \
    #                load_checkpoint(models, optims, models_name, optims_name, ckpt_path)
    #            model = loaded_models[0]
    #            model.eval()
    #            ios = generate_ios(train_program, train_typ, model, train_step, train_drop_target)
    #            f = open(save_dir / 'train_gps', 'w')
    #            for item in ios:
    #                problem = dict(program=item['program'], examples=item['examples'])
    #                f.write(json.dumps(problem) + '\n')
    #            f.close()

    #            ios = generate_ios(val_program, val_typ, model, val_step, val_drop_target)
    #            f = open(save_dir / 'val_gps', 'w')
    #            for item in ios:
    #                problem = dict(program=item['program'], examples=item['examples'])
    #                f.write(json.dumps(problem) + '\n')
    #            f.close()
    #        dist.barrier()
        #############################################

if __name__ == '__main__':
    train()