import argparse
import os
from pathlib import Path
import datetime
import numpy as np

import torch
from torch_geometric.data import DataLoader
from torch_geometric.utils import to_dense_batch
from tqdm import tqdm

from data import get_dataset
from model import GPT, GPTConfig
from utils import sample, trim_tokens
from data.preprocess_gds import tensor_to_cells, write_layout_file

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--ckpt", type=str, required=True, help="Path to the model checkpoint")
    parser.add_argument("--max_length", type=int, default=600, help="Maximum sequence length")
    parser.add_argument("--n_layer", default=6, type=int, help="Number of transformer layers")
    parser.add_argument("--n_embd", default=512, type=int, help="Embedding dimension")
    parser.add_argument("--n_head", default=8, type=int, help="Number of attention heads")
    parser.add_argument("--input_dim", default=10, type=int, help="Input dimension")
    parser.add_argument("--disc_dim", default=6, type=int, help="Discrete dimension")
    parser.add_argument("--diffloss_d", type=int, default=3, help="DiffLoss depth")
    parser.add_argument("--diffloss_w", type=int, default=256, help="DiffLoss width")
    parser.add_argument("--num_sampling_steps", type=str, default="100")
    parser.add_argument("--grad_checkpointing", type=bool, default=False)
    parser.add_argument("--diffusion_batch_mul", type=int, default=4)
    parser.add_argument("--batch_size", type=int, default=16, help="Batch size")
    parser.add_argument("--out_dir", type=str, default="output/generated_gds", help="Output directory")
    parser.add_argument("--num_context_boxes", type=int, default=1, help="Number of context boxes")
    parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to generate")
    parser.add_argument("--resolution", type=int, default=40000, help="GDS resolution")
    parser.add_argument(
        "--data_path", type=str, default="./dataset"
    )
    parser.add_argument("--eos_alpha", type=float, default=0.1)
    return parser.parse_args()

def convert_baseline_to_model_input(data, device, input_dim=10):
    """Convert baseline data to model input format for GDS data"""
    label, mask = to_dense_batch(data.y, data.batch)
    bbox, _ = to_dense_batch(data.x, data.batch)

    # Create sequence tensor
    batch_size, seq_length = bbox.size(0), bbox.size(1)
    sequence = torch.zeros(batch_size, seq_length, input_dim, device=device)
    sequence[:, :, :4] = bbox  # coordinates

    # Convert labels to one-hot
    for i in range(batch_size):
        for j in range(seq_length):
            if mask[i, j]:
                sequence[i, j, 4 + label[i, j]] = 1.0  # one-hot labels

    # Create BOS token
    bos_token = torch.zeros(batch_size, 1, input_dim, device=device)
    bos_token[:, :, 7] = 1.0  # BOS one-hot

    # Combine BOS with sequence
    full_sequence = torch.cat([bos_token, sequence], dim=1)

    return full_sequence, label, mask

def get_gds_header():
    """Generate the fixed GDS header"""
    current_time = datetime.datetime.now().strftime("%m/%d/%Y %H:%M:%S")
    header = [
        "HEADER 600",
        "BGNLIB 3/20/2025 13:28:02 3/20/2025 13:28:02",
        "LIBNAME LIB",
        "UNITS 5e-04 5e-10",
        "BGNSTR 3/20/2025 13:28:02 3/20/2025 13:28:02",
        "STRNAME S5E5535"
    ]
    return header

def convert_to_gds_cell_format(generated_layout, resolution=40000):
    """Convert model output to GDS cell format"""
    # Extract components
    classes = generated_layout[:, 0].int()  # [N]
    
    normalized_coords = torch.clamp(generated_layout[:, 1:5], 0.0, 1.0)
    print(normalized_coords)
    
    # xc = normalized_coords[:, 0] * (2 * resolution) - resolution
    # yc = normalized_coords[:, 1] * (2 * resolution) - resolution
    # w = normalized_coords[:, 2] * (2 * resolution) 
    # h = normalized_coords[:, 3] * (2 * resolution)
    
    xc = normalized_coords[:, 0] * resolution
    yc = normalized_coords[:, 1] * resolution
    w = normalized_coords[:, 2] * resolution
    h = normalized_coords[:, 3] * resolution
    
    # Calculate corners
    half_w = w / 2
    half_h = h / 2
    
    # Create tensor for GDS cell format [layer, x1, y1, x2, y2, x3, y3, x4, y4]
    N = len(generated_layout)
    gds_format = torch.zeros(N, 9)
    
    # Set layer type
    gds_format[:, 0] = classes
    
    # Set coordinates for 4 corners (rectangle) in clockwise order
    # Bottom-left
    gds_format[:, 1] = np.clip(xc - half_w, 0, resolution)
    gds_format[:, 2] = np.clip(yc - half_h, 0, resolution)
    
    # Top-left
    gds_format[:, 3] = np.clip(xc - half_w, 0, resolution)
    gds_format[:, 4] = np.clip(yc + half_h, 0, resolution)
    
    # Top-right
    gds_format[:, 5] = np.clip(xc + half_w, 0, resolution)
    gds_format[:, 6] = np.clip(yc + half_h, 0, resolution)
    
    # Bottom-right
    gds_format[:, 7] = np.clip(xc + half_w, 0, resolution)
    gds_format[:, 8] = np.clip(yc - half_h, 0, resolution)
    
    return gds_format.int()  # Convert to integers for GDS coordinates


def main():
    args = get_args()

    out_dir = Path(args.out_dir)
    out_dir.mkdir(exist_ok=True, parents=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load GDS dataset for context boxes
    dataset = get_dataset("gds", "test", data_path=args.data_path)
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        num_workers=4,
        pin_memory=True,
        shuffle=False,
    )

    mconf = GPTConfig(
        vocab_size=6,  # BOS, EOS, PAD + 3 layout types for GDS
        block_size=args.max_length + 1,
        n_layer=args.n_layer,
        n_head=args.n_head,
        n_embd=args.n_embd,
        input_dim=args.input_dim,
        disc_dim=args.disc_dim,
        diffloss_d=args.diffloss_d,
        diffloss_w=args.diffloss_w,
        num_sampling_steps=args.num_sampling_steps,
        grad_checkpointing=args.grad_checkpointing,
        diffusion_batch_mul=args.diffusion_batch_mul,
        max_length=args.max_length,
        eos_alpha=args.eos_alpha,
        length_loss_weight=0.0          # Not used in sampling
    )

    # Load model
    model = GPT(mconf)

    checkpoint = torch.load(args.ckpt, map_location=device)    
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()

    # Generate GDS header
    header = get_gds_header()
    
    # Keep track of generated samples
    sample_count = 0
    
    with torch.no_grad():
        for data in tqdm(dataloader):
            if sample_count >= args.num_samples:
                break
                
            data = data.to(device)
            x, _, mask = convert_baseline_to_model_input(data, device, args.input_dim)
            
            # Generate layouts
            sampled_layouts = sample(
                model,
                x[:, : args.num_context_boxes + 1, :],
                steps=args.max_length,
            ).cpu()
            
            print(sampled_layouts)
            
            for i, layout in enumerate(sampled_layouts):
                if not mask[i].any() or sample_count >= args.num_samples:
                    continue
                    
                # Trim special tokens (BOS, EOS, PAD)
                trimmed = trim_tokens(layout, bos=3, eos=4, pad=5)
                
                if len(trimmed) == 0:
                    continue
                
                # Convert from model format to GDS cell format
                gds_format = convert_to_gds_cell_format(trimmed, args.resolution)
                
                # Convert tensor to cells dictionary format
                cells = tensor_to_cells(gds_format)
                
                if not cells:
                    continue
                
                # Save as GDS text file
                output_path = out_dir / f"generated_{sample_count}.txt"
                write_layout_file(output_path, header, cells)
                
                print(f"Generated GDS layout saved to {output_path}")
                sample_count += 1
                
                if sample_count <= 2:
                    print(f"Generated {len(cells)} cells with types: {[cell['Layer'] for cell in cells]}")


if __name__ == "__main__":
    main()