# MOMEMTO
from math import ceil
from typing import Optional, Tuple
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

# MOMENT
from momentfm import MOMENTPipeline
from momentfm.utils.masking import Masking

# Patch-based memory module
from .Patch_based_memory_module import MemoryModule as PMM

class MOMEMTO(nn.Module):
    def __init__(self, freeze_enc: bool = False, top_k: int = 3):
        super().__init__()
        self.device = "cuda"
        self.enc = MOMENTPipeline.from_pretrained(
            "AutonLab/MOMENT-1-large",
            model_kwargs={
                "task_name": "embedding",
                "freeze_encoder": freeze_enc,
                "freeze_embedder": freeze_enc,
            },
        )
        self.enc.init()

        # Replace the default embed method with a patch-level embedding method
        setattr(self.enc, "embed", self.new_embed.__get__(self.enc))

        # Try to enable gradient checkpointing on the encoder to save memory
        if hasattr(self.enc, "encoder") and hasattr(self.enc.encoder, "gradient_checkpointing_enable"):
            try:
                self.enc.encoder.gradient_checkpointing_enable(
                    gradient_checkpointing_kwargs={"use_reentrant": False}
                )
            except TypeError:
                try:
                    self.enc.encoder.gradient_checkpointing_enable(use_reentrant=False)
                except TypeError:
                    if hasattr(self.enc.encoder, "gradient_checkpointing_disable"):
                        self.enc.encoder.gradient_checkpointing_disable()

        self.enc = self.enc.to(self.device).float()
        self.MemoryModule = PMM(top_k = top_k).to(self.device)
        
        self.initial_memory: Optional[torch.Tensor] = None  
        self.updated_M: Optional[torch.Tensor] = None      
        self.Trainloader: Optional[DataLoader] = None
        self.num_groups: Optional[int] = None
        
    @torch.no_grad()
    def init(self, 
             train_data: torch.Tensor, 
             train_mask: torch.Tensor, 
             train_domain: torch.Tensor, 
             ratio: float = 0.1,
        ) -> None:

        # DataLoader for subsequent training
        self.num_domains = len(torch.unique(train_domain, sorted=True))
        train_dataset = TensorDataset(train_data, train_mask, train_domain)
        self.Trainloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
        print(f"# of domains : {self.num_domains}")

        # Build initial memory
        print("Memory initialization...")
        gen = torch.Generator()
        self.enc.eval()
        
        initial_memory_lst = []
        for dn in range(self.num_domains):
            domain_idx = torch.nonzero(train_domain == dn, as_tuple=False).flatten()
            n_total = domain_idx.numel()
            min_sample = max(1, ceil(n_total * ratio))

            perm_idx = torch.randperm(n_total, generator=gen)[:min_sample]
            sub_idx = domain_idx[perm_idx]

            output = self.enc(x_enc = train_data[sub_idx], input_mask = train_mask[sub_idx])
            output = output.squeeze(1)              
            output = F.normalize(output, dim=-1)    
            domain_mean = output.mean(dim=0)        
            initial_memory_lst.append(domain_mean)

        self.initial_memory = torch.stack(initial_memory_lst, dim=0).contiguous().to(self.device)
        print(f"Initial memory: {tuple(self.initial_memory.shape)}")


    def fit(self, 
            train_data: torch.Tensor, 
            train_mask: torch.Tensor, 
            train_domain: torch.Tensor, 
            epochs: int = 2, 
            lr: float = 1e-4,
        ) -> None:
        if self.initial_memory is None or self.Trainloader is None:
            self.init(train_data, train_mask, train_domain)
        
        print("Train...")
        params = list(self.enc.parameters()) + list(self.MemoryModule.parameters())
        optimizer = torch.optim.Adam(params, lr=lr)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=lr, total_steps=epochs * len(self.Trainloader)
        )

        M = self.initial_memory.contiguous().clone().to(self.device)
        self.enc.train()
        for epoch in range(epochs):
            epoch_loss, num_batches = 0.0, 0

            for batch_x, batch_mask, batch_gn in tqdm(self.Trainloader):
                Q = self.enc(x_enc = batch_x, input_mask = batch_mask).squeeze(1)  
                patch_masks = Masking.convert_seq_to_patch_view(batch_mask)
                outputs, M = self.MemoryModule(Q, M, patch_masks)

                loss = self.calculate_loss(outputs, batch_x.squeeze(1), batch_mask)
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                
                optimizer.step()
                scheduler.step()
                epoch_loss += loss.item()
                num_batches += 1

            print(f"Epoch {epoch+1}, Loss = {epoch_loss / max(1, num_batches):.4f}")

        self.updated_M = M
    
    @torch.no_grad()
    def pred(self, 
             test_data: torch.Tensor, 
             test_mask: torch.Tensor,
        ) -> Tuple[torch.Tensor, torch.Tensor]:

        assert self.updated_M is not None, "Call fit(...) before pred(...)."

        self.enc.eval()
        test_patch_mask = Masking.convert_seq_to_patch_view(test_mask)
        test_Q = self.enc(x_enc=test_data, input_mask=test_mask).squeeze(1)

        recon, mse = self.MemoryModule.pred(
            Q = test_Q,
            M = self.updated_M,
            test_data = test_data,
            test_mask = test_mask,
            patch_mask = test_patch_mask
        )
        return recon, mse

    def new_embed(
        self,
        *,
        x_enc: torch.Tensor,
        input_mask: Optional[torch.Tensor] = None,
        reduction: str = "mean",
        **kwargs,
    ) -> torch.Tensor:
        batch_size, n_channels, seq_len = x_enc.shape

        if input_mask is None:
            input_mask = torch.ones((batch_size, seq_len), device=x_enc.device)
        
        x_enc = self.normalizer(x=x_enc, mask=input_mask, mode="norm")
        x_enc = torch.nan_to_num(x_enc, nan=0.0, posinf=0.0, neginf=0.0)

        x_enc = self.tokenizer(x=x_enc)
        enc_in = self.patch_embedding(x_enc, mask=input_mask)  

        n_patches = enc_in.shape[2]
        enc_in = enc_in.reshape((batch_size * n_channels, n_patches, self.config.d_model))

        patch_view_mask = Masking.convert_seq_to_patch_view(input_mask)
        attention_mask = patch_view_mask.repeat_interleave(n_channels, dim=0)

        outputs = self.encoder(inputs_embeds=enc_in, attention_mask=attention_mask)
        enc_out = outputs.last_hidden_state  
        enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))

        return enc_out

    def calculate_loss(self, outputs, targets, mask):
        loss = (outputs - targets) ** 2
        masked = loss * mask
        tmp = masked.sum(dim=1) / (mask.sum(dim=1) + 1e-8)
        return tmp.mean()

