import argparse
import os
import time
import torch
from torch.utils.data import DataLoader, TensorDataset
from torchvision.utils import save_image, make_grid
from easydict import EasyDict
import numpy as np
from tqdm import tqdm
import h5py

from models import get_flow_model
from models.constraints import (
    InitialCondition, BoundaryCondition, ChainConstraint, 
    NoneConstraint, PeriodicCondition
)
from datasets import get_dataset
from utils import load_config, seed_all
from vis_utils import draw
import matplotlib.pyplot as plt # Import matplotlib
import matplotlib as mpl
import csv
from datetime import datetime

def calculate_physics_loss(samples, ic_ref, dx, dy):
    assert samples.ndim == 4,
    N, s1, s2, T = samples.shape
    assert s1 == s2, 
    s = s1

    # 1) IC MSE
    diff_ic = samples[..., 0] - ic_ref                    # (N, s, s)
    ic_mse_per_sample = np.mean(diff_ic**2, axis=(1, 2))  # (N,)


    area_elem = dx * dy
    mass_t = samples.reshape(N, s * s, T).sum(axis=1) * area_elem  # (N, T)
    mass0 = mass_t[:, [0]]                                         # (N, 1)
    mass_err = mass_t - mass0                                      # (N, T)


    if T > 1:
        mass_mse_per_sample = np.mean(mass_err[:, 1:]**2, axis=1)  # (N,)
    else:
        mass_mse_per_sample = np.zeros(N, dtype=samples.dtype)


    mass_max_drift_per_sample = np.max(np.abs(mass_err), axis=1)   # (N,)

    return ic_mse_per_sample, mass_mse_per_sample


def evaluate_and_print_results(generated_samples, real_samples, ic_ref):
    """
    Calculates and prints evaluation metrics.
    Adapted from cal_res_ns.py.
    """
    print("\n" + "="*50)
    print(" " * 15 + "EVALUATION RESULTS")
    print("="*50)

    # --- Define Grids & Ensure Consistency ---
    nsim, nx_1, nx_1, T_gen = generated_samples.shape
    if real_samples.shape[-1] != T_gen:
        print(f"Warning: Time dimension mismatch. Generated T={T_gen}, Real T={real_samples.shape[-1]}. Truncating real samples.")
        real_samples = real_samples[..., :T_gen]

    nx = nx_1
    dx = 1.0 / nx
    dy = 1.0 / nx

    print(f"Number of samples compared: {generated_samples.shape[0]}\n")

    # --- Calculate Metrics ---
    # 1. MSE vs. Real Samples
    mse_per_sample = np.mean((real_samples - generated_samples)**2, axis=(1, 2, 3))
    mean_mse = np.mean(mse_per_sample)
    std_mse = np.std(mse_per_sample)

    # 2. Constraint Violations for Generated Samples
    ic_loss_gen, pde_loss_gen = calculate_physics_loss(generated_samples, ic_ref, dx, dy)

    # 3. Constraint Violations for Real Samples (for reference)
    ic_loss_real, pde_loss_real = calculate_physics_loss(real_samples, ic_ref, dx, dy)

    # --- Print Results ---
    print(f"--- Error vs. Ground Truth ---")
    print(f"Mean of Mean Squared Error (MSE): {mean_mse:.6e}")
    print(f"Std. of Mean Squared Error (MSE): {std_mse:.6e}\n")

    print(f"--- Constraint Violations (Generated Samples) ---")
    print(f"Mean Initial Condition Violation: {np.mean(ic_loss_gen):.6e}")
    print(f"Std. Initial Condition Violation: {np.std(ic_loss_gen):.6e}")
    print(f"Mean PDE Residual Violation:      {np.mean(pde_loss_gen):.6e}")
    print(f"Std. PDE Residual Violation:      {np.std(pde_loss_gen):.6e}\n")

    print(f"--- Constraint Violations (Real Samples for Reference) ---")
    print(f"Mean Initial Condition Violation: {np.mean(ic_loss_real):.6e}")
    print(f"Std. Initial Condition Violation: {np.std(ic_loss_real):.6e}")
    print(f"Mean PDE Residual Violation:      {np.mean(pde_loss_real):.6e}")
    print(f"Std. PDE Residual Violation:      {np.std(pde_loss_real):.6e}")
    print("="*50)

    # --- Collect metrics for saving ---
    metrics = {
        "mean_mse": mean_mse,
        "std_mse": std_mse,
        "gen_ic_violation_mean": float(np.mean(ic_loss_gen)),
        "gen_ic_violation_std": float(np.std(ic_loss_gen)),
        "gen_pde_violation_mean": float(np.mean(pde_loss_gen)),
        "gen_pde_violation_std": float(np.std(pde_loss_gen)),
        "real_ic_violation_mean": float(np.mean(ic_loss_real)),
        "real_ic_violation_std": float(np.std(ic_loss_real)),
        "real_pde_violation_mean": float(np.mean(pde_loss_real)),
        "real_pde_violation_std": float(np.std(pde_loss_real)),
        "n_samples_compared": int(generated_samples.shape[0]),
        "nx": int(nx),
        "T": int(T_gen),
    }
    return metrics

def _append_row_to_csv(csv_path, row_dict):
    """
    Append a single row (dict) to a CSV file.
    If file does not exist, write header first.
    """

    fieldnames = list(row_dict.keys())
    file_exists = os.path.isfile(csv_path)
    with open(csv_path, mode='a', newline='', encoding='utf-8') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        if not file_exists:
            writer.writeheader()
        writer.writerow(row_dict)

class SamplingConfig:
    def __init__(self):

        self.n_sample = 64
        self.n_step = 200
        self.n_mix = 3
        self.resample_step = 10
        self.n_eval = 1
        

        self.n_vis_samples = 16
        self.grid_nrow = 4
        

        self.constraint_strength = 1.0


def create_constraint(constraint_type, test_data, device, config=None):
    if constraint_type == 'none':
        return NoneConstraint()
    
    elif constraint_type == 'ic':

        return InitialCondition(test_data[:, :, 0]).to(device)
    
    elif constraint_type == 'bc':

        return BoundaryCondition(test_data[:1, 0, :]).to(device)
    
    elif constraint_type == 'icbc':

        ic_constraint = InitialCondition(test_data[:1, :, 0])
        bc_constraint = BoundaryCondition(test_data[:1, 0, :])
        return ChainConstraint(ic_constraint, bc_constraint).to(device)
    
    elif constraint_type == 'periodic':

        ndim = len(test_data.shape) - 1
        return PeriodicCondition(ndim, dims=(1,)).to(device)
    
    elif constraint_type == 'custom':

        class CenterFixConstraint(torch.nn.Module):
            def __init__(self, center_value, center_region):
                super().__init__()
                self.center_value = center_value
                self.center_region = center_region
            
            def adjust(self, x1):
                h, w = self.center_region
                center_h, center_w = x1.shape[-2] // 2, x1.shape[-1] // 2
                x1[:, center_h-h//2:center_h+h//2, center_w-w//2:center_w+w//2] = self.center_value
                return x1
            
            def to(self, device):
                self.center_value = self.center_value.to(device)
                return self
        

        center_value = test_data[:1, 40:60, 40:60].mean()
        return CenterFixConstraint(center_value, (20, 20)).to(device)
    
    else:
        raise ValueError(f": {constraint_type}")


def sample_unconstrained(model, sampling_config, dims, device):

    start_time = time.time()
    
    with torch.no_grad():
        samples = model.sample(
            n_sample=sampling_config.n_sample,
            n_eval=sampling_config.n_eval,
            dims=dims,
            device=device
        )
    
    end_time = time.time()
    print(f" {end_time - start_time:.2f}")
    return samples

def sample_eci_constrained(model, test_data, sampling_config, dims, device, projection_info):
    start_time = time.time()
    

    constraint = InitialCondition(test_data[:, :, 0]).to(device)
    
    with torch.no_grad():
        samples = model.eci_sample_ns(
            n_sample=sampling_config.batch_sample,
            n_step=sampling_config.n_step,
            n_mix=sampling_config.n_mix,
            resample_step=sampling_config.resample_step,
            dims=dims,
            device=device,
            constraint=constraint,
            projection_info=projection_info
        )
    
    end_time = time.time()
    print(f"{end_time - start_time:.2f}s")
    return samples

def sample_pcfm_constrained(model, test_data, sampling_config, dims, device, projection_info, config):
    start_time = time.time()
    
    with torch.no_grad():
        samples = model.pcfm_sample_ns(
            n_sample=sampling_config.batch_sample,
            n_step=sampling_config.n_step,
            n_mix=sampling_config.n_mix,
            resample_step=sampling_config.resample_step,
            dims=dims,
            device=device,
            projection_info=projection_info
        )
    
    end_time = time.time()
    print(f"{end_time - start_time:.2f}")
    return samples

def sample_dpde_constrained(model, test_data, sampling_config, dims, device, projection_info, config):
    start_time = time.time()
    
    # with torch.no_grad():
    samples = model.guided_sample_ns(
        n_sample=sampling_config.batch_sample,
        n_step=sampling_config.n_step,
        dims=dims,
        device=device,
        projection_info=projection_info,
        eta=getattr(sampling_config, "eta", 200),
        bound_tolerance=getattr(sampling_config, "bound_tolerance", 1e-13),
    )
    
    end_time = time.time()
    return samples

def sample_ccd_constrained(model, test_data, sampling_config, dims, device, projection_info, config):
    start_time = time.time()
    
    with torch.no_grad():
        samples = model.ccd_sample_ns(
            n_sample=sampling_config.batch_sample,
            n_step=sampling_config.n_step,
            dims=dims,
            device=device,
            projection_info=projection_info,
            alpha_exp = sampling_config.alpha_exp
        )
    
    end_time = time.time()
    print(f"{end_time - start_time:.2f}s")
    return samples


def save_samples(samples, output_dir, config, constraint_type="unconstrained"):
    os.makedirs(output_dir, exist_ok=True)
    

    filename_prefix = constraint_type.replace(" ", "_").lower()
    torch.save(samples.cpu(), os.path.join(output_dir, f'{filename_prefix}_samples.pt'))
    

    n_vis = min(config.n_vis_samples, samples.shape[0])
    img_tensors = []
    
    for i in range(n_vis):

        if hasattr(config, 'vis') and config.vis:
            img_tensor = draw(samples[i], **config.vis)
        else:

            img_tensor = draw(samples[i], vmin=0.0, vmax=1.0)
        img_tensors.append(img_tensor)
    

    grid_img = make_grid(img_tensors, nrow=config.grid_nrow, padding=2, normalize=False)
    

    image_path = os.path.join(output_dir, f'{filename_prefix}_samples_grid.png')
    save_image(grid_img, image_path)

    stats = {
        'mean': samples.mean().item(),
        'std': samples.std().item(),
        'min': samples.min().item(),
        'max': samples.max().item(),
        'shape': list(samples.shape)
    }
    
    stats_path = os.path.join(output_dir, f'{filename_prefix}_stats.txt')
    with open(stats_path, 'w', encoding='utf-8') as f:
        f.write(f"Sampling Statistics ({constraint_type}):")
        f.write(f"Shape: {stats['shape']}\n")
        f.write(f"Mean: {stats['mean']:.6f}\n")
        f.write(f"Standard Deviation: {stats['std']:.6f}\n")
        f.write(f"Min: {stats['min']:.6f}\n")
        f.write(f"Max: {stats['max']:.6f}\n")

    print(f"Statistics saved to: {stats_path}")


def load_h5_test_data(config, device):


    h5_file_path = os.path.join(config.datasets.root, config.datasets.test['data_file'])
    

    with h5py.File(h5_file_path, 'r') as h5_file:

        attribute_names = list(h5_file.keys())

        

        all_data = {}
        
        for attr_name in attribute_names:

            if h5_file[attr_name].ndim == 0:
                continue
            attr_data = h5_file[attr_name][:]
            print(f"{attr_name}' : {attr_data.shape}")
            

            if isinstance(attr_data, np.ndarray):
                attr_tensor = torch.from_numpy(attr_data).float()
            else:
                attr_tensor = torch.tensor(attr_data, dtype=torch.float32)
            

            attr_tensor = attr_tensor.to(device)
            

            all_data[attr_name] = attr_tensor
            


        if len(all_data) == 1:
            single_attr_name = list(all_data.keys())[0]

            return all_data[single_attr_name]
        


        return all_data

def main(sampling_method, override_n_mix=None, override_alpha_exp=None):
    parser = argparse.ArgumentParser()


    parser.add_argument('--checkpoint', type=str, default='/path/to/eci/ns/2025_09_12/epoch=1871-step=196500.ckpt')

    parser.add_argument('--config', type=str, default='configs/ns.yml')
    parser.add_argument('--output_dir', type=str, default='/path/to/eci/ns/results')
    parser.add_argument('--device', type=str, default='cuda',
                        help='')
    
    parser.add_argument('--sampling_method', type=str, default=sampling_method,
                        choices=['eci', 'pcfm', 'ccd', 'dpde'],
                        help='')
    parser.add_argument('--constraint_type', type=str, default='ic',
                        choices=['none', 'ic', 'bc', 'icbc', 'periodic', 'custom'],
                        help='')
    parser.add_argument('--n_sample', type=int, default=1000,
                        help='')
    parser.add_argument('--batch_size', type=int, default=64,
                        help='')
    parser.add_argument('--n_step', type=int, default=200,
                        help='')
    parser.add_argument('--n_mix', type=int, default=1,
                        help='')
    parser.add_argument('--resample_step', type=int, default=0,
                        help='')
    parser.add_argument('--n_eval', type=int, default=1,
                        help='')
    parser.add_argument('--alpha_exp', type=int, default=1,
                        help='')
    

    parser.add_argument('--n_vis_samples', type=int, default=16,
                        help='')
    parser.add_argument('--grid_nrow', type=int, default=4,
                        help='')
    

    parser.add_argument('--seed', type=int, default=0,
                        help='')
    
    args = parser.parse_args()
    

    if override_n_mix is not None:
        args.n_mix = override_n_mix
    if override_alpha_exp is not None:
        args.alpha_exp = override_alpha_exp
    

    seed_all(args.seed)
    

    if args.device == 'cuda' and not torch.cuda.is_available():
        args.device = 'cpu'
        raise ValueError()
    
    device = torch.device(args.device)

    

    ckpt = torch.load(args.checkpoint, map_location=device)
    

    if args.config:
        config = load_config(args.config)
    else:
        config = EasyDict(ckpt['config'])
    

    print(config)
    
    sampling_config = SamplingConfig()
    sampling_config.n_sample = args.n_sample
    sampling_config.n_step = args.n_step
    sampling_config.n_mix = args.n_mix
    sampling_config.resample_step = args.resample_step if args.resample_step > 0 else None
    sampling_config.n_eval = args.n_eval
    sampling_config.n_vis_samples = args.n_vis_samples
    sampling_config.grid_nrow = args.grid_nrow
    sampling_config.alpha_exp = args.alpha_exp
    
    

    if hasattr(config, 'vis'):
        sampling_config.vis = config.vis
    


    model = get_flow_model(config.model, config.encoder).to(device)

    if isinstance(ckpt, dict) and 'state_dict' in ckpt:
        sd = ckpt['state_dict']
        sd_model = { (k[6:] if k.startswith('model.') else k): v for k, v in sd.items() }
        missing, unexpected = model.load_state_dict(sd_model, strict=False)
        if missing:
            print(f"[load_state_dict] missing keys: {missing}")
        if unexpected:
            print(f"[load_state_dict] unexpected keys: {unexpected}")
    elif isinstance(ckpt, dict) and 'model' in ckpt and isinstance(ckpt['model'], dict):
        missing, unexpected = model.load_state_dict(ckpt['model'], strict=False)
        if missing:
            print(f"[load_state_dict] missing keys: {missing}")
        if unexpected:
            print(f"[load_state_dict] unexpected keys: {unexpected}")
    else:
        available = list(ckpt.keys()) if isinstance(ckpt, dict) else type(ckpt)
        raise KeyError()
    model.eval()

    

    if hasattr(config, 'sample_dims'):
        dims = tuple(config.sample_dims)
    else:

        dims = (64, 64, 50)
    

    if args.constraint_type == 'none':
        output_dir = os.path.join(args.output_dir, "unconstrained")
    else:
        output_dir = os.path.join(
            args.output_dir,
            f"{args.sampling_method}_sample_nmix_{args.n_mix}_alpha_exp_{args.alpha_exp}"
        )
    os.makedirs(output_dir, exist_ok=True)
    

    start_time = time.time()
    if args.constraint_type == 'none':

        samples = sample_unconstrained(model, sampling_config, dims, device)
        save_samples(samples, output_dir, sampling_config, constraint_type="unconstrained")
    
    else:

        _, full_test_set = get_dataset(config.datasets)


        limit = min(args.n_sample, len(full_test_set))
        if limit < len(full_test_set):
            from torch.utils.data import Subset

            test_set = Subset(full_test_set, list(range(limit)))
        else:

            test_set = full_test_set

        test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False)


        projection_info_full = load_h5_test_data(config, device)


        if isinstance(projection_info_full, dict):
            for key, tensor_val in list(projection_info_full.items()):
                if hasattr(tensor_val, 'shape') and tensor_val.shape[0] >= limit:
                    projection_info_full[key] = tensor_val[:limit]
        else:
            projection_info_full = projection_info_full[:limit]

        if len(test_set) != projection_info_full['solution'].shape[0]:
            raise ValueError()

        projection_tensors = [
            projection_info_full['solution'],
            projection_info_full['initial_condition'],
        ]
        projection_dataset = TensorDataset(*projection_tensors)
        projection_loader = DataLoader(projection_dataset, batch_size=args.batch_size, shuffle=False)

        all_samples = []
        all_ground_truth_sols = []
        all_ground_truth_ics = []
        

        
        pbar_desc = f"{args.sampling_method.upper()}, {args.constraint_type.upper()}"
        pbar = tqdm(zip(test_loader, projection_loader), total=len(test_loader), desc=pbar_desc)
        
        for test_data_batch, projection_tensors_batch in pbar:
            test_data_batch = test_data_batch.to(device)
            

            projection_info_batch = {
                'solution': projection_tensors_batch[0],
                'initial_condition': projection_tensors_batch[1]
            }

            batch_size = test_data_batch.shape[0]
            sampling_config.batch_sample = test_data_batch.shape[0]
            # current_dims = (batch_size,) + dims[1:]
            current_dims = dims


            if args.sampling_method == 'eci':
                samples_batch = sample_eci_constrained(
                    model, test_data_batch, sampling_config, current_dims, device, projection_info_batch
                )
            elif args.sampling_method == 'pcfm':
                samples_batch = sample_pcfm_constrained(
                    model, test_data_batch, sampling_config, current_dims, device, projection_info_batch, config
                )
            elif args.sampling_method == 'ccd':
                samples_batch = sample_ccd_constrained(
                    model, test_data_batch, sampling_config, current_dims, device, projection_info_batch, config
                )
            elif args.sampling_method == 'dpde':
                samples_batch = sample_dpde_constrained(
                    model, test_data_batch, sampling_config, current_dims, device, projection_info_batch, config
                )
            
            all_samples.append(samples_batch.cpu())
            all_ground_truth_sols.append(projection_info_batch['solution'].cpu())
            all_ground_truth_ics.append(projection_info_batch['initial_condition'].cpu())

        samples = torch.cat(all_samples, dim=0)
        save_samples(samples, output_dir, sampling_config, constraint_type=f"{args.sampling_method}_{args.constraint_type}")

        ground_truth_sols = torch.cat(all_ground_truth_sols, dim=0)
        ground_truth_ics = torch.cat(all_ground_truth_ics, dim=0)
        
        metrics = evaluate_and_print_results(
            generated_samples=samples.numpy(),
            real_samples=ground_truth_sols.numpy(),
            ic_ref=ground_truth_ics.numpy()
        )

    end_time = time.time()
    total_sec = float(end_time - start_time)


    run_params = {
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "checkpoint": str(args.checkpoint),
        "config_path": str(args.config),
        "output_dir": str(output_dir),
        "device": str(args.device),
        "sampling_method": str(args.sampling_method),
        "constraint_type": str(args.constraint_type),
        "n_sample": int(args.n_sample),
        "batch_size": int(args.batch_size),
        "n_step": int(args.n_step),
        "n_mix": int(args.n_mix),
        "resample_step": int(args.resample_step),
        "n_eval": int(args.n_eval),
        "alpha_exp": float(args.alpha_exp),
        "seed": int(args.seed),
        "dims_h": int(dims[0]),
        "dims_w": int(dims[1]),
        "dims_T": int(dims[2]),
        "total_time_sec": total_sec,
    }


    row = {**run_params, **metrics}
    csv_path = os.path.join(output_dir, "metrics.csv")
    _append_row_to_csv(csv_path, row)



if __name__ == '__main__':
    # 'ccd',,'pcfm'
    sampling_methods = ['ccd']
    alpha_exp_list = [0.5,0.7,0.9]
    for samling_method in sampling_methods:
            for alpha_exp in alpha_exp_list:
                main(samling_method, override_n_mix=None, override_alpha_exp=alpha_exp)
