import math
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from utils import cosine_anneal, linear_warmup
from torch.nn.utils import clip_grad_norm_



def compute_loss_coef(config, epoch):
    return {'nll': 1}




def train_model(config, data_loaders, net):    
    optimizer = optim.Adam([
    {'params': (x[1] for x in net.named_parameters() if 'modelencoder' in x[0]), 'lr': 0.0},
    {'params': (x[1] for x in net.named_parameters() if 'modeldecoder' in x[0]), 'lr': 0.0},
    ])
    phase_list = ['train']
    save_phase = 'test'
    if 'valid' in data_loaders:
        phase_list.append('valid')
        save_phase = 'valid'
    path_ckpt = os.path.join(config['folder_out'], config['file_ckpt'])
    path_model = os.path.join(config['folder_out'], config['file_model'])
    #print(config['resume'])
    if config['resume']:
        checkpoint = torch.load(path_ckpt)
        start_epoch = checkpoint['epoch'] + 1
        best_epoch = checkpoint['best_epoch']
        best_loss = checkpoint['best_loss']
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print('Resume training from epoch {}'.format(start_epoch))
    else:
        start_epoch = 0
        best_epoch = -1
        best_loss = float('inf')
        print('Start training')
    print()
    accumulate_steps=0
    with SummaryWriter(log_dir=config['folder_log'], purge_step=start_epoch) as writer:
        for epoch in range(start_epoch, config['num_epochs']):
            loss_coef = compute_loss_coef(config, epoch)
            print('Data:{}, Model:{}, Epoch: {}/{}'.format(config['dataset'], config['model_name'], epoch, config['num_epochs'] - 1))
            for phase in phase_list:
                phase_param = config['phase_param'][phase]
                net.train(phase == 'train')
                sum_losses={}
                num_data = 0
                datasets=data_loaders[phase]
                print('training Model')
                for idx_batch, data in enumerate(tqdm(datasets, bar_format='{l_bar}{bar:20}{r_bar}{bar:-2000b}')):
                    global_step = epoch * len(datasets) + idx_batch
                    if config.get('lrs',True):

                        lr_warmup_factor_enc = linear_warmup(
                            global_step,
                            0.,
                            1.0,
                            0.,
                            config.get('lr_warmup_steps',30000))

                        lr_warmup_factor_dec = linear_warmup(
                            global_step,
                            0.,
                            1.0,
                            0,
                            config.get('lr_warmup_steps',30000))

                        lr_decay_factor = math.exp(global_step / config.get('lr_half_life',250000) * math.log(0.5))

                    else:
                        lr_warmup_factor_enc=1.0
                        lr_warmup_factor_dec=1.0
                        lr_decay_factor=1.0

                    optimizer.param_groups[0]['lr'] = lr_decay_factor * lr_warmup_factor_enc * config.get('lr_enc',3e-5)

                    optimizer.param_groups[1]['lr'] = lr_decay_factor * lr_warmup_factor_dec * config.get('lr_dec',1e-4)


                    batch_size = data['image'].shape[0]
                    if phase == 'train':
                        enable_grad = True
                    else:
                        enable_grad = False
                    if idx_batch == 0 and epoch % config['summ_image_intvl'] == 0:
                        with torch.set_grad_enabled(False):
                            results, losses = net(data, phase_param,requires_results=True)

                        overview = compute_overview(config, data['image'], results)
                        writer.add_image('{}/overview'.format(phase.capitalize()),  overview, global_step=epoch)
                        writer.flush()

                    with torch.set_grad_enabled(enable_grad):
                        results, losses = net(data, phase_param)
                    for key, val in losses.items():
                        if key in sum_losses:
                            sum_losses[key] += val.sum().item()*batch_size
                        else:
                            sum_losses[key] = val.sum().item()*batch_size
                   
                    num_data += batch_size
                    if phase == 'train':
                        loss_opt = torch.stack(
                            [loss_coef[key] * val for key, val in losses.items()]).sum()
                        loss_opt = loss_opt / config.get('gradient_accumulation_steps',1)

                        loss_opt.backward()
                        accumulate_steps += 1
                        if (accumulate_steps+1) % config.get('gradient_accumulation_steps',1) == 0:
                            clip_grad_norm_(net.parameters(), config.get('clip',0.05), 'inf')
                            optimizer.step()
                            optimizer.zero_grad()

                mean_losses = {key: val / num_data for key, val in sum_losses.items()}
                

                if epoch % config['summ_scalar_intvl'] == 0:
                    for key, val in mean_losses.items():
                        writer.add_scalar('{}/loss_{}'.format(phase.capitalize(), key), val, global_step=epoch)
                    writer.flush()
                print(phase.capitalize())

                if phase == save_phase:
                    if mean_losses['nll'] < best_loss:
                        best_loss = mean_losses['nll']
                        best_epoch = epoch
                        torch.save(net.state_dict(), path_model)
            save_dict = {
                'epoch': epoch,
                'best_epoch': best_epoch,
                'best_loss': best_loss,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }
            torch.save(save_dict, path_ckpt)
            print('Best Epoch: {}'.format(best_epoch))
            print()
    return


def compute_overview(config, images, results, dpi=150):
    def convert_image(image):
        image = np.moveaxis(image, 0, 2)
        if image.shape[2] == 1:
            image = np.repeat(image, 3, axis=2)
        return image

    def plot_image(ax, image, xlabel=None, ylabel=None, color=None):
        plot = ax.imshow(image, interpolation='bilinear')
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlabel(xlabel, color='k' if color is None else color, fontfamily='monospace') if xlabel else None
        ax.set_ylabel(ylabel, color='k' if color is None else color, fontfamily='monospace') if ylabel else None
        ax.xaxis.set_label_position('top')
        return plot

    def plot_rectangle(ax, val_scl, val_trs, val_pres=True):
        image_ht, image_wd = image_batch.shape[-2:]
        image_size = np.array([image_wd, image_ht]) - 1
        rect_center = (val_trs + 1) * 0.5 * image_size
        rect_size = val_scl * image_size
        anchor_x, anchor_y = rect_center - 0.5 * rect_size
        size_x, size_y = rect_size
        edgecolor = 'red' if val_pres else 'green'
        ax.add_patch(Rectangle((anchor_x, anchor_y), size_x, size_y, edgecolor=edgecolor, fill=False))
        return

    def get_overview(fig_idx):
        image = image_batch[fig_idx]
        recon = recon_batch[fig_idx]
        apc = apc_batch[fig_idx]
        mask = mask_batch[fig_idx]
        mx = mx_batch[fig_idx]
        infer_mask = infer_mask_batch[fig_idx]
        infer_mx = infer_mx_batch[fig_idx]
        infer_recon = infer_recon_batch[fig_idx]
        ins_seg_gt = ins_seg_gt_batch[fig_idx]
        ins_seg = ins_seg_batch[fig_idx]
        infer_ins_seg = infer_ins_seg_batch[fig_idx]
        rows, cols = 2, apc.shape[0] + 4
        fig, axes = plt.subplots(rows, cols, figsize=(cols, rows + 0.2), dpi=dpi)
        plot_image(axes[0, 0], convert_image(image), xlabel='scene')
        plot_image(axes[0, 1], convert_image(recon), xlabel='recon')
        plot_image(axes[0, 2], convert_image(ins_seg_gt), xlabel='segment_gt')
        plot_image(axes[0, 3], convert_image(ins_seg), xlabel='segment')
        plot_image(axes[1, 0], convert_image(image))
        plot_image(axes[1, 1], convert_image(recon))
        plot_image(axes[1, 2], convert_image(ins_seg_gt))
        plot_image(axes[1, 3], convert_image(infer_ins_seg))

        for idx in range(apc.shape[0]):
            xlabel = 'obj_{}'.format(idx)
            plot_image(axes[0, idx + 4], convert_image(mx[idx]),xlabel=xlabel)
            plot_image(axes[1, idx + 4], convert_image(infer_mx[idx]))
        fig.tight_layout(pad=0)
        fig.canvas.draw()
        width, height = fig.canvas.get_width_height()
        out = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8').reshape(height, width, -1)
        plt.close(fig)
        return out
    
    B=images.shape[0]

    summ_image_count = min(config['summ_image_count'],B)
    image_batch = images[:summ_image_count].data.cpu().numpy()
    recon_batch = results['recon'][:summ_image_count].data.cpu().numpy()
    apc_batch = results['apc_all'][:summ_image_count].data.cpu().numpy()
    mask_batch = results['mask'][:summ_image_count].data.cpu().numpy()
    mx_batch = results['m_x'][:summ_image_count].data.cpu().numpy()
    infer_mask_batch = results['infer_mask'][:summ_image_count].data.cpu().numpy()
    infer_mx_batch = results['infer_m_x'][:summ_image_count].data.cpu().numpy()
    infer_recon_batch = results['recon'][:summ_image_count].data.cpu().numpy()
    ins_seg_gt_batch = results['ins_seg_gt'][:summ_image_count].data.cpu().numpy()
    infer_ins_seg_batch = results['infer_ins_seg'][:summ_image_count].data.cpu().numpy()
    ins_seg_batch = results['ins_seg'][:summ_image_count].data.cpu().numpy()
    overview_list = [get_overview(idx) for idx in range(summ_image_count)]
    overview = np.concatenate(overview_list, axis=0)
    overview = np.moveaxis(overview, 2, 0)
    return overview