import torch
from utils.cav import get_latent_encoding_batch

import pytorch_lightning as pl

from utils.metrics import get_accuracy, get_auc, get_f1

class BasicModel(pl.LightningModule):
    def __init__(self, backbone, n_classes, config):
        super().__init__()
        self.loss = None
        self.optim = None
        self.milestones = None
        self.n_classes = n_classes
        self.backbone = backbone.eval() 
        self.classifier = None
        self.cav_dim = config["cav_dim"]
        self.layer_name = config["layer_name"]
        
        # Freeze the backbone's parameters
        for n, param in self.backbone.named_parameters():
            param.requires_grad = False

    def get_features(self, x):
        latent_features = get_latent_encoding_batch(self.backbone, x, self.layer_name)

        if self.cav_dim == 1:
            latent_features = latent_features.flatten(start_dim=2).max(2).values
        elif self.cav_dim == 3:
            latent_features = latent_features.flatten(start_dim=1)
        
        return latent_features
        
    def forward(self, x):
        raise NotImplementedError()
    
    def default_step(self, x, y, stage):
        y_hat = self(x)
        loss = self.loss(y_hat + 1e6, y)
        self.log_dict(
            {f"{stage}_loss": loss,
             f"{stage}_acc": get_accuracy(y_hat, y),
             f"{stage}_auc": get_auc(y_hat, y),
             f"{stage}_f1": get_f1(y_hat, y),
             },
            prog_bar=True,
            sync_dist=True,
        )
        return loss
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        loss = self.default_step(x, y, stage="train")
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        self.default_step(x, y, stage="valid")

    def test_step(self, batch, batch_idx):
        x, y = batch
        self.default_step(x, y, stage="test")
    
    def set_optimizer(self, optim):
        self.optim = optim

    def set_loss(self, loss):
        self.loss = loss

    def set_milestones(self, milestones):
        self.milestones = milestones
    
    def configure_optimizers(self):
        sche = torch.optim.lr_scheduler.MultiStepLR(optimizer=self.optim, milestones=self.milestones, gamma=0.1)
        scheduler = {
            "scheduler": sche,
            "name": "lr_history",
        }

        return [self.optim], [scheduler]
    
    def state_dict(self, **kwargs):
        return {**self.backbone.state_dict(), **self.classifier.state_dict()}
    
class PHCBModel(BasicModel):
    def __init__(self, backbone, concept_bank, n_classes, config, **kwargs):
        super().__init__(backbone, n_classes, config)

        self.concept_bank = concept_bank
        self.concept_matrix = torch.stack([c for _, c in concept_bank.items()])
        self.classifier = torch.nn.Linear(len(self.concept_matrix), self.n_classes, bias=True)

    def compute_concept_embedding(self, latent_features):
        return latent_features @ self.concept_matrix.T.to(latent_features.device)

    def forward(self, x):
        latent_features = self.get_features(x)
        concept_embedding = self.compute_concept_embedding(latent_features)
        out = self.classifier(concept_embedding)
        return out
    
class PHCBHModel(BasicModel):
    def __init__(self, phcb_model, config):
        super().__init__(phcb_model.backbone, phcb_model.n_classes, config)
        self.phcb_model = phcb_model
        self.residual_classifier = torch.nn.Linear(self.phcb_model.concept_matrix.shape[1], 
                                                   self.phcb_model.n_classes, bias=True)

    def forward(self, x):
        latent_features = self.phcb_model.get_features(x)
        concept_embedding = self.phcb_model.compute_concept_embedding(latent_features)
        out = self.phcb_model.classifier(concept_embedding) + self.residual_classifier(latent_features)
        return out

    def state_dict(self, **kwargs):
        return {**self.phcb_model.backbone.state_dict(), 
                **self.phcb_model.classifier.state_dict(),
                **self.residual_classifier.state_dict()}

class LinearProbeModel(BasicModel):
    def __init__(self, backbone, n_classes, config, **kwargs):
        super().__init__(backbone, n_classes, config)
        default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
        device = config.get('device', default_device)
        dummy_emb = self.get_features(torch.rand(1,3,config["img_size"], config["img_size"]).to(device))
        self.classifier = torch.nn.Linear(dummy_emb.shape[1], self.n_classes, bias=True)

    def forward(self, x):
        latent_features = self.get_features(x)
        out = self.classifier(latent_features)
        return out
    