try:
    from torch.utils.tensorboard import SummaryWriter
except:
    from tensorboardX import SummaryWriter
import os, utils, glob, losses
import sys
from torch.utils.data import DataLoader
from data import datasets, trans
import numpy as np
import torch
from torchvision import transforms
from torch import optim
import torch.nn as nn
import matplotlib.pyplot as plt
from natsort import natsorted
from models.TransMorph import CONFIGS as CONFIGS_TM
import models.TransMorph as TransMorph
import argparse

class Logger(object):
    def __init__(self, save_dir):
        self.terminal = sys.stdout
        self.log = open(save_dir+"logfile.log", "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        pass

def main(args):
    batch_size = 1
    data_dir = "/mnt/anon_data2/neurite-OASIS/"
    # data_dir = '/data/anon/OASIS/'
    #train_dir = 'D:/DATA/OASIS/All/'
    #val_dir = 'D:/DATA/OASIS/Test/'
    weights = [args.train_ncc, args.train_dice, 1] # loss weights
    if args.train_dice <= 0 and args.train_ncc <= 0:
        print('At least one of the loss weights should be positive')
        print("we have weights for NCC: {} and DSC: {}".format(args.train_ncc, args.train_dice))
        return
    save_dir = 'TransMorph_ncc_{}_dsc{}_diffusion_{}_{}/'.format(weights[0], weights[1], weights[2], args.model)
    if not os.path.exists('experiments/'+save_dir):
        os.makedirs('experiments/'+save_dir)
    if not os.path.exists('logs/'+save_dir):
        os.makedirs('logs/'+save_dir)
    sys.stdout = Logger('logs/'+save_dir)
    lr = 0.0001 # learning rate
    epoch_start = 0
    max_epoch = 500 #max traning epoch
    cont_training = False #if continue training

    '''
    Initialize model
    '''
    # config = CONFIGS_TM['TransMorph-Large']
    model = args.model
    if model == 'large':
        config = CONFIGS_TM['TransMorph-Large']
    elif model == 'regular':
        config = CONFIGS_TM['TransMorph']
    elif model == 'tiny':
        config = CONFIGS_TM['TransMorph-Tiny']
    print("Choose model: ", model)
    model = TransMorph.TransMorph(config)
    model.cuda()

    optimizer = None
    if args.resume:
        print("Resuming, finding best model...")
        all_models = natsorted(glob.glob('experiments/'+save_dir+'*'))
        if len(all_models) > 0:
            data = torch.load(all_models[-1])
            epoch_start = data['epoch']
            model.load_state_dict(data['state_dict'])
            updated_lr = round(lr * np.power(1 - (epoch_start) / max_epoch,0.9),8)
            optimizer = optim.Adam(model.parameters(), lr=updated_lr, weight_decay=0, amsgrad=True)
            optimizer.load_state_dict(data['optimizer'])
            print("Model loaded at epoch {} with best DSC: {}".format(epoch_start, data['best_dsc']))
        else:
            print("No model found, starting from scratch")
        # print(all_models)
        # input("Press Enter to continue...")
        #      save_checkpoint({
        #     'epoch': epoch + 1,
        #     'state_dict': model.state_dict(),
        #     'best_dsc': best_dsc,
        #     'optimizer': optimizer.state_dict(),
        # }

    '''
    Initialize spatial transformation function
    '''
    reg_model = utils.register_model(config.img_size, 'nearest')
    reg_model.cuda()
    reg_model_bilin = utils.register_model(config.img_size, 'bilinear')
    reg_model_bilin.cuda()

    '''
    If continue from previous training
    '''
    if cont_training:
        epoch_start = 201
        model_dir = 'experiments/'+save_dir
        updated_lr = round(lr * np.power(1 - (epoch_start) / max_epoch,0.9),8)
        best_model = torch.load(model_dir + natsorted(os.listdir(model_dir))[-1])['state_dict']
        print('Model: {} loaded!'.format(natsorted(os.listdir(model_dir))[-1]))
        model.load_state_dict(best_model)
    else:
        updated_lr = lr

    '''
    Initialize training
    '''
    train_composed = transforms.Compose([trans.NumpyType((np.float32, np.int16)),
                                         ])
    val_composed = transforms.Compose([trans.NumpyType((np.float32, np.int16))])
    train_set = datasets.OASISNiftiDataset(data_dir, split='train', transforms=train_composed)
    val_set = datasets.OASISNiftiDataset(data_dir, split='val', transforms=val_composed)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)

    if optimizer is None:
        optimizer = optim.Adam(model.parameters(), lr=updated_lr, weight_decay=0, amsgrad=True)
    criterion_ncc = losses.NCC_vxm()
    criterion_dsc = losses.DiceLoss()
    criterion_reg = losses.Grad3d(penalty='l2')
    best_dsc = 0
    writer = SummaryWriter(log_dir='logs/'+save_dir)
    for epoch in range(epoch_start, max_epoch):
        print('Training Starts')
        '''
        Training
        '''
        loss_all = utils.AverageMeter()
        idx = 0
        for data in train_loader:
            idx += 1
            model.train()
            adjust_learning_rate(optimizer, epoch, max_epoch, lr)
            data = [t.cuda() for t in data]
            x = data[0]
            y = data[1]
            x_seg = data[2]
            y_seg = data[3]

            x_seg_oh = nn.functional.one_hot(x_seg.long(), num_classes=36)
            x_seg_oh = torch.squeeze(x_seg_oh, 1)
            x_seg_oh = x_seg_oh.permute(0, 4, 1, 2, 3).contiguous()

            x_in = torch.cat((x,y), dim=1)
            output, flow = model(x_in)

            loss_dsc = torch.tensor(0.0, device=x.device)
            if weights[1] > 0:
                def_segs = []
                for i in range(36):
                    def_seg = model.spatial_trans(x_seg_oh[:, i:i + 1, ...].float(), flow.float())
                    def_segs.append(def_seg)
                def_seg = torch.cat(def_segs, dim=1)
                loss_dsc = criterion_dsc(def_seg, y_seg.long()) * weights[1]

            loss_ncc = criterion_ncc(output, y) * weights[0]
            loss_reg = criterion_reg(flow, y) * weights[2]
            loss = loss_ncc + loss_dsc + loss_reg
            loss_all.update(loss.item(), y.numel())
            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            del x_seg_oh, x_in, def_segs, def_seg, loss

            y_seg_oh = nn.functional.one_hot(y_seg.long(), num_classes=36)
            y_seg_oh = torch.squeeze(y_seg_oh, 1)
            y_seg_oh = y_seg_oh.permute(0, 4, 1, 2, 3).contiguous()

            y_in = torch.cat((y, x), dim=1)
            output, flow = model(y_in)

            loss_dsc = torch.tensor(1).to(x.device)
            if weights[1] > 0:
                def_segs = []
                for i in range(36):
                    def_seg = model.spatial_trans(y_seg_oh[:, i:i + 1, ...].float(), flow.float())
                    def_segs.append(def_seg)
                def_seg = torch.cat(def_segs, dim=1)
                loss_dsc = criterion_dsc(def_seg, x_seg.long()) * weights[1]
            # init to dummy
            loss_ncc = torch.tensor(1).to(x.device)
            if weights[0] > 0:
                loss_ncc = criterion_ncc(output, x) * weights[0] 
            loss_reg = criterion_reg(flow, x) * weights[2]
            loss = loss_ncc + loss_dsc + loss_reg
            loss_all.update(loss.item(), x.numel())
            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            del y_seg_oh, y_in, def_segs, def_seg
            print('Iter {} of {} loss {:.4f}, Img Sim: {:.6f}, DSC: {:.6f}, Reg: {:.6f}'.format(idx, len(train_loader),
                                                                                                loss.item(),
                                                                                                loss_ncc.item(),
                                                                                                loss_dsc.item(),
                                                                                                loss_reg.item()))


        writer.add_scalar('Loss/train', loss_all.avg, epoch)
        print('Epoch {} loss {:.4f}'.format(epoch, loss_all.avg))
        '''
        Validation
        '''
        eval_dsc = utils.AverageMeter()
        with torch.no_grad():
            for data in val_loader:
                model.eval()
                data = [t.cuda() for t in data]
                x = data[0]
                y = data[1]
                x_seg = data[2]
                y_seg = data[3]
                x_in = torch.cat((x, y), dim=1)
                grid_img = mk_grid_img(8, 1, config.img_size)
                output = model(x_in)
                def_out = reg_model([x_seg.cuda().float(), output[1].cuda()])
                def_grid = reg_model_bilin([grid_img.float(), output[1].cuda()])
                dsc = utils.dice_val_VOI(def_out.long(), y_seg.long())
                eval_dsc.update(dsc.item(), x.size(0))
                print(eval_dsc.avg)
        # new average is best
        if best_dsc < eval_dsc.avg:
            best_dsc = max(eval_dsc.avg, best_dsc)
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_dsc': best_dsc,
                'optimizer': optimizer.state_dict(),
            }, save_dir='experiments/'+save_dir, filename='dsc{:.4f}.pth.tar'.format(eval_dsc.avg))

        writer.add_scalar('DSC/validate', eval_dsc.avg, epoch)
        plt.switch_backend('agg')
        pred_fig = comput_fig(def_out)
        grid_fig = comput_fig(def_grid)
        x_fig = comput_fig(x_seg)
        tar_fig = comput_fig(y_seg)
        writer.add_figure('Grid', grid_fig, epoch)
        plt.close(grid_fig)
        writer.add_figure('input', x_fig, epoch)
        plt.close(x_fig)
        writer.add_figure('ground truth', tar_fig, epoch)
        plt.close(tar_fig)
        writer.add_figure('prediction', pred_fig, epoch)
        plt.close(pred_fig)
        loss_all.reset()
        del def_out, def_grid, grid_img, output
    writer.close()

def comput_fig(img):
    img = img.detach().cpu().numpy()[0, 0, 48:64, :, :]
    fig = plt.figure(figsize=(12,12), dpi=180)
    for i in range(img.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.axis('off')
        plt.imshow(img[i, :, :], cmap='gray')
    fig.subplots_adjust(wspace=0, hspace=0)
    return fig

def adjust_learning_rate(optimizer, epoch, MAX_EPOCHES, INIT_LR, power=0.9):
    for param_group in optimizer.param_groups:
        param_group['lr'] = round(INIT_LR * np.power( 1 - (epoch) / MAX_EPOCHES ,power),8)

def mk_grid_img(grid_step, line_thickness=1, grid_sz=(160, 192, 224)):
    grid_img = np.zeros(grid_sz)
    for j in range(0, grid_img.shape[1], grid_step):
        grid_img[:, j+line_thickness-1, :] = 1
    for i in range(0, grid_img.shape[2], grid_step):
        grid_img[:, :, i+line_thickness-1] = 1
    grid_img = grid_img[None, None, ...]
    grid_img = torch.from_numpy(grid_img).cuda()
    return grid_img

def save_checkpoint(state, save_dir='models', filename='checkpoint.pth.tar', max_model_num=8):
    torch.save(state, save_dir+filename)
    model_lists = natsorted(glob.glob(save_dir + '*'))
    while len(model_lists) > max_model_num:
        os.remove(model_lists[0])
        model_lists = natsorted(glob.glob(save_dir + '*'))

if __name__ == '__main__':
    '''
    GPU configuration
    '''
    # load args
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_ncc', type=float, default=1)
    parser.add_argument('--train_dice', type=float, default=0)
    parser.add_argument('--model', type=str, default='regular', choices=['large', 'regular', 'tiny'])
    parser.add_argument('--resume', action='store_true')
    args = parser.parse_args()
    print(args)

    GPU_iden = 0
    GPU_num = torch.cuda.device_count()
    print('Number of GPU: ' + str(GPU_num))
    for GPU_idx in range(GPU_num):
        GPU_name = torch.cuda.get_device_name(GPU_idx)
        print('     GPU #' + str(GPU_idx) + ': ' + GPU_name)
    torch.cuda.set_device(GPU_iden)
    GPU_avai = torch.cuda.is_available()
    print('Currently using: ' + torch.cuda.get_device_name(GPU_iden))
    print('If the GPU is available? ' + str(GPU_avai))
    torch.manual_seed(0)
    main(args)
