import os
import numpy as np
import torch
from torch.nn.utils import clip_grad_value_, clip_grad_norm_
import json

from graph_learning.trainer import Trainer, TrainerConfig
from graph_learning.module import get_module
from graph_learning.utils import merge_dicts, merge_metrics
from graph_learning.logger import TBLogger
from graph_learning.utils import flatten_dict

from functools import partial

@TrainerConfig.register('pytorch-plain',
                        help='plain PyTorch trainer.')
class PtPlainTrainerConfig(TrainerConfig):
    def __init__(self, args, context):
        super().__init__(args, context)
        self.data = context.data
        self.tasker = context.tasker
        self.logger = context.global_.logger
        self.output_model_dir = context.global_.output_model_dir

        self.model = get_module(context, self.model).to(context.global_.device)

        self.optimizer = get_module(context, self.optimizer)

    @property
    def builder(self):
        return PtTrainerPlain

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--model', default='model')
        parser.add_argument('--optimizer', default='optimizer')

        parser.add_argument('--n-epoches', type=int)
        parser.add_argument('--eval-every', type=int)
        parser.add_argument('--clip', type=float, default=-1,
                            help='gradient clipping.')

        parser.add_argument('--valid-watch', default='loss',
                            help='metrics for determining the best model')
        parser.add_argument('--valid-watch-mode', default='min',
                            choices=['min', 'max'])
        parser.add_argument('--no-watch', action='store_true')

class PtTrainerPlain(Trainer):
    def __init__(self, data, tasker, model, optimizer,
                 n_epoches, eval_every, clip, logger,
                 output_model_dir,
                 valid_watch, valid_watch_mode, no_watch):
        self.model = model
        self.valid_watch = valid_watch
        self.valid_watch_mode = valid_watch_mode
        self.no_watch = no_watch

        self.logger = logger

        self.tblog = False

        if self.tblog:
            self.summary_writer = self.logger.register('tensorboard', builder=TBLogger)

        self.tasker = tasker
        self.data = data
        self.optimizer = optimizer(filter(lambda p:p.requires_grad, self.model.parameters()))
        self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=5, T_mult=2, eta_min=5e-6)

        self.n_epoches = n_epoches
        self.eval_every = eval_every

        self.clip = clip

        self.output_model_dir = output_model_dir

    def train_step(self, data, epoch):
        self.model.train()
        outputs = self.model(data)
        loss = outputs['loss']

        if self.tblog:
            self.summary_writer.add_scalar('losses/train/loss', loss, epoch)
            for n, l in outputs['losses'].items():
                self.summary_writer.add_scalar(f'losses/train/{n}', l, epoch)

        self.optimizer.zero_grad()
        loss.backward()
        if self.clip > 0:
            clip_grad_norm_(self.model.parameters(), self.clip)

        self.optimizer.step()
        # self.lr_scheduler.step()

        return {'loss': loss.item(), 'losses': outputs['losses']}

    def save_model(self, epoch):
        state = {'net': self.model.state_dict(), 'epoch': epoch}
        torch.save(state, os.path.join(
            self.output_model_dir, 'model.pt'))

    def load_model(self):
        checkpoint = torch.load(os.path.join(
            self.output_model_dir, 'model.pt'))

        model_dict = self.model.state_dict()
        pretrained_dict = checkpoint['net']
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)

        self.model.load_state_dict(model_dict)
        return checkpoint['epoch']

    def on_train_end(self, tasker, outputs, epoch):
        return merge_metrics(outputs)

    def is_best_model(self, metrics, mode='max'):
        if getattr(self, '_best_metrics', None) is None:
            self._best_metrics = metrics
        if self.no_watch:
            return True
        if mode == 'max':
            compare = lambda new, prev: new >= prev
        elif mode == 'min':
            compare = lambda new, prev: new <= prev
        if compare(metrics, self._best_metrics):
            self._best_metrics = metrics
            return True
        return False

    def valid_step(self, data, epoch):
        self.model.eval()

        if self.tblog:
            for name, param in self.model.named_parameters():
                self.summary_writer.add_histogram(name.replace('.', '/'), param.clone().cpu().data.numpy(), epoch)

        with torch.no_grad():
            model_outputs = self.model(data)

            if self.tblog:
                self.summary_writer.add_scalar('losses/valid/loss', model_outputs['loss'], epoch)
                for n, l in model_outputs['losses'].items():
                    self.summary_writer.add_scalar(f'losses/valid/{n}', l, epoch)

        outputs = data.tasker.valid_metrics(data, model_outputs['outputs'])
        outputs['loss'] = model_outputs['loss']
        return outputs

    def on_valid_end(self, tasker, step_outputs, epoch):
        outputs = tasker.valid_end(step_outputs)
        return outputs

    def test_step(self, data, epoch):
        self.model.eval()
        #self.r.log('test', f'Test on {data.name}:')
        with torch.no_grad():
            model_outputs = self.model(data)
        outputs = data.tasker.test_metrics(data, model_outputs['outputs'])
        return outputs

    def on_test_end(self, tasker, step_outputs, epoch):
        outputs = tasker.test_end(step_outputs)
        return outputs

    def train(self):
        try:
            start_epoch = self.load_model()
            print(f'Training start from epoch {start_epoch}')
        except:
            start_epoch = 0
        for epoch in range(start_epoch+1, start_epoch+self.n_epoches+1):
            outputs = self.tasker.run_train_epoch(
                train_step = self.train_step,
                on_train_end = self.on_train_end,
                epoch = epoch)
            self.logger.log('train', f'Epoch {epoch}: {outputs}')
            if epoch % self.eval_every == 0:
                self.logger.log('valid', f'Validating on epoch {epoch}...')
                self.evaluate(epoch)
        self.test(epoch)

    def get_output_item(self, outputs, path):
        for key in path.split('/'):
            outputs = outputs[key]
        return outputs

    def evaluate(self, epoch):
        self.logger.log('note', f'>>>>>valid {epoch}:')
        outputs = self.tasker.run_eval_epoch(
            eval_step=self.valid_step,
            on_eval_end=self.on_valid_end,
            epoch=epoch)
        metrics = self.get_output_item(outputs, self.valid_watch)
        loss = outputs.pop('loss')
        self.logger.log('valid', f'Valid loss: {loss}')
        if self.is_best_model(metrics, self.valid_watch_mode):
            self.save_model(epoch)
            self.logger.log('valid', f'Saved best model on epoch {epoch}.')
        self.logger.log('valid', outputs)

        if self.tblog:
            for k,v in flatten_dict({'valid': outputs}).items():
                self.summary_writer.add_scalar(k, v, epoch)

    def test(self, load_epoch=None):
        load_epoch = self.load_model()
        if load_epoch is not None:
            pass
        self.logger.log('note', f'>>>>>test {load_epoch}:')
        outputs = self.tasker.run_test_epoch(
            test_step=self.test_step,
            on_test_end=self.on_test_end,
            epoch=None)
        self.logger.log('test', f'Testing on epoch {load_epoch}...')
        self.logger.log('test', outputs)
        self.logger.log_csv('test_metrics', outputs)

        if self.tblog:
            for k,v in flatten_dict({'test': outputs}).items():
                self.summary_writer.add_scalar(k, v, load_epoch)

