import argparse
import os
import logging
from typing import Tuple, List, Optional

import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image, ImageDraw

from data import get_dataset
from model import GPT, GPTConfig
from trainer import Trainer, TrainerConfig
from utils import set_seed

# Configure logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

class BaselineDatasetWrapper:
    """ Wrapper for dataset """
    def __init__(self, dataset, max_length: int = 200):
        self.dataset = dataset
        self.max_length = max_length
        self.vocab_size = 6  # BOS, EOS, PAD + 3 classes for GDSLayout

        # Define special tokens
        self.bos_token = torch.tensor(
            [0.0, 0.0, 0.0, 0.0, 0, 0, 0, 1, 0, 0], dtype=torch.float32
        )
        self.eos_token = torch.tensor(
            [0.0, 0.0, 0.0, 0.0, 0, 0, 0, 0, 1, 0], dtype=torch.float32
        )
        self.pad_token = torch.tensor(
            [0.0, 0.0, 0.0, 0.0, 0, 0, 0, 0, 0, 1], dtype=torch.float32
        )

        # Define colors for visualization
        self.colors = [
            (255, 0, 0),    # Red
            (0, 255, 0),    # Green
            (0, 0, 255),    # Blue
            (255, 255, 0),  # Yellow
            (255, 0, 255),  # Magenta
        ]

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        data = self.dataset[idx]

        label = data.y
        bbox = data.x

        seq_length = len(bbox)
        
        if seq_length > self.max_length:
            seq_length = self.max_length 
            bbox = bbox[:self.max_length]
            label = label[:self.max_length]
        
        sequence = torch.zeros(seq_length, 10)
        sequence[:, :4] = bbox
        sequence[range(seq_length), 4 + label] = 1.0  # one-hot labels

        chunk = torch.stack([self.pad_token] * (self.max_length + 2))
        chunk[0] = self.bos_token
        chunk[1 : seq_length + 1] = sequence
        chunk[seq_length + 1] = self.eos_token

        x = chunk[:-1]
        y = chunk[1:]

        expanded_pad = self.pad_token.expand(y.size())
        mask = ~torch.all(y == expanded_pad, dim=-1)

        return x, y, mask
    
    def render(self, layout: np.ndarray, size: Tuple[int, int] = (256, 256)) -> Image.Image:
        img = Image.new("RGB", size, color=(255, 255, 255))
        draw = ImageDraw.Draw(img)

        for item in layout:
            if item[0] < 5:
                cls_idx = int(item[0])
                x, y, w, h = item[1:5]

                x1 = int(x * size[0])
                y1 = int(y * size[1])
                x2 = int((x + w) * size[0])
                y2 = int((y + h) * size[1])

                color = self.colors[cls_idx % len(self.colors)] 
                draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
        
        return img

def parse_arguments():
    """ Parse command line arguments """
    parser = argparse.ArgumentParser("Layout Transformer")
    parser.add_argument("--exp", default="layout", help="experiment name")
    parser.add_argument("--log_dir", default="./logs", help="/path/to/logs/dir")

    # Layout options
    parser.add_argument("--max_length", type=int, default=600, help="max seq length")
    parser.add_argument("--input_dim", type=int, default=10, help="input dim")
    parser.add_argument("--disc_dim", type=int, default=6, help="discrete dim")

    # DiffLoss options
    parser.add_argument("--diffloss_d", type=int, default=3)
    parser.add_argument("--diffloss_w", type=int, default=256)
    parser.add_argument("--num_sampling_steps", type=str, default="100")
    parser.add_argument("--grad_checkpointing", type=bool, default=False)
    parser.add_argument("--diffusion_batch_mul", type=int, default=30)
    parser.add_argument("--loss_weight", type=int, default=1)

    # Architecture/training options
    parser.add_argument("--seed", type=int, default=42, help="random seed")
    parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
    parser.add_argument("--batch_size", type=int, default=64, help="batch size")
    parser.add_argument("--lr", type=float, default=4.5e-06, help="learning rate")
    parser.add_argument("--n_layer", default=6, type=int)
    parser.add_argument("--n_embd", default=512, type=int)
    parser.add_argument("--n_head", default=8, type=int)
    parser.add_argument(
        "--lr_decay", action="store_true", help="use learning rate decay"
    )
    parser.add_argument(
        "--warmup_iters", type=int, default=0, help="linear lr warmup iters"
    )
    parser.add_argument(
        "--final_iters", type=int, default=0, help="cosine lr final iters"
    )
    parser.add_argument(
        "--sample_every", type=int, default=1, help="sample every epoch"
    )
    parser.add_argument(
        "--data_path", type=str, default="./dataset/"
    )
    
    parser.add_argument("--resume", type=str, default=None, help="path to checkpoint for resuming training")
    
    parser.add_argument(
        "--eos_alpha", type=float, default=0.1
    )
    parser.add_argument(
        "--length_loss_weight", type=float, default=0.1
    )

    return parser.parse_args()

def setup_directories(args):
    """ Create necessary directories for logging and checkpoints """
    log_dir = os.path.join(args.log_dir, args.exp)
    samples_dir = os.path.join(log_dir, "samples")
    ckpt_dir = os.path.join(log_dir, "ckpt")

    os.makedirs(samples_dir, exist_ok=True)
    os.makedirs(ckpt_dir, exist_ok=True)

    return log_dir, samples_dir, ckpt_dir

def initialize_model(args, train_dataset):
    """ Initialize model with figuration """
    mconf = GPTConfig(
        vocab_size=train_dataset.vocab_size,
        block_size=train_dataset.max_length + 1,
        n_layer=args.n_layer,
        n_head=args.n_head,
        n_embd=args.n_embd,
        input_dim=args.input_dim,
        disc_dim=args.disc_dim,
        diffloss_w=args.diffloss_w,
        diffloss_d=args.diffloss_d,
        num_sampling_steps=args.num_sampling_steps,
        grad_checkpointing=args.grad_checkpointing,
        diffusion_batch_mul=args.diffusion_batch_mul,
        max_length=args.max_length,
        eos_alpha=args.eos_alpha,
        length_loss_weight=args.length_loss_weight,
    )

    model = GPT(mconf)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    return model, device

def main():
    """ Main entry point for training """
    args = parse_arguments()
    log_dir, samples_dir, ckpt_dir = setup_directories(args)

    # Set seed
    set_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")

    try:
        # Load datasets
        train_dataset_base = get_dataset("gds", "train", data_path=args.data_path)
        val_dataset_base = get_dataset("gds", "val", data_path=args.data_path)

        train_dataset = BaselineDatasetWrapper(train_dataset_base, args.max_length)
        val_dataset = BaselineDatasetWrapper(val_dataset_base, args.max_length)

        # Initialize model
        model, device = initialize_model(args, train_dataset)

        # Configure trainer
        tconf = TrainerConfig(
            max_epochs=args.epochs,
            batch_size=args.batch_size,
            learning_rate=args.lr,
            lr_decay=args.lr_decay,
            warmup_iters=args.warmup_iters,
            final_iters=args.final_iters,
            sample_every=args.sample_every,
            ckpt_dir=ckpt_dir,
            samples_path=samples_dir,
            loss_weight=args.loss_weight,
            length_loss_weight=args.length_loss_weight,
        )

        start_epoch = 0

        trainer = Trainer(model, train_dataset, val_dataset, tconf, args)
        
        if args.resume:
            if os.path.isfile(args.resume):
                start_epoch = trainer.load_checkpoint(args.resume)
                logger.info(f"Resuming training from epoch {start_epoch}")
            else:
                logger.warning(f"No checkpoint found at {args.resume}")
                
        trainer.train(start_epoch=start_epoch)

    except Exception as e:
        logger.error(f"Error in training: {e}", exc_info=True)
        return 1 

    return 0

if __name__ == "__main__":
    main()
