import os
from numpy import identity
import pytorch_lightning as pl
import torch
from torch import Tensor
from typing import Dict, Any
import torchmetrics
import torch.nn.functional as F
import math


class BaseModel(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters(ignore=["datamodule", "embedding_model"])
        # note metrics are automatically tracked in self.trainer.logged_metrics
        # self.results stores additional results
        self.results: Dict[str, Any] = dict()

    @property
    def model_name(self):
        return self.__class__.__name__

    @property
    def logged_metrics(self) -> Dict[str, Any]:
        """Converts all logged metrics to json serializable pure Python values"""
        metrics = {"model_name": self.model_name}
        if not self.trainer:
            return metrics
        for name, value in self.trainer.logged_metrics.items():
            if isinstance(value, torch.Tensor):
                try:
                    value = value.item()
                except ValueError:
                    value = value.tolist()
            metrics[name] = value
        return metrics

    def on_save_checkpoint(self, checkpoint):
        checkpoint["seed"] = int(os.getenv("PL_GLOBAL_SEED", default=0))
        checkpoint["current_epoch"] = self.current_epoch
        checkpoint["datamodule_hparams"] = self.datamodule_hparams
        if "metrics" not in checkpoint:
            checkpoint["metrics"] = {}
        for metric in self.trainer.logged_metrics:
            checkpoint["metrics"][metric] = self.trainer.logged_metrics[metric]

    @property
    def datamodule_hparams(self) -> dict:
        if not hasattr(self, "datamodule"):
            return None
        return self.datamodule.hparams

    def on_load_checkpoint(self, checkpoint):
        self.seed = checkpoint["seed"]
        if "metrics" not in checkpoint:
            checkpoint["metrics"] = {}
        for metric in checkpoint["metrics"]:
            value = checkpoint["metrics"][metric]
            self.log(metric, value)

    def _setup_loader_names(self):
        if self.datamodule:
            self.train_loader_names = self.datamodule.train_loader_names
            self.val_loader_names = self.datamodule.val_loader_names
            self.test_loader_names = self.datamodule.test_loader_names
            if self.num_classes is None:
                self.num_classes = self.datamodule.num_classes
        else:
            print("loader names not loaded from datamodule")
            if self.num_classes is None:
                print("assuming dataset contains 15 classes")
                self.num_classes = 15
            self.train_loader_names = []
            self.val_loader_names = []
            self.test_loader_names = []

    def setup_accuracy_metrics(self):
        loader_types = (
            self.train_loader_names + self.val_loader_names + self.test_loader_names
        )

        for k in self.top_k:
            for data_type in loader_types:
                setattr(
                    self,
                    f"{data_type}_top_{k}_accuracy",
                    torchmetrics.Accuracy(top_k=k),
                )


class ShapesBaseModel(BaseModel):
    """Module for training on Shapes.

    Inherit and implement:
        - load_backbone()
        - forward()
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.top_k = kwargs["top_k"]
        # num_classes is set within _setup_loader_names()
        self.num_classes = kwargs.get("num_classes", None)
        self.datamodule = kwargs["datamodule"]

        self._setup_loader_names()
        if self.top_k:
            self.setup_accuracy_metrics()

    def load_backbone(self):
        raise NotImplementedError("model needs to implement a backbone")

    def forward(self):
        raise NotImplementedError("model needs to implement a forward method")

    def shared_step(self, batch: Tensor, stage: str = "train"):
        x, y, _ = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        batch_size = x.shape[0]
        self.log(
            f"{stage}_loss",
            loss,
            sync_dist=True,
            # loader names are used instead
            add_dataloader_idx=False,
            batch_size=batch_size,
        )
        for k in self.top_k:
            accuracy_metric = getattr(self, f"{stage}_top_{k}_accuracy")
            accuracy_metric(F.softmax(y_hat, dim=-1), y)
            self.log(
                f"{stage}_top_{k}_accuracy",
                accuracy_metric,
                prog_bar=True,
                sync_dist=True,
                on_epoch=True,
                on_step=False,
                batch_size=batch_size,
                # loader names are used instead
                add_dataloader_idx=False,
            )
        return loss

    def training_step(self, loaders, loader_idx):
        loss = 0
        for loader_name in loaders:
            batch = loaders[loader_name]
            loss += self.shared_step(batch, stage=loader_name)
        return loss

    def validation_step(self, batch, batch_idx, loader_idx=0):
        loader_name = self.val_loader_names[loader_idx]
        loss = self.shared_step(batch, stage=loader_name)
        return loss

    def test_step(self, batch, batch_idx, loader_idx=0):
        loader_name = self.test_loader_names[loader_idx]
        loss = self.shared_step(batch, stage=loader_name)
        return loss

    def configure_optimizers(self):
        if self.optimizer == "adam":
            return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        elif self.optimizer == "sgd":
            return self.sgd()
        raise ValueError(f"optimizer {self.optimizer} not implemented")

    def sgd(self):
        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=self.learning_rate,
            momentum=self.momentum,
            weight_decay=self.weight_decay,
        )
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
        return [optimizer], [scheduler]


class LieModelMixins:
    """Utilities for Lie model"""

    def initialize_L_basis(
        self, num_lie_generators: int, embedding_dim: int
    ) -> torch.nn.parameter.Parameter:
        L_basis = self.initialize_L_basis_uniform(num_lie_generators, embedding_dim)
        return L_basis

    def initialize_L_basis_uniform(
        self, num_lie_generators: int, embedding_dim: int
    ) -> torch.nn.parameter.Parameter:
        dim_m = embedding_dim
        init_L_basis = torch.FloatTensor(
            size=(num_lie_generators, dim_m, dim_m)
        ).uniform_(-math.sqrt(1.0 / dim_m), math.sqrt(1.0 / dim_m))
        L_basis = torch.nn.parameter.Parameter(data=init_L_basis)
        return L_basis

    def initialize_L_basis_identity(
        self, num_lie_generators: int, embedding_dim: int, epsilon=1e-6
    ) -> torch.nn.parameter.Parameter:
        identities = torch.stack(
            [torch.eye(embedding_dim) + epsilon for i in range(num_lie_generators)]
        )
        L_basis = torch.nn.parameter.Parameter(data=identities)
        return L_basis

    def create_t_inference_network(
        self, num_lie_generators: int, embedding_dim: int
    ) -> torch.nn.Sequential:
        dim_m = embedding_dim

        h = torch.nn.Sequential(
            torch.nn.Linear(2 * dim_m + 1, 2 * dim_m + 1, bias=True),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(2 * dim_m + 1, num_lie_generators, bias=True),
        )
        return h

    def infer_t(self, z, z_prime, delta):
        if z_prime is None:
            raise ValueError("z_prime cannot both be None")

        v = torch.cat([z, z_prime, delta], axis=-1)
        # Breaks back propagation wrt encoder on z (just learn the weights of h)
        t = 2 * torch.sigmoid(self.t_inference_network(v.detach().clone())) - 1

        # Hard code that t is 0 where delta is 0
        t = t * (1 - (delta == 0).float())
        return t

    def compute_rep(self, x):
        """Produces a representation z without any projection"""
        raise NotImplementedError

    def on_train_start(self) -> None:
        self.t_bounds = self.t_bounds.to(self.device)

    def on_train_epoch_start(self) -> None:
        """
        Resets bounds at each epoch
        """
        self.t_bounds[:] = 0.0

    def matrix_exponential(self, A):
        if self.truncate_matrix_exp:
            # uses a linearization
            return 1.0 + A
        return torch.matrix_exp(A)

    def on_train_batch_start(self, batch, batch_idx, dataloader_idx) -> None:
        self.g_matrix = {}
        for train_loader_name in self.train_loader_names:
            self.g_matrix[train_loader_name] = None

    def online_probe_forward(self, x1_online, x2_online, g, stage="train"):
        """Computes representations for online probe.

        Note dim_m = projection dimension

        Returns:
            z1, z2 as a single concatenated tensor of shape [batch_size 2, proj dim]

            if stage is not canonical, also concats g(z1) g(z2) -> [batch_size, 4, proj dim]
        """
        raise NotImplementedError

    def lie_nt_xent_loss(self, out_1, out_2, out_3, temperature, eps=1e-6):
        """
        DOES NOT assume out_1 and out_2 are normalized
        out_1: [batch_size, dim]
        out_2: [batch_size, dim]
        out_3: [batch_size, dim]
        """
        # gather representations in case of distributed training
        # out_1_dist: [batch_size * world_size, dim]
        # out_2_dist: [batch_size * world_size, dim]
        # out_3_dist: [batch_size * world_size, dim]
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            out_1_dist = SyncFunction.apply(out_1)
            out_2_dist = SyncFunction.apply(out_2)
            out_3_dist = SyncFunction.apply(out_3)
        else:
            out_1_dist = out_1
            out_2_dist = out_2
            out_3_dist = out_3

        # out: [2 * batch_size, dim]
        # out_dist: [3 * batch_size * world_size, dim]
        out = torch.cat([out_1, out_2], dim=0)
        out_dist = torch.cat([out_1_dist, out_2_dist, out_3_dist], dim=0)

        # cov and sim: [2 * batch_size, 3 * batch_size * world_size]
        # neg: [2 * batch_size]
        cov = torch.mm(out, out_dist.t().contiguous())
        sim = torch.exp(cov / temperature)
        neg = sim.sum(dim=-1)

        row_sub = torch.exp(torch.norm(out, dim=-1) / temperature)
        neg = torch.clamp(neg - row_sub, min=eps)  # clamp for numerical stability

        # Positive similarity, pos becomes [2 * batch_size]
        pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
        pos = torch.cat([pos, pos], dim=0)
        loss = -torch.log(pos / (neg + eps)).mean()
        if loss < 0.0:
            print("Lie Contrastive loss can't be negative")
            raise ValueError("Lie Contrastive loss can't be negative")
        return loss


class SyncFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, tensor):
        ctx.batch_size = tensor.shape[0]

        gathered_tensor = [
            torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
        ]

        torch.distributed.all_gather(gathered_tensor, tensor)
        gathered_tensor = torch.cat(gathered_tensor, 0)

        return gathered_tensor

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        torch.distributed.all_reduce(
            grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False
        )

        idx_from = torch.distributed.get_rank() * ctx.batch_size
        idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
        return grad_input[idx_from:idx_to]
