import os.path
from abc import ABC, abstractmethod

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchmetrics as tm
import wandb
from torchvision.utils import save_image

import gcip.utils.io as playbook_io
import gcip.utils.preparator.utils as pb_prep
import gcip.utils.wandb_local as wandb_local
from gcip.utils.constants import Cte
from gcip.utils.entropy import *
from gcip.utils.exceptions import WrongLoss
from gcip.utils.io import dict_to_cn
from gcip.utils.io import makedirs
from torchlikelihoods import scalers_dict, HeterogeneousScaler, HeterogeneousObjectScaler


class BaseDatasetPreparator(ABC):

    def __init__(self, name,
                 splits,
                 shuffle_train,
                 single_split,
                 task,
                 k_fold,
                 root,
                 loss,
                 scale,
                 use_weight,
                 include_idx,
                 balance=None,
                 device='cpu'):
        self.name = name

        self.include_idx = include_idx

        self.split = splits
        self.split_names = ['train', 'val', 'test']
        self.current_split = None
        assert np.isclose(sum(splits), 1.0), f"Splits: {splits} {sum(splits)}"
        self.shuffle_train = shuffle_train
        self.single_split = single_split
        self.task = task
        self.k_fold = k_fold

        self.scale = scale
        self.scaler = None
        self.loss = self._loss(loss)

        self.datasets = None

        self.device = device

        if not os.path.exists(root):
            makedirs(root)

        self.root = root

        self.weight = None
        self.use_weight = use_weight
        self.balance = balance

    @classmethod
    def params(cls, dataset):

        if isinstance(dataset, dict):
            dataset = dict_to_cn(dataset)

        return {
            'splits': dataset.splits,
            'k_fold': dataset.k_fold,
            'shuffle_train': dataset.shuffle_train,
            'single_split': dataset.single_split,
            'include_idx': dataset.include_idx,
            'loss': dataset.loss,
            'root': dataset.root,
            'scale': dataset.scale,
            'use_weight': dataset.use_weight,
            'balance': dataset.balance
        }

    @property
    def type_of_data(self):
        return 'default'

    @property
    def dims_scaler(self):
        return None  # single scalar per scaler parameter

    # Abstract methods
    @abstractmethod
    def _batch_element_list(self):
        pass

    @abstractmethod
    def _data_loader(self, dataset, batch_size, shuffle, num_workers=0):
        pass

    @abstractmethod
    def _get_dataset_raw(self):
        pass

    @abstractmethod
    def _metric_names(self):
        pass

    @abstractmethod
    def _plot_data(self, batch, **kwargs):
        pass

    @abstractmethod
    def _split_dataset(self, dataset_raw):
        pass

    @abstractmethod
    def _x_dim(self):
        pass

    @abstractmethod
    def compute_psnr(self, x, x_recons):
        pass

    @abstractmethod
    def dim_coordinates(self):
        pass

    @abstractmethod
    def dim_features(self):
        pass

    @abstractmethod
    def get_dataset_train(self):
        pass

    @abstractmethod
    def get_features_train(self):
        pass

    @abstractmethod
    def get_scaler_info(self):
        pass

    @abstractmethod
    def get_y_from_dataset(self, dataset):
        pass

    @abstractmethod
    def num_samples(self):
        pass

    @abstractmethod
    def label_dim(self):
        pass

    @abstractmethod
    def _loss(self, loss):
        pass

    # Not implemented methods

    def data_converter(self):
        raise NotImplementedError

    def _get_target(self, batch):
        raise NotImplementedError

    def get_target(self, batch, dtype=None):
        target = self._get_target(batch)

        if self.loss in [Cte.CE]:
            target = target.flatten()

        if isinstance(dtype, str):
            if dtype == 'float':
                return target.float()
            elif dtype == 'long':
                return target.long()
            else:
                raise NotImplementedError
        else:
            if self.loss in [Cte.BCELOGITS]:
                return target.float()
            elif self.loss in [Cte.CE]:
                return target.long()
            else:
                return target

    def grid_resolution(self, multiplier):
        raise NotImplementedError

    def monitor(self):
        return 'val_loss', 'min'

    # Implemented methods

    def non_linearity(self):

        scaler = self._get_scaler()
        if scaler is None:
            return nn.Identity()
        else:
            return scaler.non_linearity()

    def batch_element_list(self):
        return ['idx', 'modulation']

    def get_batch_elements(self, batch, elements):
        batch_out = []

        batch_elements = self._batch_element_list()

        batch_elements.extend(self.batch_element_list())

        for el in elements:
            batch_out.append(batch[batch_elements.index(el)])

        return batch_out

    @torch.no_grad()
    def on_start(self, device):

        self.device = device

    def _transform_dataset_pre_split(self, dataset_raw):
        return dataset_raw

    def x_dim(self, use_modulation=False):
        return self._x_dim()



    def prepare_data(self):
        dataset_raw = self._get_dataset_raw()
        dataset_raw = self._transform_dataset_pre_split(dataset_raw=dataset_raw)
        datasets = self._split_dataset(dataset_raw)

        if self.balance == 'oversample':
            y = self.get_y_from_dataset(datasets[0])

            if y.ndim == 2:
                assert y.shape[-1] == 1
            y_np = y.flatten().numpy()
            idx_balanced = pb_prep.balance_dataset_indexes(y_np)
            datasets[0] = datasets[0][idx_balanced]

        if self.single_split in self.split_names:
            idx = self.split_names.index(self.single_split)
            for i in range(len(datasets)):
                if i != idx:
                    datasets[i] = datasets[idx]
        datasets = self._transform_after_split(datasets)
        self.datasets = datasets
        return

    def set_current_split(self, i):
        if isinstance(self.single_split, str):
            self.current_split = self.single_split
        else:
            self.current_split = self.split_names[i]

    def _set_weights(self, dataset_tr):

        y = self.get_y_from_dataset(dataset_tr)
        if self.label_dim() == 1 and self.loss in [Cte.BCELOGITS]:
            num_pos = (y == 1).sum().item()
            num_neg = (y == 0).sum().item()
            self.weight = num_neg / num_pos
            self.weight = torch.tensor(self.weight, device=self.device)
        elif self.loss in [Cte.CE]:
            y_one_hot = torch.nn.functional.one_hot(y, num_classes=self.label_dim())
            num_pos = (y_one_hot == 1).sum(0)
            num_neg = (y_one_hot == 0).sum(0)
            self.weight = (num_neg / num_pos).flatten()
        else:
            raise NotImplementedError

    def _transform_after_split(self, datasets):
        if self.use_weight:
            self._set_weights(datasets[0])

        return datasets

    def _get_scaler_list(self, scalers_info):
        if len(scalers_info) == 1:
            name, _ = scalers_info[0]
            scaler = scalers_dict[name]()
        else:
            scalers, splits = [], []
            for (name, size) in scalers_info:
                scaler = scalers_dict[name]()
                scalers.append(scaler)
                splits.append(size)
            scaler = HeterogeneousScaler(scalers, splits)
        return scaler

    def _get_scaler(self):
        if isinstance(self.scale, str):
            scalers_info = self.get_scaler_info()
        else:
            scalers_info = None

        if isinstance(scalers_info, list):
            scaler = self._get_scaler_list(scalers_info=scalers_info)


        elif isinstance(scalers_info, dict):
            sca_info = {}
            for attr_name, value in scalers_info.items():
                if isinstance(value, str):
                    sca_info[attr_name] = value  # Reuse scaler from attribute value
                elif isinstance(value, list):
                    sca_info[attr_name] = self._get_scaler_list(scalers_info=value)
                else:
                    scaler_name = value[0]
                    domain_size = value[1]
                    sca_info[attr_name] = scalers_dict[scaler_name]()
            scaler = HeterogeneousObjectScaler(scalers_dict=sca_info)

        else:
            scaler = scalers_dict['identity']()

        return scaler

    def get_scaler(self, fit=True):

        dims = self.dims_scaler

        scaler = self._get_scaler()

        if fit:
            x = self.get_features_train()
            scaler.fit(x, dims=dims)

        self.scaler = scaler

        return scaler

    def __compute_metric(self, fn, fn_name, metrics, predictions_dict, targets_dict):
        fn = fn.to(self.device)
        for name in predictions_dict.keys():
            preds = predictions_dict[name]
            target = targets_dict[name]
            metrics[f'{fn_name}{name}'] = fn(preds=preds, target=target)

    def compute_metrics(self, **kwargs):

        predictions_dict = {}
        targets_dict = {}
        logits_to_hard_pred_fn = self.get_logits_to_hard_pred_fn(keepdim=False)

        for name, values in kwargs.items():
            if 'logits' in name:
                pred = logits_to_hard_pred_fn(values)
                predictions_dict[name.replace('logits', '')] = pred
            if 'target' in name:
                targets_dict[name.replace('target', '')] = values

        preds_target_provided = len(predictions_dict) > 0 and len(targets_dict) > 0
        metric_names = self._metric_names()
        metrics = {}
        if 'accuracy' in metric_names and preds_target_provided:
            task = self.get_clf_task()
            num_classes = self.get_num_classes()
            accuracy = tm.Accuracy(task=task,
                                   num_classes=num_classes)

            self.__compute_metric(accuracy, 'accuracy', metrics, predictions_dict, targets_dict)
        if 'precision' in metric_names and preds_target_provided:
            task = self.get_clf_task()
            num_classes = self.get_num_classes()
            precision = tm.Precision(task=task,
                                     threshold=0.5,
                                     num_classes=num_classes,
                                     average='micro',
                                     multidim_average='global',
                                     top_k=1)

            self.__compute_metric(precision, 'precision', metrics, predictions_dict, targets_dict)
        if 'recall' in metric_names and preds_target_provided:
            task = self.get_clf_task()
            num_classes = self.get_num_classes()

            recall = tm.Recall(task=task,
                               threshold=0.5,
                               num_classes=num_classes,
                               average='micro',
                               multidim_average='global',
                               top_k=1)

            self.__compute_metric(recall, 'recall', metrics, predictions_dict, targets_dict)

        if 'mae' in metric_names and preds_target_provided:
            raise NotImplementedError
            mae = tm.MeanAbsoluteError()
            value = mae(preds=logits, target=target)
            metrics['mae'] = value
            if 'logits_recursive' in kwargs:
                logits_recursive = kwargs['logits_recursive']
                target_recursive = kwargs['target_recursive']
                value_2 = mae(preds=logits_recursive, target=target_recursive)
                metrics['mae_recursive'] = value_2

        metrics = {key: value.item() for key, value in metrics.items()}
        return metrics

    def get_clf_task(self):
        if self.label_dim() == 1 and self.loss in [Cte.BCELOGITS]:
            task = 'binary'
        elif self.label_dim() > 1 and self.loss in [Cte.CE]:
            task = 'multiclass'
        elif self.label_dim() > 1 and self.loss in [Cte.BCELOGITS]:
            task = 'multilabel'
        else:
            raise NotImplementedError
        return task

    def get_num_classes(self):
        if self.label_dim() == 1 and self.loss in [Cte.BCELOGITS]:
            return 2
        elif self.label_dim() > 1 and self.loss in [Cte.CE]:
            return self.label_dim()
        elif self.label_dim() > 1 and self.loss in [Cte.BCELOGITS]:
            return 2

        else:
            raise NotImplementedError


    def get_ckpt_name(self, ckpt_file):
        ckpt_name = os.path.splitext(os.path.basename(ckpt_file))[0]
        if 'epoch' in ckpt_name:
            ckpt_dict = wandb_local.str_to_dict(my_str=ckpt_name,
                                                sep='-',
                                                remove_ext=False)
            ckpt_name = f"ckpt_{ckpt_dict['epoch']}"

        return ckpt_name

    def get_modulations_folder(self, root, ckpt_name, batch_size, split):
        ckpt_folder = os.path.join(root, 'modulations', ckpt_name)
        if batch_size is None:
            batch_size_list = [f for f in os.listdir(ckpt_folder) if os.path.isdir(os.path.join(ckpt_folder, f))]
            if len(batch_size_list) != 1:
                playbook_io.print_warning(f"batch_size_list: {batch_size_list} {ckpt_folder}")
                raise NotImplementedError
            batch_size = batch_size_list[0]
        folder = os.path.join(ckpt_folder, str(batch_size), split)
        return folder


    def get_dataloader_train(self, batch_size, num_workers=0, shuffle=None):
        assert isinstance(self.datasets, list)

        dataset = self.datasets[0]
        shuffle = self.shuffle_train if shuffle is None else shuffle
        loader_train = self._data_loader(dataset, batch_size, shuffle=shuffle, num_workers=num_workers)

        return loader_train

    def get_dataloaders(self, batch_size, num_workers=0):
        assert isinstance(self.datasets, list)
        loader_train = self.get_dataloader_train(batch_size, num_workers)

        loaders = [loader_train]
        for i in range(1, len(self.datasets)):
            dataset = self.datasets[i]
            loader = self._data_loader(dataset, batch_size,
                                       shuffle=False,
                                       num_workers=num_workers)
            loaders.append(loader)
        return loaders

    def get_loss_fn(self):
        # playbook_io.print_debug(f"loss: {self.loss}")
        # playbook_io.print_debug(f"self.weight: {self.weight}")
        # assert False
        if self.loss == Cte.BCELOGITS:
            loss = nn.BCEWithLogitsLoss(reduction='none', pos_weight=self.weight)
        elif self.loss == Cte.CE:
            loss = nn.CrossEntropyLoss(reduction='none', weight=self.weight)
        elif self.loss == Cte.FORWARD:
            loss = lambda x: x
        else:
            raise WrongLoss
        return loss

    def get_output_act_fn(self):
        if self.loss == Cte.BCELOGITS:
            return torch.nn.Sigmoid()
        elif self.loss == Cte.CE:
            return torch.nn.Softmax(dim=-1)
        elif self.loss == Cte.FORWARD:
            return torch.nn.Identity()
        else:
            raise WrongLoss

    def get_entropy(self):
        if self.loss == Cte.BCELOGITS:
            return BinaryEntropy()
        elif self.loss == Cte.CE:
            return CategoricalEntropy()
        elif self.loss == Cte.FORWARD:
            raise NotImplementedError

    def get_logits_to_hard_pred_fn(self, keepdim=True):
        act_fn = self.get_output_act_fn()
        if self.loss == Cte.BCELOGITS:
            def my_fn(logits):
                pred = act_fn(logits)
                return (pred > 0.5).int()
        elif self.loss == Cte.CE:
            def my_fn(logits):
                pred = act_fn(logits)
                return torch.argmax(pred, dim=-1, keepdim=keepdim)
        elif self.loss == Cte.FORWARD:
            def my_fn(logits):
                pred = act_fn(logits)
                return pred
        else:
            raise WrongLoss

        return my_fn

    def plot_data(self, split='train',
                  num_samples=1,
                  shuffle=False,
                  batch_idx=0,
                  folder=None,
                  filename=None,
                  show=False,
                  **kwargs):

        loader = self._data_loader(dataset=self.datasets[self.split_names.index(split)],
                                   batch_size=num_samples,
                                   shuffle=shuffle,
                                   num_workers=0)

        for i, batch in enumerate(iter(loader)):
            if i == batch_idx:
                break

        return self.plot_data_batch(batch, folder=folder, filename=filename, show=show, **kwargs)

    def plot_data_batch(self, batch, folder=None, filename=None, show=False, **kwargs):

        fig = self._plot_data(batch, **kwargs)
        if not isinstance(fig, torch.Tensor):
            plt.tight_layout(rect=[0, 0.03, 1, 0.95])

        self.post_plotting(fig, folder=folder, filename=filename, show=show)

        return fig

    def post_plotting(self, fig, folder=None, filename=None, show=False):
        plt.tight_layout()
        if folder and filename:
            self.save_fig(fig, folder, filename)
        if show:
            plt.show()
        else:
            plt.close('all')

    def get_number_of_rows_and_cols(self, num_samples, batch_size):
        if isinstance(batch_size, tuple):
            bath_size_ = np.prod(batch_size)
            num_samples = min(num_samples, bath_size_)
            ncol = batch_size[1]
            nrow = batch_size[0]
        elif isinstance(batch_size, int):
            num_samples = min(num_samples, batch_size)
            ncol = int(np.ceil(np.sqrt(num_samples)))
            nrow = 1 if num_samples == 2 else ncol
        else:
            ncol = int(np.ceil(np.sqrt(num_samples)))
            nrow = 1 if num_samples == 2 else ncol
        return num_samples, nrow, ncol

    def save_fig(self, fig, folder, filename):
        my_dict = wandb_local.str_to_dict(my_str=filename, sep='--', remove_ext=True)
        name = wandb_local.dict_to_str(my_dict=my_dict, keys_remove=['epoch', 'now'])

        full_filaname = os.path.join(folder, f'{filename}')

        if os.path.exists(full_filaname):
            playbook_io.print_warning(f'Overwriting file: {full_filaname}')
        if isinstance(fig, torch.Tensor):
            save_image(fig, full_filaname)
            try:
                image = wandb.Image(fig)
                wandb.log({f'figures/{name}': image})
            except:
                playbook_io.print_warning(f"wandb not ready to plot image")
        else:
            try:
                wandb.log({f'figures/{name}': wandb.Image(fig)})
            except:
                playbook_io.print_warning(f"wandb not ready to plot figure")
            fig.savefig(full_filaname)

    def add_title(self, title_el, ax):

        if title_el.ndim == 0:
            my_str = str(title_el)
            num_characters = 0
        else:
            my_str = ', '.join([f"{t:.2f}" for t in title_el])
            my_str = f"[{my_str}]"
            num_characters = len(my_str)
        fontsize = max(14 - num_characters, 5)
        ax.set_title(f"{my_str}", fontsize=fontsize)

    def select_axis(self, nrow, ncol, i, j, axes):
        if nrow == 1 and ncol == 1:
            ax_ij = axes
        elif nrow == 1:
            ax_ij = axes[j]
        else:
            ax_ij = axes[i, j]

        return ax_ij
