"""
Auxiliary handlers for use during training.
"""


import os
import csv
# import numpy as np
import torch
import ignite.handlers as hdlr
from ignite.handlers import EarlyStopping
from ignite.engine import Events


class ModelCheckpoint(hdlr.ModelCheckpoint):
    @property
    def last_checkpoint_state(self):
        with open(self.last_checkpoint, mode='rb') as f:
            state_dict = torch.load(f)
        return state_dict

    @property
    def all_paths(self):
        def name_path_tuple(p):
            return p.filename, os.path.join(self.save_handler.dirname,
                                            p.filename)

        return [name_path_tuple(p) for p in self._saved]


class LRScheduler(object):
    def __init__(self, scheduler, loss):
        self.scheduler = scheduler
        self.loss = loss

    def __call__(self, engine):
        loss_val = engine.state.metrics[self.loss]
        self.scheduler.step(loss_val)

    def attach(self, engine):
        engine.add_event_handler(Events.COMPLETED, self)
        return self


class Tracer(object):
    def __init__(self, val_metrics, save_path=None, save_interval=1):
        self.metrics = ['loss']
        self.loss = []
        self.save_path = save_path
        self.save_interval = save_interval
        self._running_loss = 0
        self._n_inputs = 0
        template = 'val_{}'
        for k in val_metrics:
            name = template.format(k)
            setattr(self, name, [])
            self.metrics.append(name)

    def _initalize_traces(self, engine):
        for k in self.metrics:
            getattr(self, k).clear()
    def _save_batch_loss(self, engine):
        n_examples = engine.state.batch[1].size(0)
        self._running_loss += engine.state.output * n_examples
        self._n_inputs += n_examples

    def _compute_training_loss(self, engine):
        epoch_loss = self._running_loss / self._n_inputs
        self.loss.append(epoch_loss)
        self._running_loss = 0.0
        self._n_inputs = 0
    def _trace_validation(self, engine):
        metrics = engine.state.metrics
        template = 'val_{}'
        for k, v in metrics.items():
            trace = getattr(self, template.format(k))
            trace.append(v)

    def attach(self, trainer, evaluator=None):
        trainer.add_event_handler(Events.STARTED, self._initalize_traces)
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED, self._save_batch_loss)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, self._compute_training_loss)

        if evaluator is not None:
            evaluator.add_event_handler(
                Events.COMPLETED, self._trace_validation)

        if self.save_path is not None:
            trainer.add_event_handler(
                Events.EPOCH_COMPLETED, self._save_at_interval)

        return self

    def _save_at_interval(self, engine):
        #if engine.state.iteration % self.save_interval == 0:
        self.save_traces()

    def save_traces(self):
        for loss in self.metrics:
            trace = getattr(self, loss)
            with open('{}/{}.csv'.format(self.save_path, loss), mode='w') as f:
                wr = csv.writer(f, quoting=csv.QUOTE_ALL)
                for i, v in enumerate(trace):
                    wr.writerow([i + 1, v])



class Tracer2(object):
    def __init__(self, val_metrics, loss_list = ['loss'], save_path=None, save_interval=1, suffix = ""):

        self.save_path = save_path
        self.save_interval = save_interval
        self.suffix = suffix
        self.loss_list = loss_list
        for name in self.loss_list:
            setattr(self, name, [])
            setattr(self, f'_running_loss_{name}', 0)
        self._n_inputs = 0

        self.metrics = []
        for kname in val_metrics:
            #name = f'val_{k}'
            setattr(self, name, [])
            self.metrics.append(name)

    def _initalize_traces(self, engine):
        for name in self.loss_list:
            getattr(self, name).clear()
        for name in self.metrics:
            getattr(self, name).clear()

    def _save_batch_loss(self, engine):
        n_examples = engine.state.batch[1].size(0)
        self._n_inputs += n_examples
        for name in self.loss_list:
            running_loss = getattr(self, f'_running_loss_{name}')
            setattr(self, f'_running_loss_{name}', engine.state.output[name] * n_examples + running_loss)
        #for name in self.metrics:
        #    metric_value = engine.state.metrics[name]
        #    running_metric = getattr(self, name)
        #    setattr(self, name, metric_value * n_examples + running_metric)
        #self._running_loss += engine.state.output * n_examples

    def _compute_training_loss(self, engine):
        for name in self.loss_list:
            epoch_loss = getattr(self, f'_running_loss_{name}') / self._n_inputs
            getattr(self, name).append(epoch_loss)
            setattr(self, f'_running_loss_{name}',0)
        self._n_inputs = 0

    def _trace_validation(self, engine):
        metrics = engine.state.metrics
        template = 'val_{}'
        for k, v in metrics.items():
            trace = getattr(self, template.format(k))
            trace.append(v)

    def attach(self, trainer, evaluator=None):
        trainer.add_event_handler(Events.STARTED, self._initalize_traces)
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED, self._save_batch_loss)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, self._compute_training_loss)

        if evaluator is not None:
            evaluator.add_event_handler(
                Events.COMPLETED, self._trace_validation)

        if self.save_path is not None:
            trainer.add_event_handler(
                Events.EPOCH_COMPLETED, self._save_at_interval)

        return self

    def _save_at_interval(self, engine):
        #if engine.state.iteration % self.save_interval == 0:
        self.save_traces()

    def save_traces(self):
        for loss in self.loss_list:
            trace = getattr(self, loss)
            with open('{}/{}{}.csv'.format(self.save_path, loss,self.suffix), mode='w') as f:
                wr = csv.writer(f, quoting=csv.QUOTE_ALL)
                for i, v in enumerate(trace):
                    wr.writerow([i + 1, v])


class Saver_concat(object):
    def __init__(self, save_folder, name = "", output = None, every_iter = False, train = False):
        if name:
            self.suffix = f'_{name}'
            self.prefix = f'{name}_'
            self.name = name
        else:
            self.suffix = ""
            self.prefix = ""
            self.name = "No_Name"
        self.save_folder = save_folder
        self.epoch_count = 0
        self.iter_count = 0
        self.batch_count = 0

        self.variable_list = []
        self.variables = {}
        self.train = train
        if output:
            self.variable_list_iter = output
        else:
            self.variable_list_iter = []
        self.variables_iter = {}
        self.variables_iter_full_list = {}
        self.variables_iter_epoch_cumul = {}
        self.variables_iter_epoch_mean = {}

        self.every_iter = every_iter
        self.scalar_csv_path = os.path.join(self.save_folder, f'{self.prefix}scalars.csv')
        self.scalar_iter_csv_path = os.path.join(self.save_folder, f'{self.prefix}scalars_iter.csv')

    def epoch_count_add(self, engine):
        self.epoch_count += 1
    def iter_count_add(self, engine):
        self.iter_count += 1
    def iter_clear(self, engine):
        self.iter_count = 0
        
    def attach(self, engine):
        engine.add_event_handler(Events.EPOCH_STARTED, self.epoch_count_add)
        engine.add_event_handler(Events.EPOCH_STARTED, self.iter_clear)

        engine.add_event_handler(Events.ITERATION_STARTED, self.iter_count_add)
        engine.add_event_handler(Events.ITERATION_COMPLETED, self.log_output)
        engine.add_event_handler(Events.ITERATION_COMPLETED, self._save_batch_loss)

        engine.add_event_handler(Events.EPOCH_COMPLETED, self._compute_loss)
        engine.add_event_handler(Events.EPOCH_COMPLETED, self.save_sample)
        if self.train:
            engine.add_event_handler(Events.EPOCH_COMPLETED, self.save_sample_group_action)
        engine.add_event_handler(Events.EPOCH_COMPLETED, self.log_metric)
        
        return self

    def _save_batch_loss(self, engine):
        n_examples = engine.state.batch[1].size(0)
        self.batch_count += n_examples
        for name in self.variable_list_iter:
            if not name in self.variables_iter_epoch_cumul.keys():
                self.variables_iter_epoch_cumul[name] = engine.state.output[name] * n_examples
            else:
                self.variables_iter_epoch_cumul[name] += engine.state.output[name] * n_examples

    def _compute_loss(self, engine):
        for name in self.variable_list_iter:
            epoch_loss = self.variables_iter_epoch_cumul[name] / self.batch_count
            if not name in self.variables_iter_epoch_mean.keys():
                self.variables_iter_epoch_mean[name] = [epoch_loss]
            else:
                self.variables_iter_epoch_mean[name].append(epoch_loss)
            self.variables_iter_epoch_cumul[name] = 0
            if self.every_iter:
                if not name in self.variables_iter_full_list.keys():
                    self.variables_iter_full_list[name] = self.variables_iter[name].copy()
                else:
                    self.variables_iter_full_list[name] += self.variables_iter[name]
            self.variables_iter[name].clear()
        self.batch_count = 0
        

    def save_sample(self, engine):
        x2_predicted = engine.state.output['y_pred']
        batch_size = x2_predicted.shape[0]

        x2_predicted = 1/(1+ torch.exp(-x2_predicted))
        x2 = engine.state.output['y']
        x1 = engine.state.output['template']
        x2_predicted = x2_predicted.cpu().detach().numpy()
        x2 = x2.cpu().detach().numpy()
        x1 = x1.cpu().detach().numpy()
        
        root_save_path = os.path.join(self.save_folder, 'images')
        if not os.path.isdir(root_save_path):
            os.mkdir(root_save_path)
        epoch = self.epoch_count
        epoch_save_path = os.path.join(root_save_path, f'epoch_{epoch}')
        if not os.path.isdir(epoch_save_path):
            os.mkdir(epoch_save_path)
        with open(os.path.join(epoch_save_path, f'output{self.suffix}.dill' ), 'wb') as f:
            dill.dump(engine.state.output, f)
        for i in range(batch_size):
            stri = str(i).zfill(3)
            concated = np.concatenate([x2_predicted[i], x2[i], x1[i]], axis = 2)
            image_save_path = os.path.join(epoch_save_path, f'{stri}{self.suffix}.png')
            if concated.shape[0] == 1:
                concated = Image.fromarray((concated* 255).astype(np.uint8).squeeze(0))
            else:
                concated = Image.fromarray((concated* 255).astype(np.uint8).transpose(1,2,0))
            concated.save(image_save_path)
    def save_sample_group_action(self, engine):
        name_list = [['recon_hg','data1_inv_hg','data2_inv_hg'],
                    ['recon_gh','data1_inv_gh','data2_inv_gh'],
                    ['recon_gk', 'data1_assoc_gk','data2_assoc_gk']]
        midfix_list = ['_inv_hg_', '_inv_gh_', '_assoc_']
        for midfix, [pred_name, temp_name, gt_name] in zip(midfix_list, name_list):
            x2_predicted = engine.state.output[pred_name]
            batch_size = x2_predicted.shape[0]

            x2_predicted = 1/(1+ torch.exp(-x2_predicted))
            x2 = engine.state.output[gt_name]
            x1 = engine.state.output[temp_name]
            x2_predicted = x2_predicted.cpu().detach().numpy()
            x2 = x2.cpu().detach().numpy()
            x1 = x1.cpu().detach().numpy()
            
            root_save_path = os.path.join(self.save_folder, 'images')
            if not os.path.isdir(root_save_path):
                os.mkdir(root_save_path)
            epoch = self.epoch_count
            epoch_save_path = os.path.join(root_save_path, f'epoch_{epoch}')
            if not os.path.isdir(epoch_save_path):
                os.mkdir(epoch_save_path)
            with open(os.path.join(epoch_save_path, f'output{midfix}{self.suffix}.dill' ), 'wb') as f:
                dill.dump(engine.state.output, f)
            for i in range(batch_size):
                stri = str(i).zfill(3)
                concated = np.concatenate([x2_predicted[i], x2[i], x1[i]], axis = 2)
                image_save_path = os.path.join(epoch_save_path, f'{stri}{midfix}{self.suffix}.png')
                if concated.shape[0] == 1:
                    concated = Image.fromarray((concated* 255).astype(np.uint8).squeeze(0))
                else:
                    concated = Image.fromarray((concated* 255).astype(np.uint8).transpose(1,2,0))
                concated.save(image_save_path)
    def add(self, var_name, var):
        if not var_name in self.variable_list:
            self.variables[var_name] = [var]
            self.variable_list.append(var_name)
        else:
            self.variables[var_name].append(var)
    def add_iter(self, var_name, var):
        if not var_name in self.variables_iter.keys():
            self.variables_iter[var_name] = [var]
        else:
            self.variables_iter[var_name].append(var)
    def make_csv(self):
        with open(self.scalar_csv_path, mode='w') as f:
            wr = csv.writer(f, quoting=csv.QUOTE_ALL)
            wr.writerow(['Epoch'] + self.variable_list + self.variable_list_iter)
    def make_csv_iter(self):
        with open(self.scalar_iter_csv_path, mode='w') as f:
            wr = csv.writer(f, quoting=csv.QUOTE_ALL)
            wr.writerow(['Epoch/Iteration'] + self.variable_list_iter)
    def save_scalars(self):
        if self.variable_list or self.variable_list_iter:
            if not os.path.isfile(self.scalar_csv_path):
                self.make_csv()
            with open(self.scalar_csv_path, mode='a') as f:
                wr = csv.writer(f, quoting=csv.QUOTE_ALL)
                row_to_write = [self.epoch_count]
                for var_name in self.variable_list:
                    row_to_write.append(self.variables[var_name][-1])
                for var_name in self.variable_list_iter:
                    row_to_write.append(self.variables_iter_epoch_mean[var_name][-1])
                wr.writerow(row_to_write)
    def save_scalars_iter(self):
        if self.variable_list_iter:
            if not os.path.isfile(self.scalar_iter_csv_path):
                self.make_csv_iter()
            with open(self.scalar_iter_csv_path, mode='a') as f:
                wr = csv.writer(f, quoting=csv.QUOTE_ALL)
                row_to_write = [f'{self.epoch_count}/{self.iter_count}']
                for var_name in self.variable_list_iter:
                    row_to_write.append(self.variables_iter[var_name][-1])
                wr.writerow(row_to_write)
    def log_metric(self, engine):
        for metric, value in engine.state.metrics.items(): 
            self.add(f'{metric}', value)
        self.save_scalars()
        #writer.add_scalar(f'Metric/{self.name}', value, self.counter)
    def log_output(self, engine):
        for output_name in self.variable_list_iter:
            self.add_iter(output_name,engine.state.output[output_name])
        if self.every_iter:
            self.save_scalars_iter()






class Saver(object):
    def __init__(self, variable_list = None, suffix = "", save_path = None):
        self.save_path = save_path
        if variable_list:
            self.variable_list = variable_list
        else:
            self.variable_list = []
        for name in self.variable_list:
            setattr(self, name, [])
    def add(self, var_name, var):
        if not var_name in self.variable_list:
            self.variable_list.append(var_name)
            setattr(self, var_name, [])
        getattr(self, var_name).append(var)
    def save(self, var_name, save_path = None):
        if save_path is None:
            save_path = self.save_path
        trace = getattr(self, var_name)
        with open(os.path.join(save_path, f'{var_name}.csv'), mode='w') as f:
            wr = csv.writer(f, quoting=csv.QUOTE_ALL)
            for i, v in enumerate(trace):
                wr.writerow([i + 1, v])
    def save_all(self, save_path = None):
        for var_name in self.variable_list:
            self.save(var_name, save_path)

import numpy as np
import dill
from PIL import Image
class Image_Saver(object):
    def __init__(self, save_folder, suffix = ""):
        self.suffix = suffix
        self.save_folder = save_folder
        self.count = 0
    def execute(self):
        self.count += 1
    def attach(self, engine):
        engine.add_event_handler(Events.EPOCH_COMPLETED, self.save_sample)
    def save_sample(self, engine):
        self.execute()
        x2_predicted = engine.state.output['y_pred']
        batch_size = x2_predicted.shape[0]
        x2 = engine.state.output['y']
        x1 = engine.state.output['template']
        x2_predicted = x2_predicted.cpu().detach().numpy()
        x2 = x2.cpu().detach().numpy()
        x1 = x1.cpu().detach().numpy()
        
        epoch = self.count
        epoch_save_path = os.path.join(self.save_folder, f'epoch_{epoch}')
        if not os.path.isdir(epoch_save_path):
            os.mkdir(epoch_save_path)
        with open(os.path.join(epoch_save_path, f'output{self.suffix}.dill' ), 'wb') as f:
            dill.dump(engine.state.output, f)
        for i in range(batch_size):
            stri = str(i).zfill(3)
            concated = np.concatenate([x2_predicted[i], x2[i], x1[i]], axis = 2)
            image_save_path = os.path.join(epoch_save_path, f'{stri}{self.suffix}.png')
            if concated.shape[0] == 1:
                concated = Image.fromarray((concated* 255).astype(np.uint8).squeeze(0))
            else:
                concated = Image.fromarray((concated* 255).astype(np.uint8))
            concated.save(image_save_path)

