
import argparse
import os
import time
import torch
from torch.utils.data import DataLoader, TensorDataset
from torchvision.utils import save_image, make_grid
from easydict import EasyDict
import numpy as np
from tqdm import tqdm
import h5py
import csv
from datetime import datetime


from models import get_flow_model
from models.constraints import (
    InitialCondition, BoundaryCondition, ChainConstraint, 
    NoneConstraint, PeriodicCondition
)
from datasets import get_dataset
from utils import load_config, seed_all
from vis_utils import draw


class SamplingConfig:

    def __init__(self):

        self.n_sample = 64
        self.n_step = 200
        self.n_mix = 3
        self.resample_step = 10
        self.n_eval = 1
        

        self.n_vis_samples = 16
        self.grid_nrow = 4
        

        self.constraint_strength = 1.0


def create_constraint(constraint_type, test_data, device, config=None):

    if constraint_type == 'none':
        return NoneConstraint()
    
    elif constraint_type == 'ic':

        return InitialCondition(test_data[:, :, 0]).to(device)
    
    elif constraint_type == 'bc':

        return BoundaryCondition(test_data[:1, 0, :]).to(device)
    
    elif constraint_type == 'icbc':

        ic_constraint = InitialCondition(test_data[:1, :, 0])
        bc_constraint = BoundaryCondition(test_data[:1, 0, :])
        return ChainConstraint(ic_constraint, bc_constraint).to(device)
    
    elif constraint_type == 'periodic':

        ndim = len(test_data.shape) - 1
        return PeriodicCondition(ndim, dims=(1,)).to(device)
    
    elif constraint_type == 'custom':

        class CenterFixConstraint(torch.nn.Module):
            def __init__(self, center_value, center_region):
                super().__init__()
                self.center_value = center_value
                self.center_region = center_region
            
            def adjust(self, x1):
                h, w = self.center_region
                center_h, center_w = x1.shape[-2] // 2, x1.shape[-1] // 2
                x1[:, center_h-h//2:center_h+h//2, center_w-w//2:center_w+w//2] = self.center_value
                return x1
            
            def to(self, device):
                self.center_value = self.center_value.to(device)
                return self
        
        center_value = test_data[:1, 40:60, 40:60].mean()
        return CenterFixConstraint(center_value, (20, 20)).to(device)
    
    else:
        raise ValueError(f"{constraint_type}")


def sample_unconstrained(model, sampling_config, dims, device):

    start_time = time.time()
    
    with torch.no_grad():
        samples = model.sample(
            n_sample=sampling_config.n_sample,
            n_eval=sampling_config.n_eval,
            dims=dims,
            device=device
        )
    
    end_time = time.time()

    return samples

def sample_dpde_constrained(model, test_data, sampling_config, dims, device, projection_info, config):

    start_time = time.time()

    samples = model.guided_sample_rd(
        n_sample=sampling_config.batch_sample,
        n_step=sampling_config.n_step,
        dims=dims,
        device=device,
        projection_info=projection_info,
        eta=getattr(sampling_config, "eta", 200),
        bound_tolerance=getattr(sampling_config, "bound_tolerance", 1e-13),
    )

    end_time = time.time()
    return samples

def sample_eci_constrained(model, test_data, sampling_config, dims, device, projection_info):

    start_time = time.time()
    

    constraint = InitialCondition(test_data[:, :, 0]).to(device)
    
    with torch.no_grad():
        samples = model.eci_sample_rd(
            n_sample=sampling_config.batch_sample,
            n_step=sampling_config.n_step,
            n_mix=sampling_config.n_mix,
            resample_step=sampling_config.resample_step,
            dims=dims,
            device=device,
            constraint=constraint,
            projection_info=projection_info
        )
    
    end_time = time.time()
    print(f" {end_time - start_time:.2f}")
    return samples

def sample_pcfm_constrained(model, test_data, sampling_config, dims, device, projection_info, config):

    start_time = time.time()
    
    with torch.no_grad():
        samples = model.pcfm_sample_rd(
            n_sample=sampling_config.batch_sample,
            n_step=sampling_config.n_step,
            n_mix=sampling_config.n_mix,
            resample_step=sampling_config.resample_step,
            dims=dims,
            device=device,
            projection_info=projection_info
        )
    
    end_time = time.time()
    return samples
    
def sample_ccd_constrained(model, test_data, constraint_type, sampling_config, dims, device, projection_info, config):
    constraint = create_constraint(constraint_type, test_data, device, config)
    start_time = time.time()
    
    with torch.no_grad():
        samples = model.ccd_sample_rd(
            n_sample=sampling_config.batch_sample,
            n_step=sampling_config.n_step,
            dims=dims,
            device=device,
            projection_info=projection_info,
            alpha_exp=sampling_config.alpha_exp
        )
    
    end_time = time.time()
    
    return samples


def save_samples(samples, output_dir, config, constraint_type="unconstrained"):

    os.makedirs(output_dir, exist_ok=True)
    

    filename_prefix = constraint_type.replace(" ", "_").lower()
    torch.save(samples.cpu(), os.path.join(output_dir, f'{filename_prefix}_samples.pt'))
    

    n_vis = min(config.n_vis_samples, samples.shape[0])
    img_tensors = []
    
    for i in range(n_vis):

        if hasattr(config, 'vis') and config.vis:
            img_tensor = draw(samples[i], **config.vis)
        else:

            img_tensor = draw(samples[i], vmin=0.0, vmax=1.0)
        img_tensors.append(img_tensor)
    

    grid_img = make_grid(img_tensors, nrow=config.grid_nrow, padding=2, normalize=False)
    

    image_path = os.path.join(output_dir, f'{filename_prefix}_samples_grid.png')
    save_image(grid_img, image_path)

    

    stats = {
        'mean': samples.mean().item(),
        'std': samples.std().item(),
        'min': samples.min().item(),
        'max': samples.max().item(),
        'shape': list(samples.shape)
    }
    
    stats_path = os.path.join(output_dir, f'{filename_prefix}_stats.txt')
    with open(stats_path, 'w', encoding='utf-8') as f:
        f.write(f"Sampling Statistics ({constraint_type}):")
        f.write(f"Shape: {stats['shape']}\n")
        f.write(f"Mean: {stats['mean']:.6f}\n")
        f.write(f"Standard Deviation: {stats['std']:.6f}\n")
        f.write(f"Min: {stats['min']:.6f}\n")
        f.write(f"Max: {stats['max']:.6f}\n")

    print(f"Statistics saved to: {stats_path}")


def load_h5_test_data(config, device):

    h5_file_path = os.path.join(config.datasets.root, config.datasets.test['data_file'])
    

    with h5py.File(h5_file_path, 'r') as h5_file:

        attribute_names = list(h5_file.keys())

        all_data = {}
        
        for attr_name in attribute_names:

            attr_data = h5_file[attr_name][:]
            print(f"attribute '{attr_name}' raw shape: {attr_data.shape}")
            

            if isinstance(attr_data, np.ndarray):
                attr_tensor = torch.from_numpy(attr_data).float()
            else:
                attr_tensor = torch.tensor(attr_data, dtype=torch.float32)
            

            attr_tensor = attr_tensor.to(device)
            

            all_data[attr_name] = attr_tensor
            

        if len(all_data) == 1:
            single_attr_name = list(all_data.keys())[0]

            return all_data[single_attr_name]

        return all_data


def calculate_physics_loss_rd(samples, ic_ref, g_L, g_R, rho, dx, dt):

    nsim, nx, nt = samples.shape
    sol = samples


    ic_generated = sol[:, :, 0]
    ic_loss_per_sample = np.mean((ic_generated - ic_ref) ** 2, axis=1)


    mass = sol.sum(axis=1) * dx                     # (nsim, nt)
    reaction = rho[:, None, None] * sol * (1 - sol) # (nsim, nx, nt)
    R = reaction.sum(axis=1) * dx                   # (nsim, nt)
    flux = (g_L - g_R)[:, None]                     # (nsim, 1)
    integrand = R[:, :-1] + flux                    # (nsim, nt-1)
    C = np.cumsum(integrand * dt, axis=1)           # (nsim, nt-1)
    m0 = mass[:, 0]
    m_pred = m0[:, None] + C                        # (nsim, nt-1)
    m_true = mass[:, 1:]                            # (nsim, nt-1)
    mass_loss_per_sample = np.mean((m_true - m_pred) ** 2, axis=1)

    return ic_loss_per_sample, mass_loss_per_sample


def evaluate_and_print_results_rd(generated_samples, real_samples, ic_ref, g_L, g_R, rho):


    nsim, nx, nt = generated_samples.shape
    if real_samples.shape[-2:] != (nx, nt):
        print(f"Warning: Shape mismatch. Gen: {(nx, nt)}, Real: {real_samples.shape[-2:]}. Truncating if needed.")
        nt_use = min(nt, real_samples.shape[-1])
        nx_use = min(nx, real_samples.shape[-2])
        generated_samples = generated_samples[:, :nx_use, :nt_use]
        real_samples = real_samples[:, :nx_use, :nt_use]
        ic_ref = ic_ref[:, :nx_use]
        nx, nt = nx_use, nt_use

    dx = 1.0 / nx
    dt = 1.0 / (nt - 1) if nt > 1 else 1.0

    print(f"Number of samples compared: {generated_samples.shape[0]}\n")

    mse_per_sample = np.mean((real_samples - generated_samples) ** 2, axis=(1, 2))
    mean_mse = np.mean(mse_per_sample)
    std_mse = np.std(mse_per_sample)

    ic_loss_gen, mass_loss_gen = calculate_physics_loss_rd(generated_samples, ic_ref, g_L, g_R, rho, dx, dt)
    ic_loss_real, mass_loss_real = calculate_physics_loss_rd(real_samples, ic_ref, g_L, g_R, rho, dx, dt)

    print(f"--- Error vs. Ground Truth ---")
    print(f"Mean of Mean Squared Error (MSE): {mean_mse:.6e}")
    print(f"Std. of Mean Squared Error (MSE): {std_mse:.6e}\n")

    print(f"--- Constraint Violations (Generated Samples) ---")
    print(f"Mean Initial Condition Violation: {np.mean(ic_loss_gen):.6e}")
    print(f"Std. Initial Condition Violation: {np.std(ic_loss_gen):.6e}")
    print(f"Mean Mass Conservation Violation: {np.mean(mass_loss_gen):.6e}")
    print(f"Std. Mass Conservation Violation: {np.std(mass_loss_gen):.6e}\n")

    print(f"--- Constraint Violations (Real Samples for Reference) ---")
    print(f"Mean Initial Condition Violation: {np.mean(ic_loss_real):.6e}")
    print(f"Std. Initial Condition Violation: {np.std(ic_loss_real):.6e}")
    print(f"Mean Mass Conservation Violation: {np.mean(mass_loss_real):.6e}")
    print(f"Std. Mass Conservation Violation: {np.std(mass_loss_real):.6e}")
    print("=" * 50)

    metrics = {
        "mean_mse": float(mean_mse),
        "std_mse": float(std_mse),
        "gen_ic_violation_mean": float(np.mean(ic_loss_gen)),
        "gen_ic_violation_std": float(np.std(ic_loss_gen)),
        "gen_mass_violation_mean": float(np.mean(mass_loss_gen)),
        "gen_mass_violation_std": float(np.std(mass_loss_gen)),
        "real_ic_violation_mean": float(np.mean(ic_loss_real)),
        "real_ic_violation_std": float(np.std(ic_loss_real)),
        "real_mass_violation_mean": float(np.mean(mass_loss_real)),
        "real_mass_violation_std": float(np.std(mass_loss_real)),
        "n_samples_compared": int(generated_samples.shape[0]),
        "nx": int(nx),
        "T": int(nt),
    }
    return metrics


def _append_row_to_csv(csv_path, row_dict):

    fieldnames = list(row_dict.keys())
    file_exists = os.path.isfile(csv_path)
    with open(csv_path, mode='a', newline='', encoding='utf-8') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        if not file_exists:
            writer.writeheader()
        writer.writerow(row_dict)

def main(sampling_method, override_n_mix=None, override_alpha_exp=None):
    parser = argparse.ArgumentParser(description='')

    parser.add_argument('--checkpoint', type=str, default='/path/to/eci/rd/models/15800.pt',
                        help='')
    # rd_0815 
    parser.add_argument('--config', type=str, default='configs/rd.yml',
                        help='')
    parser.add_argument('--output_dir', type=str, default='/path/to/eci/rd/results',
                        help='')
    parser.add_argument('--device', type=str, default='cuda',
                        help='')
    

    parser.add_argument('--sampling_method', type=str, default=sampling_method,
                        choices=['eci', 'pcfm', 'ccd', 'dpde'],
                        help='')
    parser.add_argument('--constraint_type', type=str, default='ic',
                        choices=['none', 'ic', 'bc', 'icbc', 'periodic', 'custom'],
                        help='')
    parser.add_argument('--n_sample', type=int, default=1000,
                        help='')
    parser.add_argument('--batch_size', type=int, default=64,
                        help='')
    parser.add_argument('--n_step', type=int, default=200,
                        help='')
    parser.add_argument('--n_mix', type=int, default=1,
                        help='')
    parser.add_argument('--resample_step', type=int, default=0,
                        help='')
    parser.add_argument('--n_eval', type=int, default=1,
                        help='')
    parser.add_argument('--alpha_exp', type=float, default=1.0,
                        help='')

    parser.add_argument('--n_vis_samples', type=int, default=16,
                        help='')
    parser.add_argument('--grid_nrow', type=int, default=4,
                        help='')

    parser.add_argument('--seed', type=int, default=0,
                        help='')
    
    args = parser.parse_args()


    if override_n_mix is not None:
        args.n_mix = override_n_mix
    if override_alpha_exp is not None:
        args.alpha_exp = override_alpha_exp
    

    seed_all(args.seed)

    if args.device == 'cuda' and not torch.cuda.is_available():

        args.device = 'cpu'
        raise ValueError()
    
    device = torch.device(args.device)


    try:

        torch.serialization.add_safe_globals([EasyDict])
    except Exception:
        pass
    try:
        ckpt = torch.load(args.checkpoint, map_location=device, weights_only=False)
    except TypeError:

        ckpt = torch.load(args.checkpoint, map_location=device)
    

    if args.config:
        config = load_config(args.config)
    else:
        config = EasyDict(ckpt['config'])
    

    

    sampling_config = SamplingConfig()
    sampling_config.n_sample = args.n_sample
    sampling_config.n_step = args.n_step
    sampling_config.n_mix = args.n_mix
    sampling_config.resample_step = args.resample_step if args.resample_step > 0 else None
    sampling_config.n_eval = args.n_eval
    sampling_config.n_vis_samples = args.n_vis_samples
    sampling_config.grid_nrow = args.grid_nrow
    sampling_config.alpha_exp = args.alpha_exp
    
    

    if hasattr(config, 'vis'):
        sampling_config.vis = config.vis
    

    model = get_flow_model(config.model, config.encoder).to(device)
    model.load_state_dict(ckpt['model'])
    model.eval()

    

    if hasattr(config, 'sample_dims'):
        dims = tuple(config.sample_dims)
    else:

        dims = (128, 100)

    
    if args.constraint_type == 'none':
        output_dir = os.path.join(args.output_dir, "unconstrained")
    else:
        output_dir = os.path.join(
            args.output_dir,
            f"{args.sampling_method}_sample_nmix_{args.n_mix}_alpha_exp_{args.alpha_exp}"
        )
    os.makedirs(output_dir, exist_ok=True)
    

    start_time = time.time()
    if args.constraint_type == 'none':

        samples = sample_unconstrained(model, sampling_config, dims, device)
        save_samples(samples, output_dir, sampling_config, constraint_type="unconstrained")
    
    else:

        _, full_test_set = get_dataset(config.datasets)

        limit = min(args.n_sample, len(full_test_set))
        if limit < len(full_test_set):
            from torch.utils.data import Subset

            test_set = Subset(full_test_set, list(range(limit)))
        else:

            test_set = full_test_set

        test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False)
        

        projection_info_full = load_h5_test_data(config, device)


        if isinstance(projection_info_full, dict):
            for key in ['solution', 'initial_condition', 'bc_flux_left', 'bc_flux_right', 'rho']:
                if key in projection_info_full:
                    tensor_val = projection_info_full[key]
                    if hasattr(tensor_val, 'shape') and tensor_val.shape[0] >= limit:
                        projection_info_full[key] = tensor_val[:limit]
        else:
            projection_info_full = projection_info_full[:limit]

        if isinstance(projection_info_full, dict) and 'solution' in projection_info_full:
            if len(test_set) != projection_info_full['solution'].shape[0]:
                raise ValueError()
        projection_tensors = [
            projection_info_full['solution'],
            projection_info_full['initial_condition'],
            projection_info_full['bc_flux_left'],
            projection_info_full['bc_flux_right'],
            projection_info_full['rho']
        ]
        projection_dataset = TensorDataset(*projection_tensors)
        projection_loader = DataLoader(projection_dataset, batch_size=args.batch_size, shuffle=False)

        all_samples = []
        all_ground_truth_sols = []
        all_ground_truth_ics = []
        all_bc_left = []
        all_bc_right = []
        all_rho = []

        
        pbar_desc = f" ({args.sampling_method.upper()}, {args.constraint_type.upper()})"
        pbar = tqdm(zip(test_loader, projection_loader), total=len(test_loader), desc=pbar_desc)
        
        for test_data_batch, projection_tensors_batch in pbar:
            test_data_batch = test_data_batch.to(device)
            

            projection_info_batch = {
                'solution': projection_tensors_batch[0],
                'initial_condition': projection_tensors_batch[1],
                'bc_flux_left': projection_tensors_batch[2],
                'bc_flux_right': projection_tensors_batch[3],
                'rho': projection_tensors_batch[4]
            }

            batch_size = test_data_batch.shape[0]
            sampling_config.batch_sample = test_data_batch.shape[0]

            current_dims = dims

            if args.sampling_method == 'eci':
                samples_batch = sample_eci_constrained(
                    model, test_data_batch, sampling_config, current_dims, device, projection_info_batch
                )
            elif args.sampling_method == 'pcfm':
                samples_batch = sample_pcfm_constrained(
                    model, test_data_batch, args.constraint_type, sampling_config, current_dims, device, projection_info_batch, config
                )
            elif args.sampling_method == 'ccd':
                samples_batch = sample_ccd_constrained(
                    model, test_data_batch, args.constraint_type, sampling_config, current_dims, device, projection_info_batch, config
                )
            elif args.sampling_method == 'dpde':
                samples_batch = sample_dpde_constrained(
                    model, test_data_batch, sampling_config, current_dims, device, projection_info_batch, config
                )            
            all_samples.append(samples_batch.cpu())

            all_ground_truth_sols.append(projection_info_batch['solution'].cpu())
            all_ground_truth_ics.append(projection_info_batch['initial_condition'].cpu())
            all_bc_left.append(projection_info_batch['bc_flux_left'].cpu())
            all_bc_right.append(projection_info_batch['bc_flux_right'].cpu())
            all_rho.append(projection_info_batch['rho'].cpu())

        samples = torch.cat(all_samples, dim=0)
        save_samples(samples, output_dir, sampling_config, constraint_type=f"{args.sampling_method}_{args.constraint_type}")


        ground_truth_sols = torch.cat(all_ground_truth_sols, dim=0).numpy()
        ground_truth_ics = torch.cat(all_ground_truth_ics, dim=0).numpy()
        g_L = torch.cat(all_bc_left, dim=0).numpy()
        g_R = torch.cat(all_bc_right, dim=0).numpy()
        rho = torch.cat(all_rho, dim=0).numpy()

        metrics = evaluate_and_print_results_rd(
            generated_samples=samples.numpy(),
            real_samples=ground_truth_sols,
            ic_ref=ground_truth_ics,
            g_L=g_L,
            g_R=g_R,
            rho=rho,
        )


        end_time = time.time()
        total_sec = float(end_time - start_time)
        run_params = {
            "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            "checkpoint": str(args.checkpoint),
            "config_path": str(args.config),
            "output_dir": str(output_dir),
            "device": str(args.device),
            "sampling_method": str(args.sampling_method),
            "constraint_type": str(args.constraint_type),
            "n_sample": int(args.n_sample),
            "batch_size": int(args.batch_size),
            "n_step": int(args.n_step),
            "n_mix": int(args.n_mix),
            "resample_step": int(args.resample_step),
            "n_eval": int(args.n_eval),
            "alpha_exp": float(args.alpha_exp),
            "seed": int(args.seed),
            "dims_nx": int(dims[0]),
            "dims_T": int(dims[1]),
            "total_time_sec": total_sec,
        }
        row = {**run_params, **metrics}
        csv_path = os.path.join(output_dir, "metrics.csv")
        _append_row_to_csv(csv_path, row)


    end_time = time.time()


if __name__ == '__main__':
    for samling_method in ['ccd']:

        for alpha in [0.5]:
            main(samling_method, override_alpha_exp=alpha)
        # main(samling_method)