import logging

import numpy as np
from sklearn.preprocessing import StandardScaler
import torch
from torch import nn
from tqdm import tqdm
import wandb

from text_ood.method import Method
from ood_core import logmeanexp


logger = logging.getLogger(__name__)

import logging

import numpy as np
from sklearn.preprocessing import StandardScaler
import torch
from torch import nn
from tqdm import tqdm
import wandb

from text_ood.method import Method
from ood_core import HopfieldOODDetector


logger = logging.getLogger(__name__)

class HopfieldClassifierMethod(Method):
    requires_aux: bool = True
    requires_additional_data: bool = False
    
    def __init__(
        self,
        n_heads,
        n_queries,
        n_features,
        n_steps,
        batch_size,
        detector,
        beta=None,
        optimizer=None,
        log_training=True,
        device='cpu',
        similarity='dot',
        use_scheduler=True,
        lamb=1.,
        average_dimensions=True,
        init_averaging_params=False,
        share_params=False,
        standardize_inputs=False,
        init_std=1.,
        inference_method='mean',
        orthogonalize_params=False,
        learn_betas=False,
        apply_normalizer=False,
    ):
        super(HopfieldClassifierMethod, self).__init__()
        self.device = device
        self.n_heads = n_heads
        self.n_queries = n_queries
        self.n_features = n_features
        self.n_steps = n_steps
        self.batch_size = batch_size
        self.lamb = lamb
        self.average_dimensions = average_dimensions
        self.standardize_inputs = standardize_inputs
        self.inference_method = inference_method
        self.orthogonalize_params = orthogonalize_params
        self.apply_normalizer = apply_normalizer

        self.detector = detector(
            feature_dim=n_features,
            n_heads=n_heads,
            n_queries=n_queries,
            beta=beta,
            similarity=similarity,
            init_averaging_params=init_averaging_params,
            share_params=share_params,
            init_std=init_std,
            learn_betas=learn_betas,
        )

        self.use_scheduler = use_scheduler
        self.log_training = log_training
        self.to(self.device)
        self.optimizer = optimizer(self.parameters())
    
    def _compute_mean(self, id_dataset):
        loader = torch.utils.data.DataLoader(id_dataset, batch_size=self.batch_size, shuffle=False, drop_last=False, num_workers=16, persistent_workers=True)
        logger.info('Computing mean value')
        with torch.no_grad():
            for i, (embeddings, input_ids, masks, texts) in enumerate(tqdm(loader)):
                embeddings = embeddings.to(self.device).float()
                masks = masks.to(self.device)
                if self.standardize_inputs:
                    embeddings = (embeddings - self.mean) / (self.std + 1e-10)
                self.detector.partial_fit_mean(embeddings, masks)

    @torch.no_grad()
    def fit_standard_scaler(self, id_dataset):
        mean = torch.zeros(self.n_features, device=self.device)
        var = torch.zeros(self.n_features, device=self.device)
        n = 0
        id_loader = torch.utils.data.DataLoader(id_dataset, batch_size=self.batch_size, shuffle=False, drop_last=False, num_workers=16, persistent_workers=True)
        logger.info('Fitting standard scaler')
        for embeddings, input_ids, masks, texts in tqdm(id_loader):
            embeddings = embeddings.to(self.device).float()
            masks = masks.to(self.device)
            embeddings = embeddings[masks.bool()]
            mean += torch.sum(embeddings, dim=0)
            var += torch.sum(embeddings ** 2, dim=0)
            n += embeddings.shape[0]
        mean /= n
        var = var / n - mean ** 2
        std = torch.sqrt(var)
        self.mean = mean
        self.std = std

    def fit(self, id_dataset, aux_dataset):
        self.train()

        if self.standardize_inputs:
            self.fit_standard_scaler(id_dataset)

        optimizer = self.optimizer

        if self.use_scheduler:
            # instantiate CosineAnnealing scheduler. The number of steps is ``len(loader) * n_epochs``.
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.n_steps)
        else:
            scheduler = None

        logger.info('Fitting parameters')

        if len(aux_dataset) < len(id_dataset):
            # if the auxiliary dataset is smaller than the id dataset, we need to repeat the auxiliary dataset
            # so that it has at least the same size as the id dataset
            aux_dataset = torch.utils.data.ConcatDataset([aux_dataset] * (len(id_dataset) // len(aux_dataset) + 1))

        id_loader = torch.utils.data.DataLoader(id_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=8, persistent_workers=True)
        aux_loader = torch.utils.data.DataLoader(aux_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=8, persistent_workers=True)

        if self.orthogonalize_params:
            self.detector.orthogonalize()

        total_steps = 0
        pbar = tqdm(total=self.n_steps)
        while total_steps < self.n_steps:
            for (id_embeddings, _, id_masks, _), (aux_embeddings, _, aux_masks, _) in zip(id_loader, aux_loader):
                if total_steps >= self.n_steps:
                    break
                id_embeddings = id_embeddings.to(self.device).float()  # B, S, F
                id_masks = id_masks.to(self.device)  # B, S
                aux_embeddings = aux_embeddings.to(self.device).float()
                aux_masks = aux_masks.to(self.device)

                if self.standardize_inputs:
                    id_embeddings = (id_embeddings - self.mean) / (self.std + 1e-10)
                    aux_embeddings = (aux_embeddings - self.mean) / (self.std + 1e-10)

                embeddings = torch.concat([id_embeddings, aux_embeddings], dim=0)
                masks = torch.concat([id_masks, aux_masks], dim=0)
                y = torch.concat([torch.ones(len(id_embeddings)), torch.zeros(len(aux_embeddings))], dim=0).int()

                scores = self.detector(embeddings, masks, use_for_mean=y.bool())  # B
                
                if not self.average_dimensions:
                    scores = scores * self.detector.n_heads  # detector averages over dimensions

                id_scores = scores[:len(id_embeddings)]
                aux_scores = scores[len(id_embeddings):]
                id_loss = torch.mean(id_scores, dim=0)
                aux_loss = torch.mean(-torch.log(1-torch.exp(-aux_scores)), dim=0)

                total_loss = id_loss + self.lamb * aux_loss
                
                if self.apply_normalizer:
                    if self.average_dimensions:
                        total_loss -= torch.mean(self.detector.normalizer())
                    else:
                        total_loss -= torch.sum(self.detector.normalizer())

                if self.log_training:
                    wandb.log({
                        'total_loss': total_loss,
                        'ood_scores': scores,
                        'lr': optimizer.param_groups[0]['lr'],
                    })

                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()
                if scheduler:
                    scheduler.step()
                if self.orthogonalize_params:
                    self.detector.orthogonalize()
                total_steps = total_steps + 1
                pbar.update()

        self._compute_mean(id_dataset)

    @torch.no_grad()
    def predict(self, embeddings, input_ids, masks, texts):
        self.eval()
        self.detector.eval()

        if self.standardize_inputs:
            embeddings = (embeddings - self.mean) / (self.std + 1e-10)

        if self.inference_method == 'mean':
            score = self.detector(embeddings, masks)
            return -score.cpu()
        elif self.inference_method == 'max':
            self.detector.return_features = True
            features = self.detector(embeddings, masks)
            squared_distances = features**2
            normalizer = self.detector.normalizer()
            scores = squared_distances - normalizer.unsqueeze(0)
            
            # when normalizer is infinity, ignore the entry
            scores = torch.where(torch.isinf(normalizer), -torch.inf, scores)
            
            scores, _ = torch.max(scores, dim=-1)
            self.detector.return_features = False
            return -scores.cpu()
