import os
import torch
import torch.nn as nn
from pathlib import Path
from datetime import datetime

try:
    from ..omp import omp_v0
except ImportError:
    # Fallback for direct execution
    import sys
    from pathlib import Path
    sys.path.append(str(Path(__file__).parent.parent))
    from omp import omp_v0

class Autoencoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.L = cfg.get("num_hidden_layers", 1) * 2  # Use config with fallback
        # Handle feature_dim availability
        if "feature_dim" in cfg:
            self.m = cfg["feature_dim"]
        elif "head_dim" in cfg:
            self.m = cfg["head_dim"]
        else:
            raise KeyError("Either 'feature_dim' or 'head_dim' must be specified in config")
        
        self.n = cfg["dictionary_size"]
        self.s = cfg["sparsity"]

        self.D = nn.Parameter(nn.init.kaiming_uniform_(torch.empty(self.L, self.m, self.n)))
        self.D.data[:] = self.D / self.D.norm(dim=-2, keepdim=True)
        self.to(cfg["device"])

    @torch.no_grad()
    def encode(self, k):
        sparsity = self.s
        DTD = torch.bmm(self.D.permute(0, 2, 1), self.D)
        indices, values, _, _, _, _ = omp_v0(self.D, k.transpose(0, 1), DTD, sparsity)
        last_atom_indices = indices[:, :, -1] # shape: [L, batch_size]

        y = torch.zeros((self.L, k.size(0), self.n), device=self.cfg["device"])
        y.scatter_(-1, indices.to(torch.int64), values.squeeze(-1))
        return y.transpose(0, 1)

    def decode(self, y):
        return torch.einsum('lmn,bln->blm', self.D, y)

    def forward(self, k):
        y = self.encode(k)
        k_hat = self.decode(y)
        recon_loss = torch.mean((k_hat - k) ** 2)
        decorrelation_loss = 0
        
        # Return tuple format for consistency with AutoencoderCustom
        loss = (recon_loss, decorrelation_loss)
        return loss, k_hat, y
    
    @torch.no_grad()
    def normalise_decoder_weights(self):
        D_normalised = self.D / self.D.norm(dim=-2, keepdim=True)
        D_grad_proj = (self.D.grad * D_normalised).sum(-2, keepdim=True) * D_normalised
        self.D.grad -= D_grad_proj
        self.D.data = D_normalised
    
    def save(self, save_dir=None):
        if save_dir:
            save_dir = Path(save_dir)
        else:
            # Handle both dict and ConfigManager
            if hasattr(self.cfg, 'get_path'):
                save_dir = self.cfg.get_path('checkpoint')
            else:
                save_dir = Path(self.cfg.get('checkpoint_dir', 'checkpoints'))
        
        save_dir.mkdir(parents=True, exist_ok=True)
        
        # Get model name
        if hasattr(self.cfg, 'get_model_name'):
            model_name = self.cfg.get_model_name()
        else:
            model_name = self.cfg.get('name', 'model')
            
        model_path = save_dir / f"{model_name}.pt"
        torch.save(self.state_dict(), model_path)
    
    def save_dictionary(self, epoch=None, dict_save_dir=None):
        if dict_save_dir:
            dict_save_dir = Path(dict_save_dir)
        else:
            # Handle both dict and ConfigManager
            if hasattr(self.cfg, 'get_path'):
                dict_save_dir = self.cfg.get_path('dictionary')
            else:
                sparsity = self.cfg.get('sparsity', 256)
                dict_save_dir = Path(self.cfg.get('dictionary_dir', f'dictionaries_s{sparsity}'))
        
        dict_save_dir.mkdir(parents=True, exist_ok=True)
        
        # Get model name
        if hasattr(self.cfg, 'get_model_name'):
            model_name = self.cfg.get_model_name()
        else:
            model_name = self.cfg.get('name', 'model')
        
        # Include epoch in filename if provided
        if epoch is not None:
            D_path = dict_save_dir / f"{model_name}_{epoch}epoch.pt"
        else:
            D_path = dict_save_dir / f"{model_name}.pt"
            
        torch.save(self.D.detach().cpu(), D_path)
    
    @classmethod
    def load(cls, model_path, cfg):
        model = cls(cfg=cfg)
        model.load_state_dict(torch.load(model_path))
        return model


class AutoencoderCustom(Autoencoder):
    def __init__(self, cfg):
        # Call parent constructor, but override some settings
        super().__init__(cfg)
        # Override L to be fixed at 2 for this custom version
        self.L = 2
        
        # Reinitialize D with the new L value
        self.D = nn.Parameter(nn.init.kaiming_uniform_(torch.empty(self.L, self.m, self.n)))
        self.D.data[:] = self.D / self.D.norm(dim=-2, keepdim=True)
        self.to(cfg["device"])

    def sample_sparsity_power_law(self, min_s, max_s, power=3.0):
        """
        거듭제곱 변환을 사용해 left-skewed sparsity를 샘플링합니다.
        power > 1 이고, 클수록 왼쪽으로 더 많이 치우칩니다.
        """
        # [0, 1] 균등 분포에서 샘플링 후 거듭제곱
        sample_01 = torch.pow(torch.rand(1), power)
        
        # [min_s, max_s] 범위로 스케일링 후 정수화
        sparsity = torch.round(min_s + (max_s - min_s) * sample_01).int().item()
        
        # 범위 보장
        return max(min_s, min(sparsity, max_s))
    
    @torch.no_grad()
    def encode(self, k):
        sparsity = self.sample_sparsity_power_law(1, self.s, 2)
        
        DTD = torch.bmm(self.D.permute(0, 2, 1), self.D)
        indices, values, _, _, _, _ = omp_v0(self.D, k.transpose(0, 1), DTD, sparsity)
        last_atom_indices = indices[:, :, -1] # shape: [L, batch_size]

        y = torch.zeros((self.L, k.size(0), self.n), device=self.cfg["device"])
        y.scatter_(-1, indices.to(torch.int64), values.squeeze(-1))
        
        return y.transpose(0, 1), last_atom_indices.transpose(0, 1)
        

    def forward(self, k):
        y = self.encode(k)
        k_hat = self.decode(y)
            
        if self.cfg["use_norm"]:
            feat_weights = 1 / torch.std(k, dim=-1)
            recon_loss = torch.mean(feat_weights.unsqueeze(-1) * (k_hat - k) ** 2)
        else:
            recon_loss = torch.mean((k_hat - k) ** 2)

        decorrelation_loss = 0

        loss = (recon_loss, decorrelation_loss)
        return loss, k_hat, y
        

    def save_dictionary(self, epoch=None, dict_save_dir=None):
        if dict_save_dir:
            dict_save_dir = Path(dict_save_dir)
        else:
            # Handle both dict and ConfigManager
            if hasattr(self.cfg, 'get_path'):
                dict_save_dir = self.cfg.get_path('dictionary')
            else:
                dict_save_dir = Path(self.cfg.get('dictionary_dir', f'dictionaries_s{self.s}'))
        
        dict_save_dir.mkdir(parents=True, exist_ok=True)
        
        # Get model name
        if hasattr(self.cfg, 'get_model_name'):
            model_name = self.cfg.get_model_name()
        else:
            model_name = self.cfg.get('name', 'model')
        
        if epoch is not None:
            D_path = dict_save_dir / f"{model_name}_{epoch}epoch.pt"
        else:
            D_path = dict_save_dir / f"{model_name}.pt"
            
        torch.save(self.D.detach().cpu(), D_path)
        print(f"Dictionary saved to {D_path}")
