import os

from pdearena import utils
from pdearena.data.datamodule import PDEDataModule
from pdearena.data.pdemoe_datapipe import NavierStokes2DDataset
from pdearena.lr_scheduler import LinearWarmupCosineAnnealingLR  # noqa: F401
from pdearena.models.pdemoe import PDEModel
# from pdearena.models.pdemodel import PDEModel
import torch 
import wandb
from dataclasses import dataclass
from typing import Optional, Tuple
import random 
from torch.utils.data import DataLoader, Dataset
import random  
import h5py
from scipy.ndimage import zoom
import numpy as np
import torch.nn as nn
from scipy.spatial import cKDTree
import torch.nn.functional as F

DATA_PATH = "/path/to/pdemoe_data"
N_INPUT_SCALAR_COMPONENTS = 1
N_INPUT_VECTOR_COMPONENTS = 2
N_OUTPUT_SCALAR_COMPONENTS = 1
N_OUTPUT_VECTOR_COMPONENTS = 1 
TRAJLEN = 8
TIME_HISTORY = 4
TIME_FUTURE = 1
TIME_GAP = 0 


class ModelWithExtra1x1x1(nn.Module):
    def __init__(self, base_model, in_ch=1, out_ch=1):
        super().__init__()
        self.base_model = base_model
        
        # A 3D pointwise convolution
        self.conv3d = nn.Conv3d(
            in_channels=in_ch,
            out_channels=out_ch,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True 
        )
        
        # Initialize weights & biases to zero if that's desired
        with torch.no_grad():
            self.conv3d.weight.zero_()
            if self.conv3d.bias is not None:
                self.conv3d.bias.zero_()

    def forward(self, x):
        # x is shape [N, C, D, H, W]
        # 1) First pass through the existing model
        x = self.base_model(x)  # still [N, C, D, H, W] = [32, 1, 3, 128, 128]
        
        # 2) Then apply the 3D conv
        x = self.conv3d(x)      # shape [N, C, D, H, W] = [32, 1, 3, 128, 128]
        return x
 

def _knn_upsample(image, new_H, new_W, k=4):
    """
    Upsample a 2D NumPy array `image` from shape (H, W) to shape (new_H, new_W)
    using a naive k-nearest neighbor interpolation approach.

    Parameters
    ----------
    image : np.ndarray of shape (H, W)
        The original 2D array to be upsampled.
    new_H : int
        Desired output height.
    new_W : int
        Desired output width.
    k : int
        Number of neighbors to consider when interpolating.

    Returns
    -------
    upsampled : np.ndarray of shape (new_H, new_W)
        The upsampled 2D array.
    """
    H, W = image.shape
    upsampled = np.zeros((new_H, new_W), dtype=image.dtype)

    # Compute floating scale factors for both height and width
    scale_h = float(new_H) / H
    scale_w = float(new_W) / W

    # For each pixel in the upsampled image:
    for new_x in range(new_H):
        for new_y in range(new_W):
            # Map (new_x, new_y) back to continuous coords in the original image
            orig_x = new_x / scale_h
            orig_y = new_y / scale_w

            # Find integer neighbors around (orig_x, orig_y). We'll gather them
            # in a small integer window and pick the k closest in Euclidean distance.
            int_x = int(np.floor(orig_x))
            int_y = int(np.floor(orig_y))
            neighbors = []
            max_search_radius = 2  # a small radius to collect candidate neighbors
            for dx in range(-max_search_radius, max_search_radius + 1):
                for dy in range(-max_search_radius, max_search_radius + 1):
                    nx = int_x + dx
                    ny = int_y + dy
                    # Make sure (nx, ny) is within original image bounds
                    if 0 <= nx < H and 0 <= ny < W:
                        dist = np.sqrt((orig_x - nx)**2 + (orig_y - ny)**2)
                        neighbors.append((dist, image[nx, ny]))

            # Sort neighbors by distance, take the k nearest, and average
            neighbors.sort(key=lambda x: x[0])
            k_nearest = neighbors[:k]
            val = np.mean([pixval for dist, pixval in k_nearest])
            upsampled[new_x, new_y] = val

    return upsampled

def _knn_upsample_optimized(image, new_H, new_W, k=4):
    """
    An optimized version of k-nearest neighbor interpolation using cKDTree.
    """
    H, W = image.shape
    # Flatten the original image coordinates (i, j) into an array for building the KD-Tree
    coords = np.array([(i, j) for i in range(H) for j in range(W)], dtype=np.float32)
    # Build the tree
    tree = cKDTree(coords)

    # Flatten the image values in the same order as coords
    flat_vals = image.ravel()

    # Prepare query coordinates in the upsampled space
    # (map each upsampled pixel back to original space)
    scale_h = float(new_H) / H
    scale_w = float(new_W) / W
    query_coords = []
    for new_x in range(new_H):
        for new_y in range(new_W):
            orig_x = new_x / scale_h
            orig_y = new_y / scale_w
            query_coords.append([orig_x, orig_y])
    query_coords = np.array(query_coords, dtype=np.float32)

    # Query the k nearest neighbors for all upsampled points
    # distances: shape (new_H*new_W, k)
    # indices:   shape (new_H*new_W, k)
    distances, indices = tree.query(query_coords, k=k)

    # If k=1, indices is shape (N, ), so make it (N,1) for consistent handling
    if k == 1:
        indices = indices[:, np.newaxis]

    # Average the intensities of the k neighbors
    upsampled_flat = np.mean(flat_vals[indices], axis=1)
    
    # Reshape to (new_H, new_W)
    upsampled = upsampled_flat.reshape(new_H, new_W)
    return upsampled

def knn_upsample(data, new_H, new_W, k=3):
    """
    data: np.ndarray of shape (N1, N2, H, W)
    Upsample the last two dimensions to (new_H, new_W) using knn_upsample().
    
    Returns upsampled_data: shape (N1, N2, new_H, new_W)
    """
    try:
        N1, N2, H, W = data.shape
        upsampled_data = np.zeros((N1, N2, new_H, new_W), dtype=data.dtype)
        for i in range(N1):
            for j in range(N2):
                # upsampled_data[i, j] = _knn_upsample(data[i, j], new_H, new_W, k=k)
                upsampled_data[i, j] = _knn_upsample_optimized(data[i, j], new_H, new_W, k=k)
        return upsampled_data
    except:
        N, H, W = data.shape
        upsampled_data = np.zeros((N, new_H, new_W), dtype=data.dtype)
        for i in range(N):
            # upsampled_data[i] = _knn_upsample(data[i], new_H, new_W, k=k)
            upsampled_data[i] = _knn_upsample_optimized(data[i], new_H, new_W, k=k)
        return upsampled_data


logger = utils.get_logger(__name__) 

def setupdir(path):
    os.makedirs(path, exist_ok=True)
    os.makedirs(os.path.join(path, "tb"), exist_ok=True)
    os.makedirs(os.path.join(path, "ckpts"), exist_ok=True)

 
class HybridModel(nn.Module):
    """If you want a hybrid approach, you might add some 'residual' logic in here."""
    def __init__(self, base_model: nn.Module, fidelity: str, alpha: float = 0.5):
        super().__init__()
        self.base_model = base_model
        self.alpha = alpha
        if fidelity not in ["fine", "coarse", "xcoarse", "medium", "gt"]:
            raise ValueError(f"Invalid fidelity level: {fidelity}")
        self.fidelity = fidelity 

    def forward(self, batch): 
        fidelity = self.fidelity
        alpha = self.alpha
        x, y, fidelity_data = batch
        y_pred = self.base_model(x)
        coarse_data = fidelity_data.get(fidelity)
        if coarse_data is None: 
            print("Fidelity: ", fidelity)
            print("Fidelity Data: ", fidelity_data)
            raise ValueError("Fidelity not found in `fidelity_data`.") 
        x_coarse, y_coarse = coarse_data
        return alpha*y_pred + (1-alpha)*y_coarse  
    
class NeuralOperator(nn.Module):
    def __init__(self, base_model: nn.Module):
        super().__init__()
        self.model = base_model

    def forward(self, batch):
        x, y, fidelity_data = batch
        return self.model(x)


def get_model(fidelity: str, hybrid: bool = False):
    # if `hybrid` is True, the model should be a hybrid model; otherwise, it should be a pure neural operator.
    # Return:
    # - the model
    # - the config

    # get the base model 
    cli = utils.PDECLI(
        PDEModel,
        datamodule_class=PDEDataModule,
        seed_everything_default=42,
        save_config_overwrite=True,
        run=False,
        parser_kwargs={"parser_mode": "omegaconf"},
    )

    if cli.trainer.default_root_dir is None:
        logger.warning("No default root dir set, using: ")
        cli.trainer.default_root_dir = os.environ.get("PDEARENA_OUTPUT_DIR", "./outputs")
        logger.warning(f"\t {cli.trainer.default_root_dir}")

    setupdir(cli.trainer.default_root_dir)
    base_model = cli.model 
    if hybrid:
        return HybridModel(base_model, fidelity=fidelity), cli.config
    else:
        return NeuralOperator(base_model), cli.config

def train(data_path: str, 
          fidelity: str, 
          batch_size: int = 32,
          learning_rate: float = 1e-4, 
          weight_decay: float = 1e-5,
          max_samples: int = 1000, 
          hybrid: bool = False): 
    model, config = get_model(fidelity, hybrid=hybrid) 
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    
    if hybrid:
        project_name = "pde-moe-project"
        name = f"hybrid-{config.model.name}-{fidelity}"
    else:
        project_name = "pde-moe-project" 
        name = f"baseline-{config.model.name}"
    wandb.init(entity="your-entity", project=project_name, name=name)  
     
    # Test the dataloader
    train_dataset = NavierStokes2DDataset(folder_path=data_path,  max_samples=max_samples, subset="train")
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=1,
        drop_last=True   # optionally
    ) 
    valid_dataset = NavierStokes2DDataset(folder_path=data_path,  max_samples=1000, subset="valid")
    valid_dataloader = DataLoader(
        valid_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=1,
        drop_last=True   # optionally
    )
    test_dataset = NavierStokes2DDataset(folder_path=data_path,  max_samples=1000, subset="test")
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=1,
        drop_last=True   # optionally
    )
    for step, batch in enumerate(train_dataloader): 
        x, y, fidelity_data = batch 
        y_pred = model(batch)
        loss = torch.mean((y_pred - y) ** 2) 
        wandb.log({"train/loss": loss.item()}, step=step)
        print("Step: ", step, "Loss: ", loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 

        # Validation
        if (max_samples>100) and (step % 50 == 0): 
            with torch.no_grad():
                # evaluate on validation set
                val_loss = 0
                for v_step, v_batch in enumerate(valid_dataloader):
                    x, y, fidelity_data = v_batch
                    if v_step > 5:
                        break 
                    y_pred = model(v_batch) 
                    loss = torch.mean((y_pred - y) ** 2)  
                    val_loss += loss.item()
                wandb.log({"valid/loss": val_loss/v_step}, step=step)
                
                # Save model checkpoint
                if hybrid:
                    ckpt_file_name = f"HYBRID_{config.model.name}_STEP_{step}_VAL_{val_loss/v_step}.pt"
                else:
                    ckpt_file_name = f"BASELINE_{config.model.name}_STEP_{step}_VAL_{val_loss/v_step}.pt"
                torch.save(model.state_dict(), os.path.join("checkpoints/", ckpt_file_name))  
        if (max_samples>100) and (step % 50 == 0): 
            # evaluate on test set
            test_loss = 0
            with torch.no_grad():
                for t_step, t_batch in enumerate(test_dataloader): 
                    x, y, fidelity_data = t_batch
                    if t_step > 5:
                        break 
                    y_pred = model(t_batch)
                    loss = torch.mean((y_pred - y) ** 2)
                    test_loss += loss.item()
                wandb.log({"test/loss": test_loss/t_step}, step=step)
    wandb.finish() 

if __name__ == "__main__":  
    # Do not change 
    # data_path = "/path/to/pdemoe_data" # for formal training
    # data_path = "/path/to/pdemoe_data_test"  # for large-scale testing
    # data_path = "/path/to/pdemoe_data_test_2"  # for large-scale testing
    data_path = "/path/to/pdemoe_data_sample" # for testing
    # data_path = "/path/to/pdemoe_data_test_20250316"
    max_samples = 80000 # Expected: 80000
    batch_size = 32 # Original: 32
    # Use Hybrid model or Pure Neural Operator / Fidelity
    hybrid = True           
    fidelity = "medium" 
    train(fidelity=fidelity, data_path=data_path, batch_size=batch_size, max_samples=max_samples, hybrid=hybrid)  
    print("Success.") 

