import glob
import logging
import math
import os
import pickle
import time

import warnings

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.autograd as autograd
from lightning.pytorch.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.types import EVAL_DATALOADERS

#from gen_data.train_data_fixed import gen_train_dataset
from gen_data.train_data import gauss_noise_padding, softrank_preprocessing_new, gen_train_dataset_lowdim
from infonet.decoder import Decoder 
from infonet.encoder import Encoder 
from infonet.infonet_new_encoder import InfoNet
from infonet.query import Query_Gen_transformer
from scipy.stats import rankdata
from sklearn.mixture import GaussianMixture
from tensorboardX import SummaryWriter
from torch import nn, optim
from torch.optim import Adam
from functorch import vmap
import lightning
from lightning.pytorch.loggers import TensorBoardLogger
from send2trash import send2trash
import hydra
from omegaconf import DictConfig
from hydra.core.hydra_config import HydraConfig
from torch.optim.lr_scheduler import StepLR
from torch.optim.lr_scheduler import CosineAnnealingLR
import multiprocessing as mp
import random
from sklearn.metrics import roc_curve, auc
#from evaluation import evaluate_gauss_order
warnings.filterwarnings("ignore")
plt.switch_backend('agg')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def estimate_mi_xy(X, Y, model, max_dim, softrank_reg: float=1e-3):
    
    # Extract configuration parameters
    device = "cuda"

    # Check input shapes
    if X.shape != Y.shape:
        raise ValueError(f"X and Y must have the same shape, got X: {X.shape}, Y: {Y.shape}")
    if len(X.shape) != 3:
        raise ValueError(f"X and Y must have shape [B, N, d], got {X.shape}")
    B, N, d = X.shape
    if d > max_dim:
        raise ValueError(f"Input dimension d={d} cannot be greater than max_dim={max_dim}")

    # Ensure X and Y are float
    X = X.float()
    Y = Y.float()

    # Preprocessing: Pad to max_dim
    X_padded = gauss_noise_padding(X, aim_dim=max_dim, perm=False)  # [B, N, max_dim]
    Y_padded = gauss_noise_padding(Y, aim_dim=max_dim, perm=False)  # [B, N, max_dim]

    # Concatenate X and Y
    sample_xy = torch.cat([X_padded, Y_padded], dim=-1)  # [B, N, 2*max_dim]

    # Soft rank preprocessing
    sample_xy = softrank_preprocessing_new(sample_xy, regularization_strength=softrank_reg).to(device)  # [B, N, 2*max_dim]

    # Model inference
    with torch.no_grad():
        mi_est = model(sample_xy)  # Assume output is [B] or scalar
        if mi_est.shape[0] == B:
            return mi_est.cpu()  # Return [B]
        else:
            return mi_est.cpu().item()  # Return scalar

def compute_ksmi_mean(X, Y, projection_dim, model, proj_num: int, batchsize: int, max_dim: int = 5, softrank_reg: float = 1e-3, normalize_input: bool = True):
    
    model.eval()
    device = "cuda"
    seq_len, dx = X.shape
    _, dy = Y.shape
    
    if seq_len != Y.shape[0]:
        raise ValueError(f"X and Y must have the same number of samples, got X: {X.shape}, Y: {Y.shape}")

    # 输入归一化（标准化到零均值和单位方差）
    if normalize_input:
        X = (X - X.mean(dim=0, keepdim=True)) / (X.std(dim=0, keepdim=True) + 1e-6)  # 避免除零
        Y = (Y - Y.mean(dim=0, keepdim=True)) / (Y.std(dim=0, keepdim=True) + 1e-6)

    results = []
    
    # 分批处理投影
    for i in range(0, proj_num, batchsize):
        current_batch_size = min(batchsize, proj_num - i)  # 处理最后一批可能不足 batchsize 的情况
        
        # 为批次中的每次投影生成独立的随机投影矩阵
        X_proj_batch = []
        Y_proj_batch = []
        for _ in range(current_batch_size):
            # 生成独立的随机投影矩阵
            proj_matrix_x = torch.randn(dx, projection_dim, device="cpu")  # [dx, projection_dim]
            proj_matrix_y = torch.randn(dy, projection_dim, device="cpu")  # [dy, projection_dim]
            
            # 归一化投影矩阵的列为单位范数
            proj_matrix_x = proj_matrix_x / torch.norm(proj_matrix_x, dim=0, keepdim=True)
            proj_matrix_y = proj_matrix_y / torch.norm(proj_matrix_y, dim=0, keepdim=True)
            
            # 投影 X 和 Y: [N, dx] @ [dx, projection_dim] -> [N, projection_dim]
            X_proj = X @ proj_matrix_x
            Y_proj = Y @ proj_matrix_y
            X_proj_batch.append(X_proj.unsqueeze(0))  # [1, N, projection_dim]
            Y_proj_batch.append(Y_proj.unsqueeze(0))  # [1, N, projection_dim]
        
        # 堆叠形成批次: [current_batch_size, N, projection_dim]
        X_proj_batch = torch.cat(X_proj_batch, dim=0)
        Y_proj_batch = torch.cat(Y_proj_batch, dim=0)
        
        # 使用 estimate_mi_xy 估计批次的互信息
        mi_est = estimate_mi_xy(X_proj_batch, Y_proj_batch, model, max_dim, softrank_reg)
        
        # mi_est 形状为 [current_batch_size] 或标量，追加到 results
        if torch.is_tensor(mi_est) and mi_est.shape[0] == current_batch_size:
            results.append(mi_est)
        else:
            results.append(torch.tensor([mi_est], device=device))

    # 拼接所有结果并计算均值
    results = torch.cat(results) if len(results) > 1 else results[0]
    return torch.mean(results).cpu()

def evaluate_gauss_order(module, dim, max_dim, seq_len, number_test, training_step, perm=False):
    data_root_path = "evaluate_data"
    data_save_cat_path = os.path.join(data_root_path, f"dim-{dim}")
    real_mi_xy = np.load(os.path.join(data_save_cat_path, "mi_xy.npy"))
    real_mi_xz = np.load(os.path.join(data_save_cat_path, "mi_xz.npy"))
    sample_save_path = os.path.join(data_save_cat_path, "samples")

    estimate_mi_xy = []
    estimate_mi_xz = []
    accs = []

    for i in range(number_test):
        sample = np.load(os.path.join(sample_save_path, f"seqlen-{seq_len}-No-{i}.npy"))
        
        x = torch.from_numpy(sample[:, :, 0:dim]).float()
        y = torch.from_numpy(sample[:, :, dim:2*dim]).float()
        z = torch.from_numpy(sample[:, :, 2*dim:3*dim])
        x = gauss_noise_padding(x, aim_dim=max_dim, perm=perm)
        y = gauss_noise_padding(y, aim_dim=max_dim, perm=perm)
        z = gauss_noise_padding(z, aim_dim=max_dim, perm=perm)
        sample_xy = (torch.cat([x, y], axis=-1))
        sample_xz = (torch.cat([x, z], axis=-1))
        sample_xy = (softrank_preprocessing_new(sample_xy, regularization_strength=softrank_reg)).to(device)
        sample_xz = (softrank_preprocessing_new(sample_xz, regularization_strength=softrank_reg)).to(device)

        est_mi_xy = module(sample_xy, early_sup=False)[0].cpu().numpy().squeeze()
        est_mi_xz = module(sample_xz, early_sup=False)[0].cpu().numpy().squeeze()
        estimate_mi_xy.append(est_mi_xy)
        estimate_mi_xz.append(est_mi_xz)

        if (est_mi_xy>est_mi_xz and real_mi_xy[i]> real_mi_xz[i]) or (est_mi_xy<=est_mi_xz and real_mi_xy[i]<= real_mi_xz[i]):
            accs.append(1)
        else:
            accs.append(0)
    
    acc_rate = np.mean(np.array(accs))

    save_text_root_path = os.path.join(logger.log_dir, "high_dim_gauss_order")
    text_path = os.path.join(save_text_root_path, f"dim-{dim}")
    if not os.path.exists(text_path):
        os.makedirs(text_path, exist_ok=True)
    with open(os.path.join(text_path, f"model-firsttry-{seq_len}.txt"), "a") as f:
        output_string = f"order test acc on {dim}-dimension gauss of model_{iter} is {acc_rate}\n"
        f.write(output_string)
        print(output_string)

    fig_root_path = os.path.join(logger.log_dir, 'figure')
    os.makedirs(fig_root_path, exist_ok=True)
    fig_save_cat_path = os.path.join(fig_root_path, f"{dim}")
    os.makedirs(fig_save_cat_path, exist_ok=True)

    plt.style.use("ggplot")
    fig = plt.figure(figsize=(14, 10))

    ax1 = fig.add_subplot(121)
    d_draw = np.arange(number_test)

    zipped_lists = sorted(zip(real_mi_xy, estimate_mi_xy), reverse=False)
    sorted_list1, sorted_list2 = zip(*zipped_lists)
    ax1.plot(d_draw[:50], real_mi_xy[:50], color="red", lw=2, ls="-", label="real mutual information XY", markersize=10)
    ax1.plot(d_draw[:50], estimate_mi_xy[:50], color="blue", lw=2, ls="-", label="estimate MI XY", markersize=10)
    ax1.set_xlabel("# experiment times", fontweight="bold", fontsize=20)
    ax1.set_ylabel(" mutual information ", fontweight="bold", fontsize=20)
    # ax1.tick_params(axis='x', which='major', labelsize=14, width=2, length=6)
    # ax1.tick_params(axis='y', which='major', labelsize=14, width=2, length=6)
    ax1.legend(fontsize=20, loc="upper left")

    ax2 = fig.add_subplot(122)
    ax2.plot(d_draw, sorted_list1, color="red", lw=2, ls="-", label="real mutual information XY", markersize=10)
    ax2.plot(d_draw, sorted_list2, color="blue", lw=2, ls="-", label="estimate MI XY", markersize=10)
    ax2.set_xlabel("# experiment times", fontweight="bold", fontsize=20)
    ax2.set_ylabel(" mutual information ", fontweight="bold", fontsize=20)
    # ax1.tick_params(axis='x', which='major', labelsize=14, width=2, length=6)
    # ax1.tick_params(axis='y', which='major', labelsize=14, width=2, length=6)
    ax2.legend(fontsize=20, loc="upper left")

    image_save_path = os.path.join(fig_save_cat_path, f"gauss-dim{dim}-step{training_step}.png")
    plt.savefig(image_save_path)
    plt.close(fig)

    return acc_rate

class InfoNetDataset(torch.utils.data.Dataset):
    def __init__(self, total_epoch, device='cpu'):
        super().__init__()
        self.total_epoch = total_epoch
        self.seq_len = seq_len
        self.dim = input_dim_x
        self.softrank_reg = softrank_reg
        self.device = device
        print(f'init data set with device {self.device}, and total epoch {total_epoch}')

    def __len__(self):
        return self.total_epoch
    
    @torch.no_grad()
    def __getitem__(self, idx):
        #reg_strength = random.uniform(1e-5, 1e-3)
        res = gen_train_dataset_lowdim(batchsize=1, seq_len=self.seq_len, dim=self.dim, regularization_strength=self.softrank_reg, device=self.device)
        return res

class LightningWrapper(lightning.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder(
            input_dim_x=input_dim_x,
            input_dim_y=input_dim_y,
            latent_num=latent_num,
            latent_dim=latent_dim,
            cross_attn_heads=8,
            self_attn_heads=16,
            num_self_attn_per_block=8,
            num_self_attn_blocks=2
        )

        self.decoder = Decoder(
            q_dim=decoder_query_dim,
            latent_dim=latent_dim,
        )

        self.query_gen = Query_Gen_transformer(
            input_dim_x=input_dim_x,
            input_dim_y=input_dim_y,
            dim=decoder_query_dim
        )
        '''
        self.encoder2 = Encoder2(
            input_dim = input_dim_x+input_dim_y,                 
            hidden_dim = encoder2_hiddim,
            output_dim = encoder2_expand_dim,
            num_layers = encoder2_block     
        )
        '''
        self.model = InfoNet(
            encoder=self.encoder, 
            decoder=self.decoder, 
            #encoder2=self.encoder2,
            query_gen=self.query_gen,
            decoder_query_dim=decoder_query_dim, 
            input_dim_x=input_dim_x, 
            input_dim_y=input_dim_y,
            targetnet_hiddim = targetnet_hiddim,
            hypermlp_hiddim = hypermlp_hiddim,
            #encoder2_expand_dim = encoder2_expand_dim,
        )
        
        self.lr_decay_step = lr_decay_step
        #self.encoder2_expand_dim = encoder2_expand_dim

    def forward(self, x):
        return self.model(x.squeeze(1), early_sup=False)
        #return vmap(self.model.forward, in_dims=0, randomness='different')(x, early_sup=False)
        # return self.model(x[0])

    def training_step(self, batch, batch_idx):    
        mi_lb = self(batch)
        loss = -torch.mean(mi_lb)
        self.logger.experiment.add_scalars('train_loss', 
                                           {f'x-dim{input_dim_x}-y-dim{input_dim_y}': loss,
                                            },
                                           global_step=self.global_step)
        self.log('loss', loss.item(), on_step=True, prog_bar=True, sync_dist=True, batch_size=batchsize)
        return loss
        # loss = torch.tensor(0.0, device=self.device)
        # loss.requires_grad = True

    def validation_step(self, batch, batch_idx):
        
        for dim in [5]:
            eval_acc = evaluate_gauss_order(self.model, dim=dim, max_dim=input_dim_x, seq_len=seq_len, number_test=250, training_step=self.global_step + 1)
            #bmi_bias = evaluate_bmi_inde(self.model, dim=dim, max_dim=input_dim_x, seq_len=seq_len, number_test=100, training_step=self.global_step + 1, softrank_reg=softrank_reg)
            self.logger.experiment.add_scalars('indv_val_acc', 
                                    {f'dim{dim}': eval_acc}, 
                                    global_step=self.global_step)
        
        return eval_acc

    def test_step(self, batch, batch_idx):
        print("=============== begin validation")
        for dim in [5]:
            #eval_acc = evaluate_gauss_order(self.model, dim=dim, max_dim=input_dim_x, seq_len=seq_len, number_test=250, training_step=self.global_step + 1)
            #bmi_bias = evaluate_bmi_inde(self.model, dim=dim, max_dim=input_dim_x, seq_len=seq_len, number_test=100, training_step=self.global_step + 1, softrank_reg=cfg.softrank_reg)
            #self.logger.experiment.add_scalars('indv_val_acc', {f'dim{dim}': bmi_bias}, global_step=self.global_step)
            bmi_bias=0
        return bmi_bias

    
    def on_load_checkpoint(self, checkpoint):
        if "optimizer_states" in checkpoint:
            for state in checkpoint["optimizer_states"]:
                for param_group in state['param_groups']:
                    param_group['lr'] = learning_rate
    

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=learning_rate)
        scheduler = StepLR(optimizer, step_size=lr_decay_step, gamma=0.9)
        #scheduler = CosineAnnealingLR(optimizer, T_max=500000, eta_min=4e-8)
        
        return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "interval": "step"}}
    
    def train_dataloader(self):
        dataset = InfoNetDataset(total_epoch=num_per_epoch*gpu_card*batchsize*1000, device='cpu')
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchsize, shuffle=False, num_workers=15, pin_memory=True, prefetch_factor=3, persistent_workers=False)
        #dataset = InfoNetDataset(total_epoch=480)
        #dataloader = torch.utils.data.DataLoader(dataset, batch_size=60, shuffle=True, num_workers=23, pin_memory=True, prefetch_factor=2)
        return dataloader
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(InfoNetDataset(total_epoch=1), batch_size=1, num_workers=15, pin_memory=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(InfoNetDataset(total_epoch=1), batch_size=1, num_workers=15, pin_memory=True)

@hydra.main(config_path='config', config_name='cfg_1-5d_small_test', version_base='1.1')
def main(cfg):
    ma_rate = 1.0
    global_step = 0

    global batchsize, latent_dim, latent_num, input_dim_x, input_dim_y, seq_len, decoder_query_dim, num_per_epoch, gpu_card, learning_rate, softrank_reg, lr_decay_step, targetnet_hiddim, hypermlp_hiddim, encoder2_expand_dim, encoder2_hiddim, encoder2_block
    batchsize = cfg.batchsize
    latent_dim = cfg.latent_dim
    latent_num = cfg.latent_num
    input_dim_x = cfg.input_dim_x
    input_dim_y = cfg.input_dim_y
    seq_len = cfg.seq_len
    decoder_query_dim = cfg.decoder_query_dim
    num_per_epoch = cfg.num_per_epoch
    gpu_card = cfg.gpu_card
    learning_rate = cfg.learning_rate
    softrank_reg = cfg.softrank_reg
    lr_decay_step = cfg.lr_decay_step
    targetnet_hiddim = cfg.targetnet_hiddim
    hypermlp_hiddim = cfg.hypermlp_hiddim
    encoder2_expand_dim = cfg.encoder2_expand_dim
    encoder2_hiddim = cfg.encoder2_hiddim
    encoder2_block = cfg.encoder2_block

    torch.set_float32_matmul_precision('medium')
    module = LightningWrapper()

    global logger
    logger = TensorBoardLogger('logs', name=cfg.name, version=cfg.version)
    #if os.path.isdir(logger.log_dir):
        #send2trash(logger.log_dir)

    checkpoint_callback = ModelCheckpoint(
        dirpath=os.path.join(logger.log_dir, 'checkpoints'),
        filename='mixture-1-5d-new-{step:08d}',
        every_n_train_steps= 8000,  # 每 2500 step 保存一次
        save_top_k=-1,  # 保存所有
        save_last=True
    )

    trainer = lightning.Trainer(
        max_epochs=500000,
        accelerator='auto',
        devices=-1,
        num_nodes=1,
        logger=logger,
        strategy='ddp_find_unused_parameters_true',
        gradient_clip_val=1.0,
        #limit_val_batches=0,         # 不做验证
        val_check_interval=8000,
        num_sanity_val_steps=1,      # 不做 sanity check
        use_distributed_sampler=False,
        callbacks=[checkpoint_callback]
    )
    trainer.fit(module)

if __name__ == '__main__':
    mp.set_start_method('spawn')
    main()

