import cv2
import datetime
import logging
import math
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import time
from tqdm import tqdm

import torch
from torch import autograd
from tensorboardX import SummaryWriter

import sys
def add_path(path):
    if path not in sys.path:
        sys.path.insert(0, path)
from libs.evaluators.ir_eval import Evaluator
import libs.utils.misc as utils
from libs.utils.utils import save_checkpoint

import librosa, librosa.display

def data_loop(data_loader):
    """
    Loop an iterable infinitely
    """
    while True:
        for x in iter(data_loader):
            yield x


# TODO logging the info
class Trainer(object):
    def __init__(self,
                 cfg,
                 render,
                 criterion,
                 optimizer,
                 lr_scheduler,
                 logger,
                 log_dir,
                 performance_indicator='mse',
                 last_iter=-1,
                 rank=0,
                 device='cuda'):
        self.cfg = cfg
        self.render = render
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.criterion = criterion
        self.logger = logger
        if log_dir:
            self.log_dir = os.path.join(log_dir, self.cfg.output_dir)
            self.epoch = last_iter + 1
        self.PI = performance_indicator
        self.rank = rank
        self.best_performance = 0.0
        self.is_best = False
        self.max_epoch = self.cfg.train.max_epoch
        self.model_name = self.cfg.render.file
        self.device = device
        self.iter_count = 0
        if self.optimizer is not None and rank == 0:
            self.writer = SummaryWriter(self.log_dir, comment=f'_rank{rank}')
            logging.info(f"max epochs = {self.max_epoch} ")
        self.evaluator = Evaluator(self.cfg, 'eval')

    def _read_inputs(self, batch):
        for k in range(len(batch)):
            if isinstance(batch[k], tuple) or isinstance(batch[k], list):
                batch[k] = [b.to(self.device) for b in batch[k]]
            if isinstance(batch[k], dict):
                batch[k] = {key: value.to(self.device) for key, value in batch[k].items()}
            else:
                batch[k] = batch[k].to(self.device)
        return batch

    def _forward(self, data):
        source_points, points, norm_source_points, norm_points, dirs, gt_ir, b_range = data
        gt_ir = gt_ir.reshape(2*len(source_points), -1)
        pred_ir = self.render.module.render(source_points, points,norm_source_points, norm_points, dirs, b_range)
        loss = self.criterion(pred_ir, gt_ir)
        return loss

    def train(self, train_loader, eval_loader):
        start_time = time.time()
        self.render.train()
        self.criterion.train()
        metric_logger = utils.MetricLogger(delimiter="  ")
        metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
        header = 'Epoch: [{}]'.format(self.epoch)
        print_freq = self.cfg.train.print_freq
        eval_data_iter = data_loop(eval_loader)
        if self.epoch > self.max_epoch:
            logging.info("Optimization is done !")
            sys.exit(0)
        for data in metric_logger.log_every(train_loader, print_freq, header, self.logger):
            data = self._read_inputs(data)
            loss_dict = self._forward(data)
            losses = sum(loss_dict[k] for k in loss_dict.keys())

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = utils.reduce_dict(loss_dict)
            loss_value = sum(loss_dict_reduced.values()).item()
            if not math.isfinite(loss_value):
                # print("Loss is {}, stopping training".format(loss_value))
                self.logger.info("Loss is {}, stopping training".format(loss_value))
                sys.exit(1)

            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()
            metric_logger.update(loss=loss_value, **loss_dict_reduced)
            metric_logger.update(lr=self.optimizer.param_groups[0]["lr"])

            self.iter_count += 1
            # quick val
            if self.rank == 0 and self.iter_count % self.cfg.train.valiter_interval == 0:
                # evaluation
                if self.cfg.train.val_when_train:
                    performance = self.quick_val(eval_data_iter)
                    self.writer.add_scalar(self.PI, performance, self.iter_count)
                    logging.info('Now: {} is {:.4f}'.format(self.PI, performance))

        # gather the stats from all processes
        metric_logger.synchronize_between_processes()
        train_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     'epoch': self.epoch, 'iter': self.iter_count}
        if self.rank == 0:
            for (key, val) in log_stats.items():
                self.writer.add_scalar(key, val, log_stats['iter'])
        self.lr_scheduler.step()

        # save checkpoint
        if self.rank == 0 and self.epoch > 0 and self.epoch % self.cfg.train.save_interval == 0:
            # evaluation TODO val all the val
            if self.cfg.train.val_when_train:
                performance = self.quick_val(eval_data_iter)
                self.writer.add_scalar(self.PI, performance, self.iter_count)
                if performance > self.best_performance:
                    self.is_best = True
                    self.best_performance = performance
                else:
                    self.is_best = False
                logging.info(f'Now: best {self.PI} is {self.best_performance}')
            else:
                performance = -1

            # save checkpoint
            try:
                state_dict = self.render.module.state_dict()  # remove prefix of multi GPUs
            except AttributeError:
                state_dict = self.render.state_dict()

            if self.rank == 0:
                if self.cfg.train.save_every_checkpoint:
                    filename = f"{self.epoch}.pth"
                else:
                    filename = "latest.pth"
                save_dir = os.path.join(self.log_dir, self.cfg.output_dir)
                save_checkpoint(
                    {
                        'epoch': self.epoch,
                        'model': self.model_name,
                        f'performance/{self.PI}': performance,
                        'state_dict': state_dict,
                        'optimizer': self.optimizer.state_dict(),
                    },
                    self.is_best,
                    save_dir,
                    filename=f'{filename}'
                )
                # remove previous pretrained model if the number of models is too big
                pths = [
                    int(pth.split('.')[0]) for pth in os.listdir(save_dir)
                    if pth != 'latest.pth' and pth != 'model_best.pth'
                ]
                if len(pths) > 20:
                    os.system('rm {}'.format(
                        os.path.join(save_dir, '{}.pth'.format(min(pths)))))

        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        # print('Training time {}'.format(total_time_str))
        self.logger.info('Training time {}'.format(total_time_str))
        self.epoch += 1

    def quick_val(self, eval_data_iter):
        self.render.eval()
        self.criterion.eval()
        val_stats = {}
        plot_stats = {}
        with torch.no_grad():
            val_data = next(eval_data_iter)
            val_data = self._read_inputs(val_data)
            index, source_points, points, norm_source_points, norm_points,  dirs, gt_ir, b_range = val_data
            gt_ir = gt_ir.reshape(2 * len(source_points), -1)
            pred_ir = self.render.module.render(source_points, points,norm_source_points, norm_points, dirs, b_range)
            # gt_ir = abs(band_gt).sum(-1)
            B = 2 * len(source_points)
            idx = np.random.choice(range(B))

            plot_stat = self.process_img(pred_ir[idx], gt_ir[idx])
            plot_stats.update(plot_stat)
            loss_dict = self.criterion(pred_ir, gt_ir)
            np_pred_ir = pred_ir.cpu().numpy()
            np_gt_ir = gt_ir.cpu().numpy()
            for cur_pred_ir, cur_gt_ir in zip(np_pred_ir, np_gt_ir):
                self.evaluator.evaluate_t60(cur_pred_ir, cur_gt_ir)
            loss_stats = utils.reduce_dict(loss_dict)
            for k, v in loss_stats.items():
                val_stats.setdefault(k, 0)
                val_stats[k] += v
            result = {
                'mse': self.evaluator.mse[-1],
                'psnr': self.evaluator.psnr[-1],
                't60': self.evaluator.t60_error[-1] if len(self.evaluator.t60_error) != 0 else 100
            }
            val_stats.update(result)

        # save metrics and loss
        log_stats = {**{f'eval_{k}': v for k, v in val_stats.items()},
                     'epoch': self.epoch, 'iter': self.iter_count}
        for (key, val) in log_stats.items():
            self.writer.add_scalar(key, val, log_stats['iter'])

        if plot_stats is not None:
            pattern = 'val_iter/{}'
            for k, v in plot_stats.items():
                self.writer.add_figure(pattern.format(k), v, log_stats['iter'])
                v.savefig('weighted_mse_val.jpg')
                # if self.cfg.dataset.name == 'large2':
                #     v.savefig('large2_decomp_val.jpg')
                # elif self.cfg.dataset.name == 'small1':
                #     v.savefig('small1_decomp_val.jpg')
                # elif self.cfg.dataset.name == 'medium1':
                #     v.savefig('medium1_decomp_val.jpg')
                # elif self.cfg.dataset.name == 'large1':
                #     v.savefig('large1_decomp_v4_nodist_val.jpg')
                # v.savefig('small2_val.jpg')
        mse_loss, mse, psnr = val_stats['mse_loss'], val_stats['mse'], val_stats['psnr']
        t60 = val_stats['t60']
        msg = 'mse_loss: {:.4f}, mse: {:.4f}, psnr: {:.4f}, t60: {:.4f}'.format(mse_loss, mse, psnr, t60)
        self.logger.info(msg)

        self.render.train()
        self.criterion.train()
        return val_stats[self.PI]

    @staticmethod
    def process_img(pred_ir, gt_ir):
        # TODO save pred_ir in a meaningful way for visualization check!
        # pred_ir: `(n_bins, n_samples_each_bin)`
        gt_ir_plot = gt_ir.reshape(-1).data.cpu().numpy()
        pred_ir_plot = pred_ir.reshape(-1).data.cpu().numpy()
        min_val = np.minimum(np.min(gt_ir_plot), np.min(pred_ir_plot))
        max_val = np.maximum(np.max(gt_ir_plot), np.max(pred_ir_plot))
        gt_spec = librosa.stft(gt_ir_plot, n_fft=2048, hop_length=256)
        pred_spec = librosa.stft(pred_ir_plot, n_fft=2048, hop_length=256)
        fig = plt.figure()
        ax_gen = fig.add_subplot(3, 1, 1)
        ax_gen.plot(gt_ir_plot, color='green')
        ax_gen.plot(pred_ir_plot, color='red', alpha=0.7)
        ax_gen.set_ylim(min_val, max_val)
        ax_gt_ir_spec = fig.add_subplot(3, 1, 2)
        gt_spec_img = librosa.display.specshow(librosa.amplitude_to_db(abs(gt_spec)), sr=22050, hop_length=512,x_axis='time', y_axis='log', ax=ax_gt_ir_spec)

        ax_pred_ir_spec = fig.add_subplot(3, 1, 3)
        pred_spec_img = librosa.display.specshow(librosa.amplitude_to_db(abs(pred_spec)), sr=22050, hop_length=512, x_axis='time',
                                 y_axis='log', ax=ax_pred_ir_spec)
        fig.colorbar(gt_spec_img, ax=[ax_gt_ir_spec, ax_pred_ir_spec])
        fig.tight_layout()
        return {'fig_plot': fig}
