from datetime import datetime

import pytorch_lightning as pl
import torch
import wandb

import gcip.utils.wandb_local as wandb_local
from gcip.utils.init import get_init_fn
from gcip.utils.utils import params_count

import gcip.utils.io as playbook_io
import torch.optim.lr_scheduler as t_lr

class BaseLightning(pl.LightningModule):

    def __init__(self, preparator, init_fn=None, plot=False):
        super(BaseLightning, self).__init__()
        self.preparator = preparator
        self.init_fn = init_fn
        self.plotting = plot
        self.model = None

        self.optim_config = None
        self.optim_config_2 = None

        self.metrics_stats = None
        self.ckpt_name = 'unknown'
        self.save_dir = None

    def get_now(self):
        now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        return now


    def reset_parameters(self):
        if self.init_fn is not None:
            self.model.apply(self.init_fn)
        return

    def param_count(self):
        return params_count(self)

    def set_optim_config(self, config):
        self.optim_config = config

    def set_optim_config_2(self, config):
        self.optim_config_2 = config

    def on_fit_start(self) -> None:
        self.input_scaler.to(self.device)
        self.preparator.on_start(self.device)

    def on_test_start(self) -> None:
        self.input_scaler.to(self.device)
        self.preparator.on_start(self.device)

    def update_log_dict(self, log_dict, my_dict, key_id=''):
        for key, value in my_dict.items():
            if isinstance(value, list):
                value_tensor = torch.cat(value)
            else:
                value_tensor = value

            if value_tensor.ndim == 0:
                value_tensor = value_tensor.unsqueeze(0)

            my_key = f"{key}{key_id}"

            log_dict[my_key] = value_tensor.detach()
            self.log(my_key, log_dict[my_key].float().mean().item(),
                     prog_bar=True)


    def set_input_scaler(self):
        raise NotImplementedError

    def global_val_step(self, batch_idx):
        return self.num_val_batches * self.current_epoch + batch_idx

    def add_noise(self, x, std):
        return torch.randn_like(x)*std

    def compute_metrics_stats(self, outputs):

        metrics = {}
        for output in outputs:
            for key, values in output.items():
                if values.ndim == 0: values = values.unsqueeze(0)
                if key not in metrics: metrics[key] = None
                if metrics[key] is None:
                    metrics[key] = values
                else:

                    metrics[key] = torch.cat([metrics[key], values], dim=0)
        metrics_stats = {}
        for metric, values in metrics.items():
            if self.__is_metric(metric):
                if values.dtype in [torch.bool]: values = values.float()
                if values.dtype != torch.float: continue
                metrics_stats[metric] = values.mean().item()
                # metrics_stats[f"{metric}_std"] = values.std().item()
                # if metric == 'ppo_actions_mu':
                #     wandb.log({'actions_mu': wandb.Histogram(values.numpy())})



        metrics_2 = self.preparator.compute_metrics(**metrics)
        metrics_stats.update(metrics_2)
        return metrics_stats

    def __is_metric(self, metric):
        cond1 =  metric not in ['logits', 'label', 'target']
        cond2 ='logits' not in metric
        return cond1 and cond2

    def training_epoch_end(self, outputs) -> None:

        metrics_stats = self.compute_metrics_stats(outputs)
        opt = self.optimizers()
        if isinstance(opt, list):
            for i, o in enumerate(opt):
                metrics_stats[f'lr_{i}'] = o.optimizer.param_groups[0]['lr']
        else:
            metrics_stats[f'lr'] = opt.optimizer.param_groups[0]['lr']
        sch = self.lr_schedulers()
        output = {'train': metrics_stats, 'epoch': self.current_epoch}


        self.do_scheduler_step(sch=sch,
                               monitor=None,
                               epoch_type='train')


        wandb.log(output, step=self.current_epoch)
        wandb_local.log_v2(output, root=self.logger.save_dir)


    def validation_epoch_end(self, outputs):
        metrics_stats = self.compute_metrics_stats(outputs)

        output = {'val': metrics_stats, 'epoch': self.current_epoch}
        self.metrics_stats = output

        wandb.log(output, step=self.current_epoch)
        wandb_local.log_v2(output, root=self.logger.save_dir)
        sch = self.lr_schedulers()

        self.do_scheduler_step(sch=sch,
                               monitor=metrics_stats['loss'],
                               epoch_type='val')

        for name, value in metrics_stats.items():
            self.log(f"val_{name}", value, on_step=False, on_epoch=True)


    def do_scheduler_step(self, sch, monitor, epoch_type):
        if epoch_type == 'train':
            if isinstance(sch, list):
                for i, sch_i in enumerate(sch):
                    if not isinstance(sch_i, t_lr.ReduceLROnPlateau): sch_i.step()
            elif sch is not None and not isinstance(sch, t_lr.ReduceLROnPlateau):
                sch.step()
        elif epoch_type == 'val':
            if isinstance(sch, list):
                for i, sch_i in enumerate(sch):
                    if isinstance(sch_i, t_lr.ReduceLROnPlateau): sch_i.step(monitor)
            elif sch is not None and isinstance(sch, t_lr.ReduceLROnPlateau):
                sch.step(monitor)


    def test_step(self, batch, batch_idx):
        log_dict = {}
        return log_dict

    def test_epoch_end(self, outputs):
        metrics_stats = self.compute_metrics_stats(outputs)

        self.metrics_stats = metrics_stats
        return

    def plot(self, name, batch, epoch, batch_idx, split, batch_size = None, title_elem_idx=1):
        if not self.plotting: return
        clamps = None
        thresholds = None
        if batch_size is None:
            if self.preparator.type_of_data in ['point_cloud']:
                clamps = [0.8, 0.9]
                batch_size = 16
            elif self.preparator.type_of_data in ['voxel']:
                batch_size = 16
                thresholds = [0.5, 0.8]
            elif self.preparator.type_of_data in ['polar']:
                batch_size = 16
            else:
                batch_size = 36

        now = self.get_now()
        folder = wandb_local.sub_folder(self.logger.save_dir, 'images')

        if isinstance(clamps, list):
            for clamp in clamps:
                my_str = str(clamp).replace('.', '')

                filename = f"{name}--epoch={epoch}--batch_idx={batch_idx}--clamp={my_str}--split={split}--now={now}.png"

                self.preparator.plot_data_batch(batch=batch,
                                                folder=folder,
                                                filename=filename,
                                                show=False,
                                                title_elem_idx=title_elem_idx,
                                                batch_size=batch_size,
                                                clamp=clamp)
        elif isinstance(thresholds, list):
            for th in thresholds:
                my_str = str(th).replace('.', '')

                filename = f"{name}--epoch={epoch}--batch_idx={batch_idx}--threshold={my_str}--split={split}--now={now}.png"

                self.preparator.plot_data_batch(batch=batch,
                                                folder=folder,
                                                filename=filename,
                                                show=False,
                                                title_elem_idx=title_elem_idx,
                                                batch_size=batch_size,
                                                thresthold=th)

        else:
            filename = f"{name}--epoch={epoch}--batch_idx={batch_idx}--split={split}--now={now}.png"
            self.preparator.plot_data_batch(batch=batch,
                                            folder=folder,
                                            filename=filename,
                                            show=False,
                                            batch_size=batch_size,
                                            title_elem_idx=title_elem_idx)
