import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
import anndata as ad
import pytorch_lightning as pl
from tqdm import tqdm
from typing import Callable
from functools import partial
from scipy.sparse import csr_matrix
from torch.optim.lr_scheduler import LambdaLR

from .util import *
from celldiff.evaluate import evaluate_annotation, integration_eval
from celldiff.util import instantiate_from_config


class DiffusionWrapper(pl.LightningModule):
    def __init__(self, diff_model_config):
        super().__init__()
        self.diffusion_model = instantiate_from_config(diff_model_config)

    def forward(self, x_orig, x, t, pe_input=None, conditions=None, input_gene_list=None, target_gene_list=None, mask=True):
        out = self.diffusion_model(x_orig, x, t, pe_input, conditions, input_gene_list, target_gene_list, mask)
        return out

class OrdinalJumpDiffusion(pl.LightningModule):
    def __init__(
            self,
            model_config,
            scheduler_config=None,
            lbd: float = 1.,
            pred_type: str = "x_0",
            loss_type: str = "kl",
            timesteps: int = 1000,
            clip_range: tuple = None,
            input_clip: tuple = None,
            normalize: tuple = None,  # torchvision-like data normalization
            z_rescale: bool = False,
            p_self_cond: float = 0.,
            time_diff: float = 0.,
            momentum: float = 0.,
            beta_start=0.001,
            beta_end='auto',
            beta_schedule='linear',
            logsnr_end=-12,
            signal_stat=1.,
            input_key="input",
            pe_input_key="coord",
            cond_key="cond",
            cond_names=None,
            raw_input_key="raw_input",
            input_gene_list_key="input_gene_list",
            t_sample=500,
            monitor="val/loss",
            integrate_flag=True,
            classify_flag=False,
            classifier_config=None,
            guidance_flag=False,
            guidance_config=None,
            test_target_sum=1e3,
            cond_to_ignore: list = None,
            **kwargs
    ):
        super().__init__()
        self.lbd = lbd
        self.timesteps = timesteps
        self.pred_type = pred_type
        self.loss_type = loss_type
        self.input_key = input_key
        self.pe_input_key = pe_input_key
        self.cond_key = cond_key
        self.raw_input_key = raw_input_key
        self.input_gene_list_key = input_gene_list_key
        self.clip_range = parse_range(clip_range)
        self.input_clip = parse_range(input_clip)
        self.normalize = parse_range(normalize)
        self.z_rescale = z_rescale
        self.p_self_cond = p_self_cond
        self.time_diff = time_diff
        self.momentum = momentum
        self.t_sample = t_sample
        self.test_target_sum = test_target_sum
        self.cond_to_ignore = cond_to_ignore
        self.integrate_flag = integrate_flag
        self.model = DiffusionWrapper(model_config)
        self.use_scheduler = scheduler_config is not None
        if self.use_scheduler:
            self.scheduler_config = scheduler_config
        if monitor is not None:
            self.monitor = monitor

        self.classify_flag = classify_flag
        if self.classify_flag:
            assert classifier_config is not None, "Classify flag set but classifier config not passed"
            self.classifier = instantiate_from_config(classifier_config, model=self)

            if cond_names is None:
                cond_names = ["batch", "celltype"]
                warnings.warn(
                    f"Condition names cond_names not set, using default of {cond_names!r}. "
                    "This is needed for specifying metric names for classification. "
                    "Please specify to suppress the warning.",
                    UserWarning,
                    stacklevel=2,
                )
            self.cond_names = cond_names

        self.guidance_flag = guidance_flag
        if self.guidance_flag:
            assert guidance_config is not None, "Guidance flag set but guidance config not passed"
            self.guidance = instantiate_from_config(guidance_config)

        self.register_schedule(beta_start, beta_end, beta_schedule, timesteps, logsnr_end, signal_stat, lbd)
        self.register_buffer("unique_conditions", None)

        self.test_step_outputs = []

    def register_schedule(self, beta_start=0.001, beta_end='auto', beta_schedule="linear", timesteps=1000,
                          logsnr_end=-12., signal_stat=1., lbd=1.):        
        schedule_dict, _ = get_decay_schedule(f'beta_{beta_schedule}', timesteps, diffusion_type='ordinal_jump', signal_stat=signal_stat, 
                                              beta_start=beta_start, beta_end=beta_end, logsnr_end=logsnr_end, lbd=lbd)
        betas = schedule_dict['betas']
        alphas = schedule_dict['alphas']
        alphas_prev = np.append(1., alphas[:-1])
        deltas = deltas_asym = (alphas_prev - alphas) * lbd
        assert alphas.shape[0] == timesteps, 'alphas have to be defined for each timestep'

        to_torch = partial(torch.tensor, dtype=torch.float32)

        self.register_buffer('betas', to_torch(betas))
        self.register_buffer('alphas', to_torch(alphas))
        self.register_buffer('alphas_prev', to_torch(alphas_prev))
        self.register_buffer('deltas', to_torch(deltas))
        self.register_buffer('deltas_asym', to_torch(deltas_asym))

    def _get_alpha(self, x, t):
        alpha = extract(self.alphas, x, t)
        return alpha

    def _get_delta(self, x, t, asym=False):
        delta = extract(self.deltas_asym if asym else self.deltas, x, t)
        return delta

    def q_sample(self, x_0, t, return_rate=False):
        rate = self._get_alpha(x_0, t) * self.lbd
        z_t = torch.poisson(rate * x_0)
        return (z_t, rate) if return_rate else z_t

    def q_posterior(self, x_0, z_t, t, asym=False, rng=None):
        rate = self._get_delta(x_0, t, asym=asym) * x_0
        z_prev = torch.poisson(rate, generator=rng) + z_t
        return z_prev

    def _pred_x_0_from_eps(self, z_t, eps, t):
        rate = self._get_alpha(z_t, t) * self.lbd
        x_0 = z_t.div(rate) - eps.div(rate.sqrt())
        return x_0

    def _pred_x_0_from_eps_anscombe(self, z_t, eps, t):
        rate = self._get_alpha(z_t, t) * self.lbd
        x_0 = (((z_t + 3. / 8).sqrt() - .5 * eps) ** 2 - 3. / 8) / rate
        return x_0

    def _pred_x_0_from_eps_freeman_tukey(self, z_t, eps, t):
        rate = self._get_alpha(z_t, t) * self.lbd
        x_0 = (torch.sqrt(z_t + 1) + torch.sqrt(z_t) - eps).pow(2).div(4 * rate)
        return x_0

    def scnormalize(self, x, target_sum=1e4, eps=1e-8):
        x = x * target_sum / (x.sum(1, keepdim=True) + eps)
        x = torch.log1p(x)
        return x

    def invervse_scnormalize(self, x, library_size=1e4, eps=1e-8):
        x = torch.exp(x) - 1
        x = x * library_size / (x.sum(1, keepdim=True) + eps)
        return x

    def p_sample_step(self, x_0, z_t, t, pe_input=None, conditions=None, input_gene_list=None,
                      accum_x_0=None, return_pred=False, rng=None):
        _t = torch.full((z_t.shape[0],), fill_value=t, device=z_t.device)
        _z_t = z_t
        if self.z_rescale:
            _z_t = z_t.div(self.lbd * self.alphas[t])
            if self.input_clip is not None:
                _z_t.clamp_(*self.input_clip)  # input clipping
        # if self.p_self_cond > 0 and accum_x_0 is not None:
        #     _z_t = torch.cat([_z_t, accum_x_0], dim=1)
        if self.guidance_flag:
            model_out = self.guidance(self.model, sample_flag=True, x_orig=x_0, x_noised=_z_t, 
                                      t=t, pe_input=pe_input, conditions=conditions,
                                      input_gene_list=input_gene_list)
        else:
            if self.cond_to_ignore is not None:
                assert len(self.cond_to_ignore) <= conditions.shape[1]
                assert all([0 <= x < conditions.shape[1] for x in self.cond_to_ignore])
                conditions[:, self.cond_to_ignore] = 0
            model_out, _ = self.model(x_0, _z_t, _t, pe_input, conditions, input_gene_list, mask=False)
        if self.pred_type == "x_0":
            pred_x_0 = model_out
        elif self.pred_type == "eps":
            pred_x_0 = self._pred_x_0_from_eps(z_t, eps=model_out, t=_t)
        elif self.pred_type == "eps_anscombe":
            pred_x_0 = self._pred_x_0_from_eps_anscombe(z_t, eps=model_out, t=_t)
        elif self.pred_type == "eps_freeman_tukey":
            pred_x_0 = self._pred_x_0_from_eps_freeman_tukey(z_t, eps=model_out, t=_t)
        else:
            raise NotImplementedError(self.pred_type)
        pred_x_0 = pred_x_0.clamp(min=0)
        if self.clip_range is not None:
            pred_x_0.clamp_(*self.clip_range)  # prediction clipping
        z_prev = self.q_posterior(pred_x_0, z_t, t, asym=True, rng=rng)
        return (z_prev, pred_x_0) if return_pred else z_prev

    @torch.no_grad()
    def p_sample(self, x_start, t_start, pe_input=None, conditions=None, input_gene_list=None, 
                 seed=None, return_pred=False):
        assert t_start <= self.timesteps
        z_T = torch.zeros_like(x_start) if t_start == self.timesteps else self.q_sample(x_0=x_start, t=t_start)
        z_t = z_T
        pred_x_0 = torch.zeros_like(z_t)
        accum_x_0 = torch.zeros_like(z_t)
        rng = None
        if seed is not None:
            device = next(self.model.parameters()).device
            rng = torch.Generator(device).manual_seed(seed)
        for t in tqdm(reversed(range(0, t_start)), desc='Sampling t', total=int(t_start + 1)):
            accum_x_0 += (1 - self.momentum) * (pred_x_0 - accum_x_0)
            z_t, pred_x_0 = self.p_sample_step(
                x_start, z_t, t, pe_input, conditions, input_gene_list,
                accum_x_0=accum_x_0, return_pred=True, rng=rng
            )
        x_0 = z_t.div_(self.lbd)
        if self.normalize is not None:
            x_0.mul_(self.normalize[1]).add_(self.normalize[0])
            pred_x_0.mul_(self.normalize[1]).add_(self.normalize[0])
        return x_0, (pred_x_0 if return_pred else None)

    @torch.no_grad()
    def p_sample_progressive(self, x_start, t_start, pe_input=None, conditions=None, freq=10, seed=None):
        assert t_start <= self.timesteps
        z_T = torch.zeros_like(x_start) if t_start == self.timesteps else self.q_sample(x_0=x_start, t=t_start)
        z_t = z_T
        pred_x_0 = torch.zeros_like(z_t)
        accum_x_0 = torch.zeros_like(z_t)
        B, *D = z_t.shape
        T = self.timesteps // freq + 1
        sample_path = torch.empty((T, B, *D))
        n = 0
        rng = None
        if seed is not None:
            device = next(self.model.parameters()).device
            rng = torch.Generator(device).manual_seed(seed)
        for t in range(self.timesteps - 1, -1, -1):
            accum_x_0 += (1 - self.momentum) * (pred_x_0 - accum_x_0)
            z_t, pred_x_0 = self.p_sample_step(
                x_start, z_t, t, pe_input, conditions, accum_x_0=accum_x_0, return_pred=True, rng=rng)
            if (t + 1) % freq == 0:
                sample_path[n] = pred_x_0.cpu()
                n += 1
        sample_path[n] = z_t.div(self.lbd).cpu()
        if self.normalize is not None:
            sample_path.mul_(self.normalize[1]).add_(self.normalize[0])
        return sample_path

    def train_loss(self, x_0, t, pe_input=None, conditions=None, input_gene_list=None):
        if self.normalize is not None:
            x_0 = x_0.sub(self.normalize[0]).div(self.normalize[1])

        z_t, rate = self.q_sample(x_0, t, return_rate=True)
        _z_t = z_t
        if self.z_rescale:
            _z_t = z_t.div(rate)  # unbiased mean estimator of x_0
            if self.input_clip is not None:
                _z_t.clamp_(*self.input_clip)  # input clipping

        # if self.p_self_cond > 0:
        #     with torch.no_grad():
        #         _z_t = torch.cat([
        #             _z_t, rand_zero(self.model(x_0, torch.cat([_z_t, torch.zeros_like(_z_t)], dim=1), 
        #                                        t, pe_input, conditions), prob=self.p_self_cond)[0]
        #         ], dim=1)

        if self.guidance_flag:
            model_out, mask = self.guidance(self.model, sample_flag=False, x_orig=x_0, x_noised=_z_t, 
                                            t=t, pe_input=pe_input, conditions=conditions, 
                                            input_gene_list=input_gene_list)
        else:
            model_out, mask = self.model(x_0, _z_t, t, pe_input, conditions, input_gene_list)
        model_out = model_out * mask
        x_0 = x_0 * mask

        if self.pred_type == "x_0":
            pred_x_0 = model_out
        elif self.pred_type == "eps":
            pred_x_0 = self._pred_x_0_from_eps(z_t, eps=model_out, t=t)
        elif self.pred_type == "eps_anscombe":
            pred_x_0 = self._pred_x_0_from_eps_anscombe(z_t, eps=model_out, t=t)
        elif self.pred_type == "eps_freeman_tukey":
            pred_x_0 = self._pred_x_0_from_eps_freeman_tukey(z_t, eps=model_out, t=t)
        else:
            raise NotImplementedError(self.pred_type)

        pred_x_0 = pred_x_0.clamp(min=0)  # not necessary?

        # KL(q(z_{t-1}|z_t, x_0) || p_\theta(z_{t-1}|z_t))
        # z_{t-1}|z_t, x_0 = z_t + Pois(\lambda(\alpha_{t-1}-\alpha_t})x_0)
        if self.loss_type == "kl":  # no re-weighting
            delta = self._get_delta(z_t, t)
            loss = flat_mean(poisson_kl(delta * x_0, delta * pred_x_0))
        elif self.loss_type == "kl_rev":  # no re-weighting
            delta = self._get_delta(z_t, t)
            loss = flat_mean(poisson_kl(delta * pred_x_0, delta * x_0))
        elif self.loss_type == "kl_simple":  # weight: 1 / delta
            loss = flat_mean(poisson_kl(x_0, pred_x_0))
        elif self.loss_type == "kl_alpha":  # weight: sqrt(alpha) / delta
            alpha = self._get_alpha(x_0, t)
            loss = flat_mean(poisson_kl(x_0, pred_x_0)) * alpha.sqrt()
        else:
            raise NotImplementedError(self.loss_type)

        return loss

    def forward(self, x, *args, **kwargs):
        # t = torch.randint(0, self.timesteps, (x.shape[0],), dtype=torch.int64, device=self.device)
        t = torch.randint(0, self.timesteps, (1,), dtype=torch.int64, device=self.device)
        return self.train_loss(x, t, *args, **kwargs)

    def get_input(self, batch, k):
        if k in batch.keys():
            x = batch[k]
            if isinstance(x, torch.Tensor):
                x = x.to(memory_format=torch.contiguous_format).float()
        else:
            x = None
        return x

    def maybe_record_conditions(self, batch):
        """Gather conditions information over the full dataset in the first
        training epoch.

        """
        conditions = self.get_input(batch, self.cond_key)
        if (self.current_epoch == 0) and (conditions is not None):
            self.unique_conditions = (
                conditions.unique(dim=0)
                if self.unique_conditions is None
                else torch.cat((self.unique_conditions, conditions)).unique(dim=0)
            )

    def shared_step(self, batch):
        x = self.get_input(batch, self.input_key)
        pe_input = self.get_input(batch, self.pe_input_key)
        conditions = self.get_input(batch, self.cond_key)
        input_gene_list = self.get_input(batch, self.input_gene_list_key)
        loss = self(x, pe_input=pe_input, conditions=conditions, input_gene_list=input_gene_list)
        return loss

    def training_step(self, batch, batch_idx):
        self.maybe_record_conditions(batch)

        loss = self.shared_step(batch).mean()
        self.log('train/loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        self.log('global_step', self.global_step,
                 prog_bar=True, logger=True, on_step=True, on_epoch=False)
        if self.use_scheduler:
            lr = self.optimizers().param_groups[0]['lr']
            self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
        return loss

    @torch.no_grad()
    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch).mean()
        self.log('val/loss', loss, prog_bar=False, logger=True, on_step=False, on_epoch=True)

    # TODO: load raw counts, inverse_normalize, 
    @torch.no_grad()
    def test_step(self, batch, batch_idx):
        x = self.get_input(batch, self.input_key)
        pe_input = self.get_input(batch, self.pe_input_key)
        conditions = self.get_input(batch, self.cond_key)
        input_gene_list = self.get_input(batch, self.input_gene_list_key)
        latent = self.model.diffusion_model.get_latent(x, pe_input, input_gene_list)
        t_sample = torch.tensor([self.t_sample]).to(x).int()
        recon, _ = self.p_sample(x, t_sample, pe_input, conditions, input_gene_list)
        raw_x = self.get_input(batch, self.raw_input_key)            

        if self.classify_flag:
            pred_conditions = self.classifier(x, conditions)

        out = {
            'x': x.cpu(),
            'raw_x': raw_x.cpu() if raw_x is not None else None,
            'recon': recon.cpu(),
            'latent': latent.cpu(),
            'pred_conditions': pred_conditions.cpu() if self.classify_flag else None,
            'conditions': conditions.cpu(),
        }
        self.test_step_outputs.append(out)

        return out

    @torch.no_grad()
    def on_test_epoch_end(self):
        outputs = self.test_step_outputs
        x = torch.cat([outdict['x'].cpu() for outdict in outputs])
        recon = torch.cat([outdict['recon'].cpu() for outdict in outputs])
        conditions = torch.cat([outdict['conditions'].cpu() for outdict in outputs]).numpy()

        metrics_dict = {}
        if self.classify_flag:
            pred_conditions = torch.cat([outdict['pred_conditions'].cpu() for outdict in outputs]).numpy()
            pred_correct = pred_conditions == conditions

            # Classification metrics for individual conditions
            for i, cond_name in enumerate(self.cond_names):
                pred, true = pred_conditions[:, i], conditions[:, i]
                metrics_dict.update(evaluate_annotation(true, pred, cond_name))

            # Accuracy for combined conditions (global avg, not macro)
            metrics_dict['overall_acc'] = pred_correct.all(1).mean()

        if outputs[0]['raw_x'] is not None:
            raw_x = torch.cat([outdict['raw_x'].cpu() for outdict in outputs])
            recon_inv = self.invervse_scnormalize(recon, library_size=raw_x.sum(1, keepdim=True))
            poisson_nll = nn.PoissonNLLLoss(log_input=False, full=True)(recon_inv, raw_x).item()
            rmse = np.sqrt(F.mse_loss(recon_inv, raw_x).item())
            rmse_normed = np.sqrt(F.mse_loss(recon, x).item())
        else:
            poisson_nll = nn.PoissonNLLLoss(log_input=False, full=True)(recon, x).item()
            rmse = np.sqrt(F.mse_loss(recon, x).item())
            x_normed = self.scnormalize(x, target_sum=self.test_target_sum)
            recon_normed = self.scnormalize(recon, target_sum=self.test_target_sum)
            rmse_normed = np.sqrt(F.mse_loss(recon_normed, x_normed).item())

        if self.integrate_flag:
            adata = ad.AnnData(csr_matrix(recon.numpy()))
            adata.obs[['batch', 'cell_type']] = conditions
            adata.obs[['batch', 'cell_type']] = adata.obs[['batch', 'cell_type']].astype('category')
            # adata.obsm['X_emb'] = latent
            metrics_dict.update(integration_eval(adata, use_rep=None))
            del adata
        
        metrics_dict['poisson_nll'] = poisson_nll
        metrics_dict['rmse'] = rmse
        metrics_dict['rmse_normed'] = rmse_normed

        self.log_dict(metrics_dict, prog_bar=True, logger=True, on_step=False, on_epoch=True)
        self.log('t_sample', self.t_sample, prog_bar=True, logger=True, on_step=False, on_epoch=True)

        self.test_step_outputs.clear()

    def configure_optimizers(self):
        lr = self.learning_rate
        params = list(self.model.parameters())
        opt = torch.optim.AdamW(params, lr=lr)
        if self.use_scheduler:
            assert 'target' in self.scheduler_config
            scheduler = instantiate_from_config(self.scheduler_config)

            print("Setting up LambdaLR scheduler...")
            scheduler = [
                {
                    'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
                    'interval': 'step',
                    'frequency': 1
                }]
            return [opt], scheduler
        return opt


class BitJumpDiffusion(OrdinalJumpDiffusion):
    def __init__(
            self,
            alphas: torch.Tensor = None,
            alpha_fn: Callable = None,
            lbd: float = 1.,
            pred_type: str = "x_0",
            loss_type: str = "const",
            timesteps: int = 1000,
            clip_range: tuple = None,
            input_clip: tuple = None,
            normalize: tuple = None,
            z_rescale: bool = False,
            p_self_cond: float = 0.,
            time_diff: float = 0.,
            momentum: float = 0.,
            num_bits: int = 8,
            **kwargs
    ):
        super().__init__(
            alphas=alphas, alpha_fn=alpha_fn, lbd=lbd, pred_type=pred_type, loss_type=loss_type, timesteps=timesteps,
            clip_range=clip_range, input_clip=input_clip, normalize=normalize, z_rescale=z_rescale,
            p_self_cond=p_self_cond, time_diff=time_diff, momentum=momentum)
        self.num_bits = num_bits

    def int_to_bits(self, x):
        shift = torch.arange(self.num_bits - 1, -1, -1, device=x.device).reshape((1, 1, -1) + (1,) * (x.ndim - 2))
        bits = torch.fmod(torch.bitwise_right_shift(
            x.int().unsqueeze(2), shift), 2).reshape((x.shape[0], -1) + x.shape[2:])
        return bits.float()

    def bits_to_int(self, bits):
        bases = 2 ** torch.arange(
            self.num_bits - 1, -1, -1, device=bits.device
        ).reshape((1, 1, -1) + (1,) * (bits.ndim - 2))
        x = ((bits > 0).int().reshape((bits.shape[0], -1, self.num_bits) + bits.shape[2:]) * bases).sum(dim=2)
        return x.float()

    def train_loss(self, model, x_0, t, **kwargs):
        bits = self.int_to_bits(x_0)
        return super().train_loss(model=model, x_0=bits, t=t)

    @torch.no_grad()
    def p_sample(self, model, z_T, seed=None, return_pred=False):
        x_0, pred_x_0 = super().p_sample(model=model, z_T=z_T, seed=seed, return_pred=True)
        return self.bits_to_int(x_0.round()), (self.bits_to_int(pred_x_0.round()) if return_pred else None)

    @torch.no_grad()
    def p_sample_progressive(self, model, z_T, freq=10, seed=None):
        sample_path = super().p_sample_progressive(model=model, z_T=z_T, freq=freq, seed=seed).round()
        return self.bits_to_int(sample_path.reshape((-1,) + sample_path.shape[2:])).reshape(sample_path.shape)
