import torch

import hienet._keys as KEY
from hienet.error_recorder import ErrorRecorder
from hienet.train.loss import get_loss_functions_from_config
from hienet.train.optim import optim_dict, scheduler_dict
from hienet.train.DeNS_utils import add_gaussian_noise_to_position

import wandb
from torch_ema import ExponentialMovingAverage
import numpy as np

from hienet.error_recorder import ErrorRecorder
from hienet.hienet_logger import Logger

import lightning as L


class LightningModel(L.LightningModule):
    def __init__(self, model, config: dict, experiment_name: str, init_csv: bool =True):
        super(LightningModel, self).__init__()
        self.distributed = config[KEY.IS_DDP]
        self.model = model
        self.model.set_is_batch_data(True)
        self.config = config
        self.experiment_name = experiment_name

        param = [p for p in self.model.parameters() if p.requires_grad]
        optimizer = optim_dict[config[KEY.OPTIMIZER].lower()]
        optim_param = config[KEY.OPTIM_PARAM]
        self.optimizer = optimizer(param, **optim_param)

        scheduler = scheduler_dict[config[KEY.SCHEDULER].lower()]
        scheduler_param = config[KEY.SCHEDULER_PARAM]
        self.scheduler = scheduler(self.optimizer, **scheduler_param)

        self.train_recorder = ErrorRecorder.from_config(config)
        self.valid_recorder = ErrorRecorder.from_config(config)
        self.best_metric = config[KEY.BEST_METRIC]

        self.use_denoising = config[KEY.USE_DENOISING]
        self.denoising_prob = 1.0

        # This should be outside of the trainer(?)
        # list of tuples (loss_definition, weight)
        self.loss_functions = get_loss_functions_from_config(config)

        self.csv_fname = config[KEY.CSV_LOG]
        if init_csv:
            self.init_csv()

    def configure_optimizers(self):
        return [self.optimizer],  [{"scheduler": self.scheduler, 
                                    "interval": "epoch", "monitor": "val"}]
    
    def init_csv(self):
        csv_header = ['Epoch', 'Learning_rate']
        for metric in self.train_recorder.get_metric_dict().keys():
            csv_header.append(f'Train_{metric}')
            csv_header.append(f'Valid_{metric}')
        Logger().init_csv(self.csv_fname, csv_header)
    
    def forward(self, batch):
        return self.model(batch)
    
    def on_fit_start(self):
        if self.experiment_name.split('.')[0] != "testing":
            if self.global_rank == 0:
                wandb.init(project="hienet", name=self.experiment_name, config=self.config)

        self.ema = ExponentialMovingAverage(self.model.parameters(), decay=self.config[KEY.EMA_DECAY])
    
    def on_train_epoch_start(self):
        epoch = self.current_epoch
        fin_epoch = self.trainer.max_epochs
        lr = self.get_lr()
        Logger().timer_start('epoch')
        Logger().bar()
        Logger().write(f'Epoch {epoch}/{fin_epoch - 1}  lr: {lr:8f}\n')
        Logger().bar()

    def training_step(self, batch, batch_idx):
        if self.use_denoising and np.random.rand() < self.denoising_prob:
                batch = add_gaussian_noise_to_position(
                    batch,
                    std=0.05
                )
        output = self.model(batch)
        self.train_recorder.update(output)
        total_loss = torch.tensor([0.0], device=self.device)
        for loss_def, w in self.loss_functions:
            total_loss += loss_def.get_loss(output, self.model) * w
        return total_loss

    def on_train_batch_end(self, losses, batch, batch_idx):
        self.ema.update()

    def on_validation_start(self):
        torch.set_grad_enabled(True)

    def validation_step(self, batch, batch_idx):
        if self.use_denoising and np.random.rand() < self.denoising_prob:
                batch = add_gaussian_noise_to_position(
                    batch,
                    std=0.05
                )
        with self.ema.average_parameters():
            output = self.model(batch)
            self.valid_recorder.update(output)
        return output
    
    def on_train_epoch_end(self):
        if self.distributed:
            self.recorder_all_reduce(self.train_recorder)
            self.recorder_all_reduce(self.valid_recorder)
        train_err = self.train_recorder.epoch_forward()
        valid_err = self.valid_recorder.epoch_forward()

        epoch = self.current_epoch
        lr = self.get_lr()
        csv_values = [epoch, lr]
        self.log('lr', lr, sync_dist=self.distributed)

        if self.global_rank == 0:
            try:
                wandb.log({"epoch": epoch, "lr": lr}, step=epoch)
            except Exception as e:
                print(f"\n Wandb was not initialized, but epoch {self.current_epoch} was successfully finished")
                # exit(0)

            try:
                for name, param in self.model.named_parameters():
                    if param.grad is not None:
                        wandb.log({f"gradients/{name}": param.grad.norm().item()}, step=epoch)
                    wandb.log({f"weights/{name}": param.norm().item()}, step=epoch)
            except Exception as e:
                pass

        for metric in train_err:
            if self.global_rank == 0:
                try:
                    wandb.log({("train_" + metric): train_err[metric] , ("valid_" + metric): valid_err[metric]}, step=epoch)
                except Exception as e:
                    pass
            csv_values.append(train_err[metric])
            csv_values.append(valid_err[metric])
            a_metric = metric.encode('ascii', 'replace').decode('ascii')
            self.log('train_' + a_metric, train_err[metric], sync_dist=self.distributed)
            self.log('valid_' + a_metric, valid_err[metric], sync_dist=self.distributed)
        Logger().append_csv(self.csv_fname, csv_values)

        Logger().write_full_table([train_err, valid_err], ['Train', 'Valid'])

        val = None
        for metric in valid_err:
            # loose string comparison,
            # e.g. "Energy" in "TotalEnergy" or "Energy_Loss"
            if self.best_metric in metric:
                val = valid_err[metric]
                break
        assert (
            val is not None
        ), f'Metric {self.best_metric} not found in {valid_err}'

        self.log('val', val, sync_dist=self.distributed)

        Logger().timer_end('epoch', message=f'Epoch {epoch} elapsed')

    def lr_scheduler_step(self, scheduler, metric):
        if scheduler is None:
            return
        if isinstance(
            scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau
        ):
            scheduler.step(metric)
        else:
            scheduler.step()

    def get_lr(self):
        return self.optimizer.param_groups[0]['lr']

    def recorder_all_reduce(self, recorder: ErrorRecorder):
        for metric in recorder.metrics:
            # metric.value._ddp_reduce(self.device)
            metric.ddp_reduce(self.device)

    def on_save_checkpoint(self, checkpoint):
        checkpoint.update({
                            'config': self.config, 
                            'epoch': self.current_epoch,
                            'model_state_dict': self.model.state_dict(),
                            'optimizer_state_dict': self.optimizer.state_dict(),
                            'scheduler_state_dict': self.scheduler.state_dict(),
            })

    def on_load_checkpoint(self, checkpoint):
        self.load_state_dicts(
            checkpoint['model_state_dict'],
            checkpoint['optimizer_state_dict'],
            checkpoint['scheduler_state_dict'],
            strict=False,
        )

    def load_state_dicts(
        self,
        model_state_dict,
        optimizer_state_dict,
        scheduler_state_dict,
        strict=True,
    ):
        # if model_state_dict is not None:
        #     self.model.load_state_dict(model_state_dict, strict=strict)
        # if optimizer_state_dict is not None:
        #     self.optimizer.load_state_dict(optimizer_state_dict)
        if scheduler_state_dict is not None:
            self.scheduler.load_state_dict(scheduler_state_dict)
