import os
from pdearena import utils
from pdearena.data.datamodule import PDEDataModule
from pdearena.data.pdemoe_datapipe import NavierStokes2DDataset
from pdearena.lr_scheduler import LinearWarmupCosineAnnealingLR  # noqa: F401
from pdearena.models.pdemoe import PDEModel
from pdearena.models.pdemoe_models import NeuralOperator, HybridModel, MoEModel
# from pdearena.models.pdemodel import PDEModel
import torch 
import wandb 
from torch.utils.data import DataLoader 
import numpy as np
import torch.nn as nn
from scipy.spatial import cKDTree 
import pandas as pd
import time
import uuid 

DATA_PATH = "/path/to/pdemoe_data"
N_INPUT_SCALAR_COMPONENTS = 1
N_INPUT_VECTOR_COMPONENTS = 2
N_OUTPUT_SCALAR_COMPONENTS = 1
N_OUTPUT_VECTOR_COMPONENTS = 1 
TRAJLEN = 8
TIME_HISTORY = 4
TIME_FUTURE = 1
TIME_GAP = 0 
TIME_COST = {
    'gt': 0.2,
    'fine': 38,
    'medium': 20,
    'coarse': 12,
    'xcoarse': 8
}


def time_cost(fidelity):
    cost = TIME_COST.get(fidelity)
    if fidelity is None or cost is None:
        return torch.tensor(0.0)
    return torch.tensor(TIME_COST[fidelity])


def kl_divergence_to_uniform(gate_probs):
    """
    gate_probs: shape [batch_size, num_experts]
                Each row is a valid probability distribution over the experts
    Returns:
        A scalar tensor (float) representing the mean KL divergence
        from uniform across the batch.
    """
    eps = 1e-8
    batch_size, num_experts = gate_probs.shape
    
    # uniform distribution = 1 / num_experts
    # We'll compute:
    #   KL(P || U) = sum_i( p_i * log (p_i / u_i) )
    # where u_i = 1 / num_experts
    # => p_i * ( log(p_i) - log(1/num_experts) )
    
    # log(1/num_experts) is just -log(num_experts)
    log_uniform = -np.log(num_experts)
    
    # step 1: log(p_i)
    log_probs = torch.log(gate_probs + eps)
    
    # step 2: p_i * [ log(p_i) - log(1/num_experts) ]
    #         = p_i * log(p_i) + p_i * log(num_experts)
    # note: p_i * log(num_experts) = p_i * -log(1/num_experts) but 
    # we can just do it explicitly
    kl_per_sample = gate_probs * (log_probs - log_uniform)
    
    # step 3: sum over experts -> shape [batch_size]
    kl_per_sample = kl_per_sample.sum(dim=1)
    
    # step 4: take the mean over the batch
    mean_kl = kl_per_sample.mean()
    
    return mean_kl


class ModelWithExtra1x1x1(nn.Module):
    def __init__(self, base_model, in_ch=1, out_ch=1):
        super().__init__()
        self.base_model = base_model
        
        # A 3D pointwise convolution
        self.conv3d = nn.Conv3d(
            in_channels=in_ch,
            out_channels=out_ch,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True 
        )
        
        # Initialize weights & biases to zero if that's desired
        with torch.no_grad():
            self.conv3d.weight.zero_()
            if self.conv3d.bias is not None:
                self.conv3d.bias.zero_()

    def forward(self, x):
        # x is shape [N, C, D, H, W]
        # 1) First pass through the existing model
        x = self.base_model(x)  # still [N, C, D, H, W] = [32, 1, 3, 128, 128]
        
        # 2) Then apply the 3D conv
        x = self.conv3d(x)      # shape [N, C, D, H, W] = [32, 1, 3, 128, 128]
        return x
 

def _knn_upsample(image, new_H, new_W, k=4):
    """
    Upsample a 2D NumPy array `image` from shape (H, W) to shape (new_H, new_W)
    using a naive k-nearest neighbor interpolation approach.

    Parameters
    ----------
    image : np.ndarray of shape (H, W)
        The original 2D array to be upsampled.
    new_H : int
        Desired output height.
    new_W : int
        Desired output width.
    k : int
        Number of neighbors to consider when interpolating.

    Returns
    -------
    upsampled : np.ndarray of shape (new_H, new_W)
        The upsampled 2D array.
    """
    H, W = image.shape
    upsampled = np.zeros((new_H, new_W), dtype=image.dtype)

    # Compute floating scale factors for both height and width
    scale_h = float(new_H) / H
    scale_w = float(new_W) / W

    # For each pixel in the upsampled image:
    for new_x in range(new_H):
        for new_y in range(new_W):
            # Map (new_x, new_y) back to continuous coords in the original image
            orig_x = new_x / scale_h
            orig_y = new_y / scale_w

            # Find integer neighbors around (orig_x, orig_y). We'll gather them
            # in a small integer window and pick the k closest in Euclidean distance.
            int_x = int(np.floor(orig_x))
            int_y = int(np.floor(orig_y))
            neighbors = []
            max_search_radius = 2  # a small radius to collect candidate neighbors
            for dx in range(-max_search_radius, max_search_radius + 1):
                for dy in range(-max_search_radius, max_search_radius + 1):
                    nx = int_x + dx
                    ny = int_y + dy
                    # Make sure (nx, ny) is within original image bounds
                    if 0 <= nx < H and 0 <= ny < W:
                        dist = np.sqrt((orig_x - nx)**2 + (orig_y - ny)**2)
                        neighbors.append((dist, image[nx, ny]))

            # Sort neighbors by distance, take the k nearest, and average
            neighbors.sort(key=lambda x: x[0])
            k_nearest = neighbors[:k]
            val = np.mean([pixval for dist, pixval in k_nearest])
            upsampled[new_x, new_y] = val

    return upsampled

def _knn_upsample_optimized(image, new_H, new_W, k=4):
    """
    An optimized version of k-nearest neighbor interpolation using cKDTree.
    """
    H, W = image.shape
    # Flatten the original image coordinates (i, j) into an array for building the KD-Tree
    coords = np.array([(i, j) for i in range(H) for j in range(W)], dtype=np.float32)
    # Build the tree
    tree = cKDTree(coords)

    # Flatten the image values in the same order as coords
    flat_vals = image.ravel()

    # Prepare query coordinates in the upsampled space
    # (map each upsampled pixel back to original space)
    scale_h = float(new_H) / H
    scale_w = float(new_W) / W
    query_coords = []
    for new_x in range(new_H):
        for new_y in range(new_W):
            orig_x = new_x / scale_h
            orig_y = new_y / scale_w
            query_coords.append([orig_x, orig_y])
    query_coords = np.array(query_coords, dtype=np.float32)

    # Query the k nearest neighbors for all upsampled points
    # distances: shape (new_H*new_W, k)
    # indices:   shape (new_H*new_W, k)
    distances, indices = tree.query(query_coords, k=k)

    # If k=1, indices is shape (N, ), so make it (N,1) for consistent handling
    if k == 1:
        indices = indices[:, np.newaxis]

    # Average the intensities of the k neighbors
    upsampled_flat = np.mean(flat_vals[indices], axis=1)
    
    # Reshape to (new_H, new_W)
    upsampled = upsampled_flat.reshape(new_H, new_W)
    return upsampled

def knn_upsample(data, new_H, new_W, k=3):
    """
    data: np.ndarray of shape (N1, N2, H, W)
    Upsample the last two dimensions to (new_H, new_W) using knn_upsample().
    
    Returns upsampled_data: shape (N1, N2, new_H, new_W)
    """
    try:
        N1, N2, H, W = data.shape
        upsampled_data = np.zeros((N1, N2, new_H, new_W), dtype=data.dtype)
        for i in range(N1):
            for j in range(N2):
                # upsampled_data[i, j] = _knn_upsample(data[i, j], new_H, new_W, k=k)
                upsampled_data[i, j] = _knn_upsample_optimized(data[i, j], new_H, new_W, k=k)
        return upsampled_data
    except:
        N, H, W = data.shape
        upsampled_data = np.zeros((N, new_H, new_W), dtype=data.dtype)
        for i in range(N):
            # upsampled_data[i] = _knn_upsample(data[i], new_H, new_W, k=k)
            upsampled_data[i] = _knn_upsample_optimized(data[i], new_H, new_W, k=k)
        return upsampled_data


logger = utils.get_logger(__name__) 

def setupdir(path):
    os.makedirs(path, exist_ok=True)
    os.makedirs(os.path.join(path, "tb"), exist_ok=True)
    os.makedirs(os.path.join(path, "ckpts"), exist_ok=True)
 
def get_model(fidelity: str, hybrid: bool = False):
    # if `hybrid` is True, the model should be a hybrid model; otherwise, it should be a pure neural operator.
    # Return:
    # - the model
    # - the config

    # get the base model 
    cli = utils.PDECLI(
        PDEModel,
        datamodule_class=PDEDataModule,
        seed_everything_default=42,
        save_config_overwrite=True,
        run=False,
        parser_kwargs={"parser_mode": "omegaconf"},
    )

    if cli.trainer.default_root_dir is None:
        logger.warning("No default root dir set, using: ")
        cli.trainer.default_root_dir = os.environ.get("PDEARENA_OUTPUT_DIR", "./outputs")
        logger.warning(f"\t {cli.trainer.default_root_dir}")

    setupdir(cli.trainer.default_root_dir)
    base_model = cli.model  
    if hybrid:
        return HybridModel(base_model, fidelity=fidelity), cli.config
    else:
        return NeuralOperator(base_model), cli.config 

def train(data_path: str,  
          kl_weight: float = 1.0,
          delta: float = 0.001,
          batch_size: int = 32,
          learning_rate: float = 2e-4, 
          weight_decay: float = 1e-5,
          max_samples: int = 1000, 
          hybrid: bool = False): 
    expert_1, config = get_model("fine", hybrid=False) 
    expert_2, config = get_model("fine", hybrid=True) 
    expert_3, config = get_model("medium", hybrid=True) 
    expert_4, config = get_model("coarse", hybrid=True)  
    experts_list = nn.ModuleList([expert_1, expert_2, expert_3, expert_4])
    model = MoEModel(experts_list) 
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) #TODO: 换成adam? 
    
    project_name = "pde-moe-project"
    name = f"MoE-{config.model.name}"
    wandb.init(entity="your-entity", project=project_name, name=name)  
    
    df = pd.DataFrame(columns=['Expert_1', 'Expert_2', 'Expert_3', 'Expert_4', 'Viscosity', 'Step'])
    # gate_table = wandb.Table(columns=["Expert1", "Expert2", "Expert3", "Expert4", "Viscosity", "Step"])
     
    # Test the dataloader
    train_dataset = NavierStokes2DDataset(folder_path=data_path,  max_samples=max_samples, subset="train")
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        drop_last=True   # optionally
    ) 
    valid_dataset = NavierStokes2DDataset(folder_path=data_path,  max_samples=1000, subset="valid")
    valid_dataloader = DataLoader(
        valid_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        drop_last=True   # optionally
    )
    test_dataset = NavierStokes2DDataset(folder_path=data_path,  max_samples=1000, subset="test")
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        drop_last=True   # optionally
    )
    _lambda = 0.2 
    # threshold = 10.0
    threshold = 100.0
    for step, batch in enumerate(train_dataloader): 
        start_time = time.time()

        x, y, fidelity_data = batch # y is the groundtruth
        y_pred, gate_probs, all_outputs = model(batch) 
        
        mse_loss = torch.mean((y_pred - y) ** 2) #TODO: add support to unrolling loss 

        ####################################################################
        # Evaluate the KL regularization term and the probability penalty term
        #################################################################### 
        out_probs = gate_probs.view(gate_probs.shape[0], gate_probs.shape[-1]) # shape [batch_size, num_experts]

        # evaluate the mean of out_probs over batch
        mean_probs = torch.mean(out_probs, dim=0)  # shape [num_experts] 
        cost_vector = torch.tensor([TIME_COST['gt'], TIME_COST['fine'], TIME_COST['medium'], TIME_COST['coarse']])
        penalty =torch.dot(mean_probs, cost_vector)

        prob_1 = mean_probs[0].item()
        prob_2 = mean_probs[1].item()   
        prob_3 = mean_probs[2].item()
        prob_4 = mean_probs[3].item()
        wandb.log({"train/prob_expert_1": prob_1}, step=step)
        wandb.log({"train/prob_expert_2": prob_2}, step=step)
        wandb.log({"train/prob_expert_3": prob_3}, step=step)
        wandb.log({"train/prob_expert_4": prob_4}, step=step)
        wandb.log({"train/penalty": penalty.item()}, step=step)


        kl_reg = kl_divergence_to_uniform(out_probs) # KL divergence to uniform distribution  
        # probs_record = gate_probs.clone().detach().numpy()
        probs_record = gate_probs.view(gate_probs.shape[0], gate_probs.shape[-1])
        probs_record = probs_record.clone().detach()
        step_column = torch.full((probs_record.shape[0], 1), step, dtype=probs_record.dtype, device=probs_record.device)
        gating_data = torch.cat([probs_record.clone().detach(), fidelity_data["viscosity"].clone().detach(), step_column], dim=1) 
        gating_data = gating_data.numpy()
        new_df = pd.DataFrame(gating_data, columns=['Expert_1', 'Expert_2', 'Expert_3', 'Expert_4', 'Viscosity', 'Step'])
        df = pd.concat([df, new_df], ignore_index=True) 
        ####################################################################

        _kl_weight = 0.2 / (step + 1.0)
        wandb.log({"train/mse_loss": mse_loss.item()}, step=step)
        wandb.log({"train/kl_loss": kl_reg.item()}, step=step)
        wandb.log({"train/kl_weight": _kl_weight}, step=step)
        print("Step: ", step, "Loss: ", mse_loss.item())

        loss = mse_loss + _kl_weight * kl_reg  + _lambda * (penalty - threshold)

        optimizer.zero_grad()
        loss.backward() 

        _lambda -= ( penalty.item()- threshold )/(step + 1.0)**0.5 
        if _lambda < 0.0:
            _lambda = 0.0
        wandb.log({"train/lambda": _lambda}, step=step)
        wandb.log({"train/penalty": penalty.item()}, step=step)

        # Log the gradient norm of the gating model
        total_norm = 0.0
        for param in model.gate.parameters():
            if param.grad is not None:
                total_norm += param.grad.norm().item() ** 2
        total_norm = total_norm ** 0.5
        wandb.log({"train/gate_grad": total_norm }, step=step)
        optimizer.step() 

        end_time = time.time()
        print(f"Forward pass time: {end_time - start_time:.6f} seconds")
        
        # Validation
        if (max_samples>100) and (step % 50 == 0) and (step > 1): 
            with torch.no_grad():
                # evaluate on validation set
                val_loss = 0
                val_penalty_total = 0.0
                for v_step, v_batch in enumerate(valid_dataloader):
                    x, y, fidelity_data = v_batch
                    if v_step > 5:
                        break 
                    y_pred, gate_probs, _ = model(v_batch) 
                    gate_probs.clone().detach().numpy()
                    gate_probs = gate_probs.view(gate_probs.shape[0], gate_probs.shape[-1])
                    # evaluate the mean of out_probs over batch
                    mean_probs = torch.mean(gate_probs, dim=0)  # shape [num_experts] 
                    cost_vector = torch.tensor([TIME_COST['gt'], TIME_COST['fine'], TIME_COST['medium'], TIME_COST['coarse']])
                    val_penalty =torch.dot(mean_probs, cost_vector)
                    val_penalty_total +=  val_penalty.item() 
                    step_column = torch.full((gate_probs.shape[0], 1), step, dtype=gate_probs.dtype, device=gate_probs.device)
                    gating_data = torch.cat([gate_probs.clone().detach(), fidelity_data["viscosity"].clone().detach(), step_column], dim=1) 
                    gating_data = gating_data.numpy()
                    new_df = pd.DataFrame(gating_data, columns=['Expert_1', 'Expert_2', 'Expert_3', 'Expert_4', 'Viscosity', 'Step'])
                    df = pd.concat([df, new_df], ignore_index=True)
                    loss = torch.mean((y_pred - y) ** 2)  
                    val_loss += loss.item()
                wandb.log({"valid/loss": val_loss/(v_step+1)}, step=step)
                wandb.log({"valid/val_penalty": val_penalty_total/(v_step+1)}, step=step)
                
                # Save model checkpoint
                if hybrid:
                    ckpt_file_name = f"HYBRID_{config.model.name}_STEP_{step}_VAL_{val_loss/v_step}.pt"
                else:
                    ckpt_file_name = f"BASELINE_{config.model.name}_STEP_{step}_VAL_{val_loss/v_step}.pt"
                torch.save(model.state_dict(), os.path.join("checkpoints/", ckpt_file_name))  
        if (max_samples>100) and (step % 50 == 0) and (step > 1): 
            # evaluate on test set
            test_loss = 0
            with torch.no_grad():
                test_penalty_total = 0.0
                for t_step, t_batch in enumerate(test_dataloader): 
                    x, y, fidelity_data = t_batch
                    if t_step > 5:
                        break 
                    y_pred, gate_probs, _ = model(t_batch)
                    gate_probs.clone().detach().numpy()
                    gate_probs = gate_probs.view(gate_probs.shape[0], gate_probs.shape[-1])
                    # evaluate the mean of out_probs over batch
                    mean_probs = torch.mean(gate_probs, dim=0)  # shape [num_experts] 
                    cost_vector = torch.tensor([TIME_COST['gt'], TIME_COST['fine'], TIME_COST['medium'], TIME_COST['coarse']])
                    test_penalty =torch.dot(mean_probs, cost_vector)
                    test_penalty_total +=  test_penalty.item() 
                    step_column = torch.full((gate_probs.shape[0], 1), step, dtype=gate_probs.dtype, device=gate_probs.device)
                    gating_data = torch.cat([gate_probs.clone().detach(), fidelity_data["viscosity"].clone().detach(), step_column], dim=1) 
                    gating_data = gating_data.numpy()
                    new_df = pd.DataFrame(gating_data, columns=['Expert_1', 'Expert_2', 'Expert_3', 'Expert_4', 'Viscosity', 'Step'])
                    df = pd.concat([df, new_df], ignore_index=True)
                    loss = torch.mean((y_pred - y) ** 2)
                    test_loss += loss.item()
                wandb.log({"test/loss": test_loss/(t_step+1)}, step=step) 
                wandb.log({"test/test_penalty": test_penalty_total/(t_step+1)}, step=step) 
        print("DF shape: ", df.shape)
        df.to_csv(f'sample_data_delta_{delta}_kl_{kl_weight}.csv', index=True)
    wandb.finish() 

if __name__ == "__main__":  
    # Do not change 
    # data_path = "/path/to/pdemoe_data"
    # data_path = "/path/to/pdemoe_data_test"
    data_path = "/path/to/pdemoe_data_sample"  # for testing
    max_samples = 80000
    batch_size = 32
    
    hybrid = True  # this option has no any effect on the result
    train(kl_weight=10, delta=0.1, data_path=data_path, batch_size=batch_size, max_samples=max_samples, hybrid=hybrid)  
    print("Success.") 

