# import torch
# import pytorch_lightning as pl
# from pytorch_lightning import LightningModule, Trainer
# from pytorch_lightning.callbacks.progress import TQDMProgressBar
# from torch.utils.data import DataLoader, TensorDataset
# from torchmetrics import Metric, MetricCollection
# import copy
# import logging

# class KLModel(LightningModule):
#     def __init__(self, model):
#         super().__init__()
#         self.save_hyperparameters()
#         self.model = model

#     def compile(self, optimizer, loss, metrics=None, scheduler=None):
#         self.koptimizer = optimizer
#         self.kloss = loss
#         self.metrics = torch.nn.ModuleDict({i: MetricCollection(copy.deepcopy(metrics)) for i in ('train_set', 'val_set', 'test_set')})
#         self.kscheduler = scheduler
    
#     def configure_optimizers(self):
#         # if self.kscheduler:
#         #     return [self.koptimizer], [{"scheduler": self.kscheduler, "interval": "epoch"}]
#         # else:
#         return self.koptimizer

#     # Binary Neural Network (BNN)
#     def forward(self, x):
#         with torch.no_grad():
#             for parameter in self.parameters():
#                 if parameter.requires_grad and 'bnn' in parameter.__dict__:
#                     if parameter.bnn:
#                         parameter.clamp_(-1, 1)
#         return self.model(x)
    
#     def training_step(self, batch, batch_idx):
#         x, y = batch
#         output = self(x)
#         loss = self.kloss(output, y)

#         with torch.no_grad():
#             for metric_name, metric in self.metrics['train_set'].items():
#                 metric.update(output, y)
#                 self.log(metric_name, metric.compute(), prog_bar=True)
#             if self.kscheduler:
#                 self.log('lr', self.kscheduler.get_lr()[0], prog_bar=True)

#         return loss

#     def training_epoch_end(self, training_step_outputs):
#         for metric_name, metric in self.metrics['train_set'].items():
#             metric.reset()
#         if self.kscheduler:
#             self.kscheduler.step()

#     def validation_step(self, batch, batch_idx):
#         x, y = batch
#         output = self(x)
#         loss = self.kloss(output, y)
#         self.log("val_loss", loss, prog_bar=True)

#         for metric_name, metric in self.metrics['val_set'].items():
#             metric.update(output, y)
#             self.log(f"val_{metric_name}", metric.compute(), prog_bar=True)

#     def validation_epoch_end(self, validation_step_outputs):
#         for metric_name, metric in self.metrics['val_set'].items():
#             metric.reset()
#         # print()

#     def test_step(self, batch, batch_idx):
#         x, y = batch
#         output = self(x)
#         loss = self.kloss(output, y)
#         self.log("test_loss", loss, prog_bar=True)

#         for metric_name, metric in self.metrics['test_set'].items():
#             metric.update(output, y)
#             self.log(f"test_{metric_name}", metric.compute(), prog_bar=True)

#     def test_epoch_end(self, test_step_outputs):
#         for metric_name, metric in self.metrics['val_set'].items():
#             metric.reset()
#         # print()

#     def fit(self, x_train=None, y_train=None, x_val=None, y_val=None, x_test=None, y_test=None,
#             batch_size=32, epochs=1, dataset_train=None, dataset_val=None, dataset_test=None, 
#             shuffle=True, verbose=True, test_checkpoint=None, **kwargs):


#         if x_train is not None and y_train is not None:
#             if type(x_train) is not torch.Tensor:
#                 x_train = torch.tensor(x_train)
#             if type(y_train) is not torch.Tensor:
#                 y_train = torch.tensor(y_train)
#             self.dataset_train = TensorDataset(x_train, y_train)
#         elif dataset_train is not None:
#             self.dataset_train = dataset_train
#         else:
#             raise "No training set provided."

#         if x_val is not None and y_val is not None:
#             if type(x_val) is not torch.Tensor:
#                 x_val = torch.tensor(x_val)
#             if type(y_val) is not torch.Tensor:
#                 y_val = torch.tensor(y_val)
#             self.dataset_val = TensorDataset(x_val, y_val)
#         elif dataset_val is not None:
#             self.dataset_val = dataset_val
#         else:
#             self.dataset_val = None

#         if x_test is not None and y_test is not None:
#             if type(x_test) is not torch.Tensor:
#                 x_test = torch.tensor(x_test)
#             if type(y_test) is not torch.Tensor:
#                 y_test = torch.tensor(y_test)
#             self.dataset_test = TensorDataset(x_test, y_test)
#         elif dataset_test is not None:
#             self.dataset_test = dataset_test
#         else:
#             self.dataset_test = None

#         self.batch_size = batch_size

#         # if verbose:
#         #     logging.getLogger("pytorch_lightning").setLevel(logging.NOTSET)
#         # else:
#         #     logging.getLogger("pytorch_lightning").setLevel(100)

#         callbacks = kwargs.pop('callbacks', [])
#         if verbose:
#             callbacks.extend([
#                TQDMProgressBar(refresh_rate=20),
#                pl.callbacks.ModelSummary(max_depth=2),
#             ])

#         trainer = Trainer(
#             accelerator = kwargs.pop('accelerator', 'auto'),
#             devices = kwargs.pop('devices', torch.cuda.device_count()) if torch.cuda.is_available() else None,
#             max_epochs = epochs,
#             callbacks = callbacks, 
#             enable_checkpointing = kwargs.pop('enable_checkpointing', False),
#             logger = kwargs.pop('logger', False),
#             enable_model_summary = True if verbose else False,
#             enable_progress_bar=True if verbose else False,
#             **kwargs,
#         )

#         if self.dataset_val is not None:
#             trainer.fit(self, DataLoader(self.dataset_train, batch_size=batch_size, shuffle=shuffle), DataLoader(self.dataset_val, batch_size=batch_size))
#         else:
#             trainer.fit(self, DataLoader(self.dataset_train, batch_size=batch_size))

#         results = None

#         self.eval()

#         if self.dataset_test is not None:
#             ckpt_path = f"{test_checkpoint}.ckpt" if test_checkpoint else None
#             results = trainer.test(self, dataloaders=DataLoader(self.dataset_test, batch_size=batch_size), verbose=verbose)[0]
#             results = trainer.test(self, dataloaders=DataLoader(self.dataset_test, batch_size=batch_size), verbose=verbose, ckpt_path=ckpt_path)[0]
        
#         # pred = trainer.predict(self, dataloaders=DataLoader(self.dataset_test, batch_size=batch_size))[0]

#         return results#, p red


# class SparseCategoricalAccuracy(Metric):
#     def __init__(self):
#         super().__init__()
#         self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
#         self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

#     def update(self, output, target):
#         preds = torch.argmax(output, dim=1)
#         assert preds.shape == target.shape

#         self.correct += torch.sum(preds == target)
#         self.total += target.numel()

#     def compute(self):
#         return self.correct.float() / self.total


# class MSE(Metric):
#     def __init__(self):
#         super().__init__()
#         self.add_state("loss", default=torch.tensor(0., dtype=torch.float32), dist_reduce_fx="sum")
#         self.add_state("n", default=torch.tensor(0., dtype=torch.float32), dist_reduce_fx="sum")

#     def update(self, output, target):
#         if output.ndim == target.ndim + 1:
#             target = target.unsqueeze(-1)
#         assert output.shape == target.shape
#         self.loss += torch.sum((output - target)**2)
#         self.n += target.numel()

#     def compute(self):
#         return self.loss / self.n

# class RMSE(MSE):
#     def __init__(self):
#         super().__init__()

#     def compute(self):
#         return (self.loss / self.n)**0.5

# class SRMSE(RMSE):
#     def __init__(self, std):
#         super().__init__()
#         self.std = std

#     def compute(self):
#         return self.std * (self.loss / self.n)**0.5

# class View(torch.nn.Module):
#     def __init__(self, shape):
#         self.shape = shape

#     def forward(self, input):
#         return input.view(input.size(0), *self.shape)

import torch
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from torch.utils.data import DataLoader, TensorDataset
from torchmetrics import Metric, MetricCollection
import copy
import logging
import time

class KLModel(LightningModule):
    def __init__(self, model, print_time=False):
        super().__init__()
        self.save_hyperparameters()
        self.model = model
        self.last_t = 0
        self.print_time = print_time

    def compile(self, optimizer, loss, metrics=None, scheduler=None):
        self.koptimizer = optimizer
        self.kloss = loss
        self.metrics = torch.nn.ModuleDict({i: MetricCollection(copy.deepcopy(metrics)) for i in ('train_set', 'val_set', 'test_set')})
        self.kscheduler = scheduler
    
    def configure_optimizers(self):
        # if self.kscheduler:
        #     return [self.koptimizer], [{"scheduler": self.kscheduler, "interval": "epoch"}]
        # else:
        return self.koptimizer

    # Binary Neural Network (BNN)
    def forward(self, x):
        with torch.no_grad():
            for parameter in self.parameters():
                if parameter.requires_grad and 'bnn' in parameter.__dict__:
                    if parameter.bnn:
                        parameter.clamp_(-1, 1)
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        if self.print_time:
            new_t = time.time()
            print(new_t - self.last_t)
            self.last_t = new_t
        x, y = batch
        output = self(x)
        loss = self.kloss(output, y)

        with torch.no_grad():
            for metric_name, metric in self.metrics['train_set'].items():
                metric.update(output, y)
                self.log(metric_name, metric.compute(), prog_bar=True)
            if self.kscheduler:
                self.log('lr', self.kscheduler.get_lr()[0], prog_bar=True)

        return loss

    def training_epoch_end(self, training_step_outputs):
        for metric_name, metric in self.metrics['train_set'].items():
            metric.reset()
        if self.kscheduler:
            self.kscheduler.step()

    def validation_step(self, batch, batch_idx):
        x, y = batch
        output = self(x)
        loss = self.kloss(output, y)
        self.log("val_loss", loss, prog_bar=True)

        for metric_name, metric in self.metrics['val_set'].items():
            metric.update(output, y)
            self.log(f"val_{metric_name}", metric.compute(), prog_bar=True)

    def validation_epoch_end(self, validation_step_outputs):
        for metric_name, metric in self.metrics['val_set'].items():
            metric.reset()
        # print()

    def test_step(self, batch, batch_idx):
        x, y = batch
        output = self(x)
        loss = self.kloss(output, y)
        self.log("test_loss", loss, prog_bar=True)

        for metric_name, metric in self.metrics['test_set'].items():
            metric.update(output, y)
            self.log(f"test_{metric_name}", metric.compute(), prog_bar=True)

    def test_epoch_end(self, test_step_outputs):
        for metric_name, metric in self.metrics['val_set'].items():
            metric.reset()
        # print()

    def fit(self, x_train=None, y_train=None, x_val=None, y_val=None, x_test=None, y_test=None,
            batch_size=32, epochs=1, dataset_train=None, dataset_val=None, dataset_test=None, 
            shuffle=True, verbose=True, test_checkpoint=None, **kwargs):


        if x_train is not None and y_train is not None:
            if type(x_train) is not torch.Tensor:
                x_train = torch.tensor(x_train)
            if type(y_train) is not torch.Tensor:
                y_train = torch.tensor(y_train)
            self.dataset_train = TensorDataset(x_train, y_train)
        elif dataset_train is not None:
            self.dataset_train = dataset_train
        else:
            raise "No training set provided."

        if x_val is not None and y_val is not None:
            if type(x_val) is not torch.Tensor:
                x_val = torch.tensor(x_val)
            if type(y_val) is not torch.Tensor:
                y_val = torch.tensor(y_val)
            self.dataset_val = TensorDataset(x_val, y_val)
        elif dataset_val is not None:
            self.dataset_val = dataset_val
        else:
            self.dataset_val = None

        if x_test is not None and y_test is not None:
            if type(x_test) is not torch.Tensor:
                x_test = torch.tensor(x_test)
            if type(y_test) is not torch.Tensor:
                y_test = torch.tensor(y_test)
            self.dataset_test = TensorDataset(x_test, y_test)
        elif dataset_test is not None:
            self.dataset_test = dataset_test
        else:
            self.dataset_test = None

        self.batch_size = batch_size

        # if verbose:
        #     logging.getLogger("pytorch_lightning").setLevel(logging.NOTSET)
        # else:
        #     logging.getLogger("pytorch_lightning").setLevel(100)

        callbacks = kwargs.pop('callbacks', [])
        if verbose:
            callbacks.extend([
               TQDMProgressBar(refresh_rate=20),
               pl.callbacks.ModelSummary(max_depth=2),
            ])

        trainer = Trainer(
            accelerator = kwargs.pop('accelerator', 'auto'),
            gpus=-1,
            devices = kwargs.pop('devices', torch.cuda.device_count()) if torch.cuda.is_available() else None,
            max_epochs = epochs,
            callbacks = callbacks, 
            enable_checkpointing = kwargs.pop('enable_checkpointing', False),
            logger = kwargs.pop('logger', False),
            enable_model_summary = True if verbose else False,
            enable_progress_bar=True if verbose else False,
            **kwargs,
        )

        if self.dataset_val is not None:
            trainer.fit(self, DataLoader(self.dataset_train, batch_size=batch_size, shuffle=shuffle), DataLoader(self.dataset_val, batch_size=batch_size))
        else:
            trainer.fit(self, DataLoader(self.dataset_train, batch_size=batch_size))

        results = None

        self.eval()

        if self.dataset_test is not None:
            # ckpt_path = f"{test_checkpoint}.ckpt" if test_checkpoint else None
            results = trainer.test(self, dataloaders=DataLoader(self.dataset_test, batch_size=batch_size), verbose=verbose)[0]
            # results = trainer.test(self, dataloaders=DataLoader(self.dataset_test, batch_size=batch_size), verbose=verbose, ckpt_path=ckpt_path)[0]
        
        # pred = trainer.predict(self, dataloaders=DataLoader(self.dataset_test, batch_size=batch_size))[0]

        return results#, p red


class SparseCategoricalAccuracy(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, output, target):
        preds = torch.argmax(output, dim=1)
        assert preds.shape == target.shape

        self.correct += torch.sum(preds == target)
        self.total += target.numel()

    def compute(self):
        return self.correct.float() / self.total


class MSE(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("loss", default=torch.tensor(0., dtype=torch.float32), dist_reduce_fx="sum")
        self.add_state("n", default=torch.tensor(0., dtype=torch.float32), dist_reduce_fx="sum")

    def update(self, output, target):
        if output.ndim == target.ndim + 1:
            target = target.unsqueeze(-1)
        assert output.shape == target.shape
        self.loss += torch.sum((output - target)**2)
        self.n += target.numel()

    def compute(self):
        return self.loss / self.n

class RMSE(MSE):
    def __init__(self):
        super().__init__()

    def compute(self):
        return (self.loss / self.n)**0.5

class SRMSE(RMSE):
    def __init__(self, std):
        super().__init__()
        self.std = std

    def compute(self):
        return self.std * (self.loss / self.n)**0.5

class View(torch.nn.Module):
    def __init__(self, *shape):
        super().__init__()
        self.shape = shape

    def forward(self, input):
        return input.view(input.size(0), *self.shape)