import os
from utils.configuration import Configuration
import time
import torch as th
import numpy as np
from einops import rearrange, repeat, reduce
import cv2
from pytorch_lightning import Callback
from pytorch_lightning import Trainer, LightningModule

class PeriodicCheckpoint(Callback):
    def __init__(self, save_path, save_every_n_steps = 3000):
        super().__init__()
        self.save_path = save_path
        self.save_every_n_steps = save_every_n_steps

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        global_step = trainer.global_step
        if (global_step + 1) % self.save_every_n_steps == 0:
            checkpoint_file = os.path.join(self.save_path, "quicksave.ckpt")
            print(f"Saving checkpoint to {checkpoint_file}")
            trainer.save_checkpoint(checkpoint_file)

class CosineAnnealingCheckpoint(Callback):
    def __init__(self, save_path, prefix = "checkpoint"):
        super().__init__()
        self.save_path = save_path
        self.current_epoch = 0
        self.prefix = prefix

    def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
        self.current_epoch += 1

        # Check if the cycle is about to end
        if self.current_epoch in np.cumsum(2**np.arange(15)):
            cycle = np.where(np.cumsum(2**np.arange(15)) == self.current_epoch)[0].item() + 1
            checkpoint_file = os.path.join(self.save_path, f"{self.prefix}_epoch_{self.current_epoch}_cycle_{cycle}.ckpt")
            trainer.save_checkpoint(checkpoint_file)
            print(f"Saving checkpoint to {checkpoint_file}")

class Timer:
    
    def __init__(self):
        self.last   = time.time()
        self.passed = 0
        self.sum    = 0

    def __str__(self):
        self.passed = self.passed * 0.99 + time.time() - self.last
        self.sum    = self.sum * 0.99 + 1
        passed      = self.passed / self.sum
        self.last = time.time()

        if passed > 1:
            return f"{passed:.2f}s/it"

        return f"{1.0/passed:.2f}it/s"

class UEMA:
    
    def __init__(self, memory = 100):
        self.value  = 0
        self.sum    = 1e-30
        self.decay  = np.exp(-1 / memory)

    def update(self, value):
        self.value = self.value * self.decay + value
        self.sum   = self.sum   * self.decay + 1

    def __float__(self):
        return self.value / self.sum

class BinaryStatistics:
    
    def __init__(self):
        self.true_positive  = 0
        self.true_negative  = 0
        self.false_positive = 0
        self.false_negative = 0

    def update(self, outputs, labels):
        outputs = th.round(outputs)
        self.true_positive  += th.sum((outputs == labels).float() * (labels == th.ones_like(labels)).float()).item()
        self.true_negative  += th.sum((outputs == labels).float() * (labels == th.zeros_like(labels)).float()).item()
        self.false_positive += th.sum((outputs != labels).float() * (labels == th.zeros_like(labels)).float()).item()
        self.false_negative += th.sum((outputs != labels).float() * (labels == th.ones_like(labels)).float()).item()

    def accuracy(self):
        return 100 * (self.true_positive + self.true_negative) / (self.true_positive + self.true_negative + self.false_positive + self.false_negative + 1e-30)

    def sensitivity(self):
        return 100 * self.true_positive / (self.true_positive + self.false_negative + 1e-30)

    def specificity(self):
        return 100 * self.true_negative / (self.true_negative + self.false_positive + 1e-30)


def model_path(cfg: Configuration, overwrite=False, move_old=True):
    """
    Makes the model path, option to not overwrite
    :param cfg: Configuration file with the model path
    :param overwrite: Overwrites the files in the directory, else makes a new directory
    :param move_old: Moves old folder with the same name to an old folder, if not overwrite
    :return: Model path
    """
    _path = os.path.join('out')
    path = os.path.join(_path, cfg.model_path)

    if not os.path.exists(_path):
        os.makedirs(_path)

    if not overwrite:
        if move_old:
            # Moves existing directory to an old folder
            if os.path.exists(path):
                old_path = os.path.join(_path, f'{cfg.model_path}_old')
                if not os.path.exists(old_path):
                    os.makedirs(old_path)
                _old_path = os.path.join(old_path, cfg.model_path)
                i = 0
                while os.path.exists(_old_path):
                    i = i + 1
                    _old_path = os.path.join(old_path, f'{cfg.model_path}_{i}')
                os.renames(path, _old_path)
        else:
            # Increases number after directory name for each new path
            i = 0
            while os.path.exists(path):
                i = i + 1
                path = os.path.join(_path, f'{cfg.model_path}_{i}')

    return path

class SequenceToImgs():

    def __init__(self, length, prefix, normalize = False,  normalize_mask=False, color_masked=False):
        self.normalize_mask = normalize_mask
        self.color_masked   = color_masked

        self.length    = length
        self.seqence   = []
        self.mask      = []
        self.prefix    = prefix
        self.counter   = 0
        self.normalize = normalize
        self.mean      = 0
        self.mean2     = 0
        self.sum       = 0

    def update(self, img, mask = None):
        
        self.seqence.append(img.detach().cpu())
        if mask is not None:
            self.mask.append(mask.detach().cpu())
        else:
            self.mask.append(th.ones_like(img[:1]).cpu())

        if len(self.seqence) >= self.length:
            imgs = th.stack(self.seqence)
            mask = th.stack(self.mask)

            if self.normalize_mask:
                mask = th.clip((mask - th.mean(mask)) / (2 * th.std(mask)), -1, 1) * 0.5 + 0.5

            if self.normalize:
                mean = th.sum(imgs * mask) / th.sum(mask)
                std  = th.sqrt(th.sum(mask * (imgs - mean)**2) / th.sum(mask))

                imgs = th.clip((imgs - mean) / (2 * std), -1, 1) * 0.5 + 0.5

            if self.color_masked:
                imgs = th.cat((
                    imgs[:,0:1] * (mask * 0.5 + 0.5),
                    imgs[:,1:2] * (mask * 0.5 + 0.5),
                    imgs[:,2:3] * (mask * 0.5 + 0.5) + (1 - mask) * 0.5
                ), dim=1)

            if imgs.shape[1] == 2:
                imgs = th.cat((imgs, th.zeros_like(imgs)[:,:1]), dim = 1)

            imgs = rearrange(imgs, 'b c h w -> b h w c').detach().cpu().numpy()

            for i in range(self.length):
                cv2.imwrite(f'{self.prefix}-{self.counter:05d}-{i:03d}.jpg', imgs[i] * 255)

            self.seqence  = []
            self.mask     = []
            self.counter += 1

            print(f'saveing imgs: {self.prefix}')

