import tqdm
import pickle
import numpy as np
import torch
import PIL.Image
import dnnlib
import argparse
from define_data import *
from torch.utils.data import DataLoader
from define_data import get_data
import os
import random
import json
from torch.nn.functional import silu

def set_seed(manualSeed=0):
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.cuda.manual_seed(manualSeed)
    torch.cuda.manual_seed_all(manualSeed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(manualSeed)

class ForwardHook():
    def __init__(self, module, name):
        self.hook = module.register_forward_hook(self.hook_fn)
        self.name = name
        self.outputs = None

    def hook_fn(self, module, module_in, module_out):
        self.outputs = module_out

    def clear(self):
        self.outputs = None
        
    def close(self):
        self.hook.remove()

def parse_eval_args():
    parser = argparse.ArgumentParser()

    # Directory Setting
    parser.add_argument('--data_dir', type=str, default = None)
    parser.add_argument('--dataset', type=str, default='cifar10') 
    parser.add_argument('--save_name', type=str, default = "mem_vp_syn")
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--seed', type=int, default=8)
    parser.add_argument("--clean_image", dest='clean_image', action="store_true")
    parser.add_argument("--random_noise", dest='random_noise', action="store_true")
    parser.add_argument('--hook_pos', type=str, nargs='+', default=['8x8_block0.norm0', '8x8_block0.norm1', '8x8_block1.norm0', '8x8_block2.norm0', '8x8_block3.norm0', '8x8_block3.norm1', '16x16_block0.norm0'], help='List of module names to hook. Use spaces to separate them.')
    parser.add_argument('--ckpt_path', type=str, default=None) 
    parser.add_argument('--n_data', type=int, default=1000) # Length of the dataset to process
    parser.add_argument('--extract_timestep', type=int, default=None)

    args = parser.parse_args()

    return args
#----------------------------------------------------------------------------

def image_2_latent_random(
    net, images, class_labels,
    device=torch.device('cuda'),
    num_steps=18, rho=7,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
    hooks=None,
    extract_step=None
):
    assert hooks is not None
    sigma_const = 0.5
    
    features_hooked_dict = {}

    # Adjust noise levels based on what's supported by the network.
    sigma_min = max(0.002, net.sigma_min)
    sigma_max = min(80, net.sigma_max)

    # Time step discretization.
    step_indices = torch.arange(num_steps, dtype=torch.float64, device=device)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
    t_steps = torch.flip(t_steps, dims=(0,))[1:]

    # Main sampling loop.
    for i, (t_cur, t_next) in tqdm.tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:]))), unit='step'): # 0, ..., N-1
        if i == extract_step:
            gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
            t_hat = net.round_sigma(t_cur + gamma * t_cur)

            noised_images = images + torch.randn(images.shape).to(images.device) * t_hat
            denoised = net(noised_images, t_hat, class_labels).to(torch.float64)
            
            for hook in hooks:
                features = hook.outputs
                if features is not None:
                    features = silu(features)  # Apply activation function
                    features = torch.max(torch.max(features, dim=2)[0], dim=2)[0]  # Global max pooling
                    features_hooked_dict[hook.name] = features
                    hook.clear()

            norm = (noised_images - denoised) / t_hat

    return features_hooked_dict, norm


def main():
    args = parse_eval_args()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    args.device = device
    num_step = 18

    set_seed(args.seed)

    # Dataset
    class FlatImageFolder(Dataset):
        """A dataset for a folder of images without class subdirectories."""
        def __init__(self, root, transform=None):
            self.root = root
            # Get a sorted list of all image files
            self.image_paths = sorted([
                os.path.join(root, f) for f in os.listdir(root)
                if os.path.isfile(os.path.join(root, f))
            ])
            self.transform = transform

        def __len__(self):
            return len(self.image_paths)

        def __getitem__(self, idx):
            # Load and convert image to RGB to ensure consistency
            image = Image.open(self.image_paths[idx]).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, 0
        
    trainset = FlatImageFolder(
        root=None, ## Your Mem. data path
        transform=transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
        ]),
    )
    trainloader = DataLoader(trainset, batch_size=args.batch_size, num_workers=8, shuffle=False)

    testset = FlatImageFolder(
        root=None, ## Your Gen. data path
        transform=transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
        ]),
    )
    testloader = DataLoader(testset, batch_size=args.batch_size, num_workers=8, shuffle=False)

    # Load network.
    network_pkl = args.ckpt_path
    print(f'Loading network from "{network_pkl}"...')
    with dnnlib.util.open_url(network_pkl) as f:
        net = pickle.load(f)['ema'].to(device)
        
    # Add hooks for multiple positions
    print(f"\nGet features from positions: {args.hook_pos}!\n")
    hooks = []
    # Create a list to store the names of the hooks you successfully registered.
    hooked_positions = []
    for hook_pos in args.hook_pos:
        found_module = False
        for name, module in net.model.dec.named_modules():
            if name == hook_pos:
                print(f"Hooking {name}...")
                hooks.append(ForwardHook(module, name))
                hooked_positions.append(name)
                found_module = True
                break
        if not found_module:
            print(f"Warning: Module '{hook_pos}' not found in the network.")
    
    if not hooks:
        raise ValueError(f"No modules found for the specified hook positions: {args.hook_pos}")


    for i, (images, labels) in enumerate(trainloader):
        if i * args.batch_size >= args.n_data:
            print(f"Collected samples: {args.n_data}. Stopping at batch {i}.")
            break
        images, labels = images.to(args.device), labels.to(args.device)
        images = images * 2 - 1
        
        sample_stats = {
            "idx": i,
            "features": {}
        }
        

        features_dict, norm = image_2_latent_random(net, images, None, num_steps=num_step, hooks=hooks, extract_step=args.extract_timestep)

        # Process features from each hook
        for hook_pos in hooked_positions:
            if hook_pos in features_dict:
                features = features_dict[hook_pos]
                vis = features.flatten()
                softmax = torch.nn.functional.softmax(vis, dim=0)
                entropy = -torch.sum(softmax * torch.log(softmax + 1e-10)).item()
                
                sample_stats["features"][hook_pos] = {
                    "std": float(vis.std().item()),
                    "max_min": float((vis.max() - vis.min()).item()),
                    "entropy": float(entropy),
                    "l4/l2": float((vis.norm(p=float("inf")) / vis.norm(p=2)).item())
                }

        # --- Save Output ---
        output_dir = None ## Your output path
        os.makedirs(output_dir, exist_ok=True)
        output_path = os.path.join(output_dir, f"mem_all_stats.jsonl")

        with open(output_path, "a") as f:
            f.write(json.dumps(sample_stats) + "\n")
    

    for i, (image, label) in enumerate(testloader):
        # ... (similar logic as the training loop)
        image, label = image.to(args.device), label.to(args.device)
        image = image * 2 - 1
        
        sample_stats = {
            "idx": i,
            "features": {}
        }
        

        features_dict, norm = image_2_latent_random(net, image, None, num_steps=num_step, hooks=hooks, extract_step=args.extract_timestep)
        for hook_pos in hooked_positions:
            if hook_pos in features_dict:
                features = features_dict[hook_pos]
                vis = features.flatten()
                softmax = torch.nn.functional.softmax(vis, dim=0)
                entropy = -torch.sum(softmax * torch.log(softmax + 1e-10)).item()
                
                sample_stats["features"][hook_pos] = {
                    "std": float(vis.std().item()),
                    "max_min": float((vis.max() - vis.min()).item()),
                    "entropy": float(entropy),
                    "l4/l2": float((vis.norm(p=float("inf")) / vis.norm(p=2)).item())
                }

        output_dir = None ## Your output path
        os.makedirs(output_dir, exist_ok=True)
        output_path = os.path.join(output_dir, f"gen_all_stats.jsonl")

        with open(output_path, "a") as f:
            f.write(json.dumps(sample_stats) + "\n")


#----------------------------------------------------------------------------
if __name__ == "__main__":
    main()

#----------------------------------------------------------------------------