import numpy as np
import torch
import torch.distributed as dist
import torchmetrics
from omegaconf import DictConfig
from torch import nn

from meds_torch.models import (
    BACKBONE_EMBEDDINGS_KEY,
    MODEL_BATCH_LOSS_KEY,
    MODEL_EMBEDDINGS_KEY,
    MODEL_LOGITS_KEY,
)
from meds_torch.models.base_model import BaseModule
from meds_torch.models.utils import GatherLayer


class EBCLModule(BaseModule):
    def __init__(self, cfg: DictConfig):
        super().__init__(cfg)
        batch_size = cfg.batch_size
        self.pre_model = self.model
        self.post_model = self.model
        self.world_size = cfg.world_size
        #  metrics
        self.train_pre_acc = torchmetrics.Accuracy(num_classes=batch_size * cfg.world_size, task="multiclass")
        self.train_pre_auc = torchmetrics.AUROC(num_classes=batch_size * cfg.world_size, task="multiclass")

        self.train_post_acc = torchmetrics.Accuracy(
            num_classes=batch_size * cfg.world_size, task="multiclass"
        )
        self.train_post_auc = torchmetrics.AUROC(num_classes=batch_size * cfg.world_size, task="multiclass")

        self.val_pre_acc = torchmetrics.Accuracy(num_classes=batch_size * cfg.world_size, task="multiclass")
        self.val_pre_auc = torchmetrics.AUROC(num_classes=batch_size * cfg.world_size, task="multiclass")

        self.val_post_acc = torchmetrics.Accuracy(num_classes=batch_size * cfg.world_size, task="multiclass")
        self.val_post_auc = torchmetrics.AUROC(num_classes=batch_size * cfg.world_size, task="multiclass")

        self.test_pre_acc = torchmetrics.Accuracy(num_classes=batch_size * cfg.world_size, task="multiclass")
        self.test_pre_auc = torchmetrics.AUROC(num_classes=batch_size * cfg.world_size, task="multiclass")

        self.test_post_acc = torchmetrics.Accuracy(num_classes=batch_size * cfg.world_size, task="multiclass")
        self.test_post_auc = torchmetrics.AUROC(num_classes=batch_size * cfg.world_size, task="multiclass")

        # Model components
        self.pre_projection = nn.Linear(cfg.token_dim, cfg.token_dim)
        self.post_projection = nn.Linear(cfg.token_dim, cfg.token_dim)

        self.t = nn.Parameter(torch.ones(1).reshape(-1, 1) * np.log(cfg.tau))
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, batch):
        pre_batch = batch[self.cfg.pre_window_name]
        pre_batch = self.input_encoder(pre_batch)
        pre_batch = self.pre_model(pre_batch)
        pre_outputs = pre_batch[BACKBONE_EMBEDDINGS_KEY]

        post_batch = batch[self.cfg.post_window_name]
        post_batch = self.input_encoder(post_batch)
        post_batch = self.pre_model(post_batch)
        post_outputs = post_batch[BACKBONE_EMBEDDINGS_KEY]

        pre_embeds = self.pre_projection(pre_outputs)
        post_embeds = self.post_projection(post_outputs)

        pre_norm_embeds = pre_embeds
        post_norm_embeds = post_embeds

        pre_norm_embeds = pre_embeds / pre_embeds.norm(dim=-1, keepdim=True)
        post_norm_embeds = post_embeds / post_embeds.norm(dim=-1, keepdim=True)

        # Gather embeddings across all devices
        if dist.is_initialized():
            pre_norm_embeds = GatherLayer.apply(pre_norm_embeds)
            post_norm_embeds = GatherLayer.apply(post_norm_embeds)

        logits = torch.mm(post_norm_embeds, pre_norm_embeds.T) * torch.exp(self.t)
        labels = torch.arange(pre_norm_embeds.shape[0], device=pre_norm_embeds.device)
        logits_per_post = logits
        logits_per_pre = logits.T
        loss_post = self.criterion(logits_per_post, labels)
        loss_pre = self.criterion(logits_per_pre, labels)
        loss = (loss_pre + loss_post) / 2
        batch["pre"] = pre_batch
        batch["post"] = pre_batch
        batch[MODEL_EMBEDDINGS_KEY] = torch.concat([pre_norm_embeds, post_norm_embeds], dim=1)
        batch[MODEL_BATCH_LOSS_KEY] = loss
        batch[MODEL_LOGITS_KEY] = logits
        return batch

    def training_step(self, batch):
        output = self.forward(batch)
        # pretrain metrics
        # pre metrics
        labels = torch.arange(output[MODEL_LOGITS_KEY].shape[0], device=output[MODEL_LOGITS_KEY].device)
        self.train_pre_acc.update(torch.diag(output[MODEL_LOGITS_KEY]), labels)
        self.train_pre_auc.update(output[MODEL_LOGITS_KEY], labels)

        # post metrics
        self.train_post_acc.update(torch.diag(output[MODEL_LOGITS_KEY]), labels)
        self.train_post_auc.update(output[MODEL_LOGITS_KEY], labels)

        self.log("train/loss", output[MODEL_BATCH_LOSS_KEY], batch_size=self.cfg.batch_size)
        return output[MODEL_BATCH_LOSS_KEY]

    def validation_step(self, batch):
        output = self.forward(batch)
        # pretrain metrics
        # pre metrics
        labels = torch.arange(output[MODEL_LOGITS_KEY].shape[0], device=output[MODEL_LOGITS_KEY].device)
        self.val_pre_acc.update(torch.diag(output[MODEL_LOGITS_KEY]), labels)
        if output[MODEL_LOGITS_KEY].shape[0] == self.cfg.batch_size:
            self.val_pre_auc.update(output[MODEL_LOGITS_KEY], labels)

        # post metrics
        self.val_post_acc.update(torch.diag(output[MODEL_LOGITS_KEY]), labels)
        if output[MODEL_LOGITS_KEY].shape[0] == self.cfg.batch_size:
            self.val_post_auc.update(output[MODEL_LOGITS_KEY], labels)
        self.log("val/loss", output[MODEL_BATCH_LOSS_KEY], batch_size=self.cfg.batch_size, sync_dist=True)
        return output[MODEL_BATCH_LOSS_KEY]

    def test_step(self, batch):
        output = self.forward(batch)
        # pretrain metrics
        # pre metrics
        labels = torch.arange(output[MODEL_LOGITS_KEY].shape[0], device=output[MODEL_LOGITS_KEY].device)
        self.test_pre_acc.update(torch.diag(output[MODEL_LOGITS_KEY]), labels)
        if output[MODEL_LOGITS_KEY].shape[0] == self.cfg.batch_size:
            self.test_pre_auc.update(output[MODEL_LOGITS_KEY], labels)

        # post metrics
        self.test_post_acc.update(torch.diag(output[MODEL_LOGITS_KEY]), labels)
        if output[MODEL_LOGITS_KEY].shape[0] == self.cfg.batch_size:
            self.test_post_auc.update(output[MODEL_LOGITS_KEY], labels)
        self.log("test/loss", output[MODEL_BATCH_LOSS_KEY], batch_size=self.cfg.batch_size)
        return output[MODEL_BATCH_LOSS_KEY]

    def on_train_epoch_end(self):
        self.log(
            "train/pre/acc",
            self.train_pre_acc,
            on_epoch=True,
            batch_size=self.cfg.batch_size,
        )
        self.log(
            "train/pre/auc",
            self.train_pre_auc,
            on_epoch=True,
            batch_size=self.cfg.batch_size,
        )

        self.log(
            "train/post/acc",
            self.train_post_acc,
            on_epoch=True,
            batch_size=self.cfg.batch_size,
        )
        self.log(
            "train/post/auc",
            self.train_post_auc,
            on_epoch=True,
            batch_size=self.cfg.batch_size,
        )

    def on_test_epoch_end(self):
        self.log(
            "test/pre/acc",
            self.test_pre_acc,
            on_epoch=True,
            batch_size=self.cfg.batch_size,
        )
        self.log(
            "test/pre/auc",
            self.test_pre_auc,
            on_epoch=True,
            batch_size=self.cfg.batch_size,
        )

        self.log(
            "test/post/acc",
            self.test_post_acc,
            on_epoch=True,
            batch_size=self.cfg.batch_size,
        )
        self.log(
            "test/post/auc",
            self.test_post_auc,
            on_epoch=True,
            batch_size=self.cfg.batch_size,
        )
        print(
            "test/pre/acc",
            self.test_pre_acc.compute(),
            "test/pre/auc",
            self.test_pre_auc.compute(),
        )
        print(
            "test/post/acc",
            self.test_post_acc.compute(),
            "test/post/auc",
            self.test_post_auc.compute(),
        )
