import logging

import numpy as np
from sklearn.preprocessing import StandardScaler
import torch
from torch import nn
from tqdm import tqdm
import wandb

from text_ood.method import Method
from ood_core import HopfieldOODDetector


logger = logging.getLogger(__name__)

class HopfieldMethod(Method):
    requires_aux: bool = False
    requires_additional_data: bool = False
    
    def __init__(
        self,
        n_heads,
        n_queries,
        n_features,
        n_steps,
        batch_size,
        detector,
        beta=None,
        optimizer=None,
        log_training=True,
        device='cpu',
        similarity='dot',
        use_scheduler=True,
        normalization=False,
        init_averaging_params=False,
        share_params=False,
        orthogonalize_params=False,
        standardize_inputs=False,
        init_std=1.,
        inference_method='mean',
        learn_betas=False
    ):
        super(HopfieldMethod, self).__init__()
        self.device = device
        self.n_heads = n_heads
        self.n_queries = n_queries
        self.n_features = n_features
        self.n_steps = n_steps
        self.batch_size = batch_size
        self.standardize_inputs = standardize_inputs
        self.inference_method = inference_method

        self.detector = detector(
            feature_dim=n_features,
            n_heads=n_heads,
            n_queries=n_queries,
            beta=beta,
            similarity=similarity,
            normalization=normalization,
            init_averaging_params=init_averaging_params,
            share_params=share_params,
            init_std=init_std,
            learn_betas=learn_betas,
        )

        self.use_scheduler = use_scheduler
        self.orthogonalize_params = orthogonalize_params
        self.log_training = log_training
        self.to(self.device)
        self.optimizer = optimizer(self.parameters())
    
    def _compute_mean(self, id_dataset):
        loader = torch.utils.data.DataLoader(id_dataset, batch_size=self.batch_size, shuffle=False, drop_last=False, num_workers=16, persistent_workers=True)
        logger.info('Computing mean value')
        with torch.no_grad():
            for i, (embeddings, input_ids, masks, texts) in enumerate(tqdm(loader)):
                embeddings = embeddings.to(self.device).float()
                masks = masks.to(self.device)
                if self.standardize_inputs:
                    embeddings = (embeddings - self.mean) / (self.std + 1e-10)
                self.detector.partial_fit_mean(embeddings, masks)

    @torch.no_grad()
    def fit_standard_scaler(self, id_dataset):
        mean = torch.zeros(self.n_features, device=self.device)
        var = torch.zeros(self.n_features, device=self.device)
        n = 0
        id_loader = torch.utils.data.DataLoader(id_dataset, batch_size=self.batch_size, shuffle=False, drop_last=False, num_workers=16, persistent_workers=True)
        logger.info('Fitting standard scaler')
        for embeddings, input_ids, masks, texts in tqdm(id_loader):
            embeddings = embeddings.to(self.device).float()
            masks = masks.to(self.device)
            embeddings = embeddings[masks.bool()]
            mean += torch.sum(embeddings, dim=0)
            var += torch.sum(embeddings ** 2, dim=0)
            n += embeddings.shape[0]
        mean /= n
        var = var / n - mean ** 2
        std = torch.sqrt(var)
        self.mean = mean
        self.std = std

    def fit(self, id_dataset):
        self.train()

        if self.standardize_inputs:
            self.fit_standard_scaler(id_dataset)

        optimizer = self.optimizer

        if self.use_scheduler:
            # instantiate CosineAnnealing scheduler. The number of steps is ``len(loader) * n_epochs``.
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.n_steps)
        else:
            scheduler = None
        
        logger.info('Fitting parameters')

        id_loader = torch.utils.data.DataLoader(id_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=16, persistent_workers=True)

        if self.orthogonalize_params:
            self.detector.orthogonalize()

        total_steps = 0
        pbar = tqdm(total=self.n_steps)
        with torch.autograd.set_detect_anomaly(True):

            while total_steps < self.n_steps:
                for embeddings, input_ids, masks, texts in id_loader:
                    if total_steps >= self.n_steps:
                        break
                    embeddings = embeddings.to(self.device).float()  # B, S, F
                    masks = masks.to(self.device)  # B, S

                    if self.standardize_inputs:
                        embeddings = (embeddings - self.mean) / (self.std + 1e-10)

                    ood_scores = self.detector(embeddings, masks)  # B
                    total_loss = torch.mean(ood_scores, dim=0)

                    normalizer = torch.mean(torch.log(torch.einsum('hqf,hqf->h', self.detector.linear_params, self.detector.linear_params)))

                    if self.log_training:
                        wandb.log({
                            'total_loss': total_loss,
                            'ood_scores': ood_scores,
                            'lr': optimizer.param_groups[0]['lr'],
                            'normalizer': normalizer,
                            'betas': self.detector.betas.data,
                        })

                    optimizer.zero_grad()
                    total_loss.backward()
                    optimizer.step()
                    if scheduler:
                        scheduler.step()
                    if self.orthogonalize_params:
                        self.detector.orthogonalize()
                    total_steps = total_steps + 1
                    pbar.update()

        self._compute_mean(id_dataset)

    @torch.no_grad()
    def predict(self, embeddings, input_ids, masks, texts):
        self.eval()
        self.detector.eval()

        if self.standardize_inputs:
            embeddings = (embeddings - self.mean) / (self.std + 1e-10)

        if self.inference_method == 'mean':
            score = self.detector(embeddings, masks)
            return -score.cpu()
        elif self.inference_method == 'max':
            self.detector.return_features = True
            features = self.detector(embeddings, masks)
            squared_distances = features**2
            normalizer = self.detector.normalizer()
            scores = squared_distances - normalizer.unsqueeze(0)
            scores, _ = torch.max(scores, dim=-1)
            self.detector.return_features = False
            return -scores.cpu()
