import torch

from pytorch_lightning.core.module import LightningModule
from torch.optim import Adam, SGD, AdamW
from lightning.pytorch.utilities import grad_norm
from models.cgcnn import CGCNN
from losses.clip_loss import clip_loss, cosface_loss
from losses.utils import calc_grobal_top_k_acc, batch_wise_accuracy
from models.text_encoder import HuggingFaceEncoder
from models.utils import normalize_embedding


class ClaspModel(LightningModule):
    """A class representing a metric learning model using LightningModule.

    Attributes:
    model (CrystalEncoder): The crystal encoder model specified by the encoder_name.
    train_loader: A dataloader containing the training data.
    val_loader: A dataloader containing the validation data.
    cfg: Model parameters such as learning rate, encoder name etc.
    """
    def __init__(self, cfg):
        super(ClaspModel, self).__init__()

        if cfg.encoder_name == "cgcnn":
            self.model = CGCNN(cfg)
        else:
            raise Exception(f"Invalid cfg.encoder_name: {cfg.encoder_name}")

        self.model_text = HuggingFaceEncoder(cfg)
        if cfg.loss_fn == 'clip_loss':
            self.loss_fn = clip_loss
        elif cfg.loss_fn == 'cosface_loss':
            self.loss_fn = cosface_loss
        else:
            raise ValueError(f"Invalid cfg.loss_fn: {cfg.loss_type}")

        self.cfg = cfg
        # self.training_step_outputs = []
        self.validation_step_outputs = []


    def training_step(self, batch, batch_idx):
        output_cry = self.model(batch)
        output_text = self.model_text(batch)
        if self.cfg.embedding_normalize is not None:
            output_text = normalize_embedding(output_text, self.cfg.embedding_normalize)
            output_cry = normalize_embedding(output_cry, self.cfg.embedding_normalize)
        # print(f"output_cry shape (before gather): {output_cry.shape}")
        output_cry = self.all_gather(output_cry, sync_grads=True)
        # print(f"output_cry shape (afrer gather): {output_cry.shape}")
        output_text = self.all_gather(output_text, sync_grads=True)
        num_elements = output_cry.shape[0] * output_cry.shape[1]
        output_cry = output_cry.view(num_elements, *output_cry.shape[2:])
        # print(f"output_cry shape (afrer reshape): {output_cry.shape}")
        output_text = output_text.view(num_elements, *output_text.shape[2:])

        loss = self.loss_fn(output_text, output_cry, self.cfg)
        _, batch_acc = batch_wise_accuracy(output_text, output_cry)

        output = {
            'loss': loss,
            'progress_bar': {'tr/loss': loss, 'tr/acc':batch_acc},
            'log': {'train/loss': loss, 'train/batch_acc':batch_acc}
        }
        self.log('train/loss', loss.to("cuda"), sync_dist=True)
        self.log('train/batch_acc', batch_acc.to("cuda"), sync_dist=True)
        return output

    def validation_step(self, batch, batch_idx):
        output_cry = self.model(batch)
        output_text = self.model_text(batch)
        if self.cfg.embedding_normalize is not None:
            output_cry = normalize_embedding(output_cry, self.cfg.embedding_normalize)
            output_text = normalize_embedding(output_text, self.cfg.embedding_normalize)
        output_cry = self.all_gather(output_cry, sync_grads=True)
        output_text = self.all_gather(output_text, sync_grads=True)
        num_elements = output_cry.shape[0] * output_cry.shape[1]
        output_cry = output_cry.view(num_elements, *output_cry.shape[2:])
        output_text = output_text.view(num_elements, *output_text.shape[2:])

        loss = self.loss_fn(output_text, output_cry, self.cfg)
        _, batch_acc = batch_wise_accuracy(output_text, output_cry)
        output = {'val/loss': loss, 
                  'val/acc': batch_acc.float(),
                  'out_cry': output_cry.detach().cpu(),
                  'out_text': output_text.detach().cpu(),
                 }
        self.validation_step_outputs.append(output)
        
        return loss

    def on_validation_epoch_end(self):
        outputs = self.validation_step_outputs
        avg_loss = torch.stack([x['val/loss'] for x in outputs]).mean()
        acc = torch.stack([x['val/acc'] for x in outputs]).mean()
        cry = torch.cat([x['out_cry'] for x in outputs], dim=0)
        text = torch.cat([x['out_text'] for x in outputs], dim=0)
        topk_acc = calc_grobal_top_k_acc(embedding_query=text, embedding_target=cry, k=10)
        logs = {'val/loss': avg_loss, 'val/acc': acc}
        
        self.log('val/loss', avg_loss.to("cuda"), sync_dist=True)
        self.log('val/acc', acc.to("cuda"), sync_dist=True)
        for i in range(len(topk_acc)):
            logs['val/top%02d' % (i+1)] = torch.tensor(topk_acc[i])
            self.log('val/top%02d' % (i+1), torch.tensor(topk_acc[i]).to("cuda"), sync_dist=True)
        
        if self.global_rank == 0:
            print("######")
            print(f"val loss: {avg_loss:.3f}")
            print(f'val acc: {acc*100:4.2f} ')
            print("evaluating text->crystal serch acc...")
            for i in range(len(topk_acc)):
                print(f'top{i+1}: {topk_acc[i]*100:4.2f} ')
            print("######")
        
        self.validation_step_outputs = []
        return {'log': logs}

    def configure_optimizers(self):
        opt_class = {
            'Adam': Adam,
            'SGD': SGD,
            'AdamW': AdamW
        }
        # Default to Adam if no optimizer is specified in the config
        optimizer_name = getattr(self.cfg, 'optimizer', 'Adam')

        p = []
        p.extend(self.model.parameters())
        p.extend(self.model_text.parameters())
        if optimizer_name not in opt_class:
            raise ValueError(f"Optimizer '{optimizer_name}' not recognized. Available options are: {list(opt_class.keys())}")

        optimizer_class = opt_class[optimizer_name]
        optimizer = optimizer_class(p, lr=self.cfg.lr)

        return optimizer

    def on_before_optimizer_step(self, optimizer):
        # Compute the 2-norm for each layer
        # If using mixed precision, the gradients are already unscaled here
        norms = grad_norm(self.model, norm_type=2)
        self.log_dict(norms)
        norms = grad_norm(self.model_text, norm_type=2)
        self.log_dict(norms)

    def forward(self, x):
        output_cry = self.model(x)
        output_text = self.model_text(x)
        if self.cfg.embedding_normalize is not None:
            output_text = normalize_embedding(output_text, self.cfg.embedding_normalize)
            output_cry = normalize_embedding(output_cry, self.cfg.embedding_normalize)
        
        return output_cry, output_text