import os
import torch
import hydra
from omegaconf import DictConfig
from lightning import LightningModule
from infonet.decoder import Decoder 
from infonet.encoder import Encoder 
from infonet.encoder2 import Encoder2 
from infonet.infonet_new_encoder import InfoNet
from infonet.query import Query_Gen_transformer
from gen_data.train_data import gauss_noise_padding, softrank_preprocessing, softrank_preprocessing_new

# Define LightningWrapper (consistent with history, adapted to reference)
class LightningWrapper(LightningModule):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.encoder = Encoder(
            input_dim_x=cfg.input_dim_x,
            input_dim_y=cfg.input_dim_y,
            latent_num=cfg.latent_num,
            latent_dim=cfg.latent_dim,
            cross_attn_heads=8,
            self_attn_heads=16,
            num_self_attn_per_block=8,  # Adjusted to history
            num_self_attn_blocks=2      # Adjusted to history
        )
        self.decoder = Decoder(
            q_dim=cfg.decoder_query_dim,
            latent_dim=cfg.latent_dim,
        )
        self.query_gen = Query_Gen_transformer(
            input_dim_x=cfg.input_dim_x,
            input_dim_y=cfg.input_dim_y,
            dim=cfg.decoder_query_dim
        )
        self.encoder2 = Encoder2(
            input_dim=cfg.input_dim_x + cfg.input_dim_y,
            hidden_dim=cfg.encoder2_hiddim,
            output_dim=cfg.encoder2_expand_dim,
            num_layers=cfg.encoder2_block
        )
        self.model = InfoNet(
            encoder=self.encoder,
            decoder=self.decoder,
            encoder2=self.encoder2,
            query_gen=self.query_gen,
            decoder_query_dim=cfg.decoder_query_dim,
            input_dim_x=cfg.input_dim_x,
            input_dim_y=cfg.input_dim_y,
            targetnet_hiddim=cfg.targetnet_hiddim,
            hypermlp_hiddim=cfg.hypermlp_hiddim,
            encoder2_expand_dim=cfg.encoder2_expand_dim,
        )
        self.encoder2_expand_dim = cfg.encoder2_expand_dim

    def forward(self, x):
        return self.model(x.squeeze(1), early_sup=False)

def load_model_from_checkpoint(checkpoint_path: str, config_path: str, config_name: str, device: str = 'cpu'):
    
    # Check checkpoint file
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file {checkpoint_path} does not exist.")

    # Check config path
    config_full_path = os.path.join(config_path, f"{config_name}.yaml")
    if not os.path.exists(config_full_path):
        raise FileNotFoundError(f"Config file {config_full_path} does not exist.")

    # Load Hydra configuration
    with hydra.initialize(config_path=config_path, version_base='1.1'):
        cfg = hydra.compose(config_name=config_name)

    # Set global variables (consistent with history)
    global batchsize, latent_dim, latent_num, input_dim_x, input_dim_y, seq_len
    global decoder_query_dim, num_per_epoch, gpu_card, learning_rate, softrank_reg
    global lr_decay_step, targetnet_hiddim, hypermlp_hiddim, encoder2_expand_dim
    global encoder2_hiddim, encoder2_block

    batchsize = cfg.batchsize
    latent_dim = cfg.latent_dim
    latent_num = cfg.latent_num
    input_dim_x = cfg.input_dim_x
    input_dim_y = cfg.input_dim_y
    seq_len = cfg.seq_len
    decoder_query_dim = cfg.decoder_query_dim
    num_per_epoch = cfg.get('num_per_epoch', 1000)  # Default if not in cfg
    gpu_card = cfg.get('gpu_card', 0)               # Default if not in cfg
    learning_rate = cfg.get('learning_rate', 1e-4)   # Default if not in cfg
    softrank_reg = cfg.softrank_reg
    lr_decay_step = cfg.get('lr_decay_step', 1000)   # Default if not in cfg
    targetnet_hiddim = cfg.targetnet_hiddim
    hypermlp_hiddim = cfg.hypermlp_hiddim
    encoder2_expand_dim = cfg.encoder2_expand_dim
    encoder2_hiddim = cfg.encoder2_hiddim
    encoder2_block = cfg.encoder2_block

    # Initialize model
    #model = LightningWrapper(cfg)

    # Load checkpoint
    model = LightningWrapper.load_from_checkpoint(
        checkpoint_path,
        map_location=device,
        cfg=cfg,
        weights_only=True
    )
    
    # Set eval mode and move to device
    model.eval()
    model.to(device)

    return model

def estimate_mi_xy(X: torch.Tensor, Y: torch.Tensor, model: LightningModule, max_dim: int, softrank_reg: float=1e-3):
    
    # Extract configuration parameters
    device = "cuda"

    # Check input shapes
    if X.shape != Y.shape:
        raise ValueError(f"X and Y must have the same shape, got X: {X.shape}, Y: {Y.shape}")
    if len(X.shape) != 3:
        raise ValueError(f"X and Y must have shape [B, N, d], got {X.shape}")
    B, N, d = X.shape
    if d > max_dim:
        raise ValueError(f"Input dimension d={d} cannot be greater than max_dim={max_dim}")

    # Ensure X and Y are float
    X = X.float()
    Y = Y.float()

    # Preprocessing: Pad to max_dim
    X_padded = gauss_noise_padding(X, aim_dim=max_dim, perm=False)  # [B, N, max_dim]
    Y_padded = gauss_noise_padding(Y, aim_dim=max_dim, perm=False)  # [B, N, max_dim]

    # Concatenate X and Y
    sample_xy = torch.cat([X_padded, Y_padded], dim=-1)  # [B, N, 2*max_dim]

    # Soft rank preprocessing
    sample_xy = softrank_preprocessing_new(sample_xy, regularization_strength=softrank_reg).to(device)  # [B, N, 2*max_dim]

    # Model inference
    with torch.no_grad():
        mi_est = model(sample_xy)  # Assume output is [B] or scalar
        if mi_est.shape[0] == B:
            return mi_est.cpu()  # Return [B]
        else:
            return mi_est.cpu().item()  # Return scalar


def compute_ksmi_mean(X: torch.Tensor, Y: torch.Tensor, projection_dim: int, model: LightningModule, proj_num: int, batchsize: int, max_dim: int = 8, softrank_reg: float = 1e-3, normalize_input: bool = False):
    
    model.eval()
    device = "cuda"
    seq_len, dx = X.shape
    _, dy = Y.shape
    
    if seq_len != Y.shape[0]:
        raise ValueError(f"X and Y must have the same number of samples, got X: {X.shape}, Y: {Y.shape}")

    # 输入归一化（标准化到零均值和单位方差）
    if normalize_input:
        X = (X - X.mean(dim=0, keepdim=True)) / (X.std(dim=0, keepdim=True) + 1e-6)  # 避免除零
        Y = (Y - Y.mean(dim=0, keepdim=True)) / (Y.std(dim=0, keepdim=True) + 1e-6)

    results = []
    
    # 分批处理投影
    for i in range(0, proj_num, batchsize):
        current_batch_size = min(batchsize, proj_num - i)  # 处理最后一批可能不足 batchsize 的情况
        
        # 为批次中的每次投影生成独立的随机投影矩阵
        X_proj_batch = []
        Y_proj_batch = []
        for _ in range(current_batch_size):
            # 生成独立的随机投影矩阵
            proj_matrix_x = torch.randn(dx, projection_dim, device="cpu")  # [dx, projection_dim]
            proj_matrix_y = torch.randn(dy, projection_dim, device="cpu")  # [dy, projection_dim]
            
            # 归一化投影矩阵的列为单位范数
            proj_matrix_x = proj_matrix_x / torch.norm(proj_matrix_x, dim=0, keepdim=True)
            proj_matrix_y = proj_matrix_y / torch.norm(proj_matrix_y, dim=0, keepdim=True)
            
            # 投影 X 和 Y: [N, dx] @ [dx, projection_dim] -> [N, projection_dim]
            X_proj = X @ proj_matrix_x
            Y_proj = Y @ proj_matrix_y
            X_proj_batch.append(X_proj.unsqueeze(0))  # [1, N, projection_dim]
            Y_proj_batch.append(Y_proj.unsqueeze(0))  # [1, N, projection_dim]
        
        # 堆叠形成批次: [current_batch_size, N, projection_dim]
        X_proj_batch = torch.cat(X_proj_batch, dim=0)
        Y_proj_batch = torch.cat(Y_proj_batch, dim=0)
        
        # 使用 estimate_mi_xy 估计批次的互信息
        mi_est = estimate_mi_xy(X_proj_batch, Y_proj_batch, model, max_dim, softrank_reg)
        
        # mi_est 形状为 [current_batch_size] 或标量，追加到 results
        if torch.is_tensor(mi_est) and mi_est.shape[0] == current_batch_size:
            results.append(mi_est)
        else:
            results.append(torch.tensor([mi_est], device=device))

    # 拼接所有结果并计算均值
    results = torch.cat(results) if len(results) > 1 else results[0]
    return torch.mean(results).cpu()



# Example usage
if __name__ == "__main__":
    # Configuration and checkpoint paths
    
    import time
    for i in range(10):
        X = torch.randn(1, 5000, 10)
        start_time = time.time()
        X = softrank_preprocessing(X, 0.001)
        using_time = time.time() - start_time
        print(using_time)