import os
import argparse

import torch
import torchmetrics
import pytorch_lightning as pl
import torchvision
from torch_geometric.data import Data, Batch
# from torch_geometric.loader import DataLoader
from torch.utils.data import DataLoader # Incorrect for this use case

# Import both models and necessary utilities
from models.rapidash.rapidash import Rapidash
from models.platoformer.platoformer import PlatonicTransformer
from models.platoformer.groups import PLATONIC_GROUPS
from utils import CosineWarmupScheduler, RandomSOd, TimerCallback

# In order to be able to download cifar on the server
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

# Performance optimization
torch.set_float32_matmul_precision('medium')


class CIFAR10Model(pl.LightningModule):
    def __init__(self, args):
        super().__init__()
        self.save_hyperparameters(args)
        
        # Setup 2D rotation augmentation for the point cloud
        self.rotation_generator = RandomSOd(2)

        # CIFAR10 point cloud has 3 scalar features (RGB) and 0 vector features
        in_channels_scalar = args.patch_size * args.patch_size * 3
        in_channels_vector = 0
        
        # The number of "points" is now the number of patches (e.g., (32/8)^2 = 16)
        self.avg_num_nodes = (32 // args.patch_size)**2
        
        # --- Model Selection and Initialization ---
        
        if self.hparams.model_type == "platoformer":
            # Set solid name based on equivariance for PlatonicTransformer
            if self.hparams.equivariance == "Tn":
                solid_name = "trivial_2"
            # SE(2) equivariance with 3D Platonic solids is not supported
            elif self.hparams.equivariance == "SEn":
                solid_name = "cyclic_"+str(self.hparams.num_ori) if self.hparams.num_ori > 1 else "trivial"
            else:
                raise ValueError(f"Unsupported equivariance for PlatonicTransformer: {self.hparams.equivariance}")

            # This sets the number of heads in case head_dim is specified.
            if self.hparams.head_dim is not None:
                num_heads = self.hparams.hidden_dim // (self.hparams.head_dim * PLATONIC_GROUPS[solid_name.lower()].G)
                if (self.hparams.num_heads is not None) and (num_heads != self.hparams.num_heads):
                    raise ValueError(f"head_dim {self.hparams.head_dim} does not match num_heads {self.hparams.num_heads} ")
                self.hparams.num_heads = num_heads

            self.net = PlatonicTransformer(
                input_dim=in_channels_scalar, 
                input_dim_vec=in_channels_vector,
                hidden_dim=self.hparams.hidden_dim,
                output_dim=10, # 10 classes for CIFAR10
                output_dim_vec=0,
                nhead=self.hparams.num_heads,
                num_layers=self.hparams.layers,
                solid_name=solid_name,
                ffn_dim_factor=0.25,
                task_level="graph",
                dropout=self.hparams.dropout,
                norm_first=self.hparams.norm_first,
                freq_sigma=self.hparams.freq_sigma,
                learned_freqs=self.hparams.learned_freqs,
                spatial_dim=2, # CIFAR10 is treated as a 2D point cloud
                dense_mode=self.hparams.dense_mode,
                mean_aggregation=self.hparams.mean_aggregation,
                attention=self.hparams.attention,
                post_pool_readout=self.hparams.post_pool_readout,
                ffn_readout=self.hparams.ffn_readout,
            )
            self.net = torch.compile(self.net)

        elif self.hparams.model_type == "rapidash":
            self.net = Rapidash(
                input_dim=in_channels_scalar,
                hidden_dim=self.hparams.hidden_dim,
                output_dim=10,
                num_layers=self.hparams.layers,
                edge_types=self.hparams.edge_types,
                equivariance=self.hparams.equivariance,
                ratios=self.hparams.ratios,
                output_dim_vec=0,
                dim=2,
                num_ori=self.hparams.num_ori,
                degree=self.hparams.degree,
                widening_factor=self.hparams.widening_factor,
                layer_scale=self.hparams.layer_scale,
                task_level='graph',
                last_feature_conditioning=False,
                skip_connections=self.hparams.skip_connections,
                basis_dim=self.hparams.basis_dim,
                basis_hidden_dim=self.hparams.basis_hidden_dim
            )
        
        # Setup metrics
        self.train_metric = torchmetrics.Accuracy(task="multiclass", num_classes=10)
        self.valid_metric = torchmetrics.Accuracy(task="multiclass", num_classes=10)
        self.test_metric = torchmetrics.Accuracy(task="multiclass", num_classes=10)

    def forward(self, data):
        # Apply rotation augmentation during training if enabled
        if self.training and self.hparams.train_augm:
            rot = self.rotation_generator().type_as(data.pos)
            data.pos = torch.einsum('ij,bj->bi', rot, data.pos)

        # Forward pass through the selected network
        if self.hparams.model_type == "platoformer":
            # The mask (batch_mask) and edge_index can be None
            pred, _ = self.net(data.x, data.pos, data.batch, vec=None, avg_num_nodes=self.avg_num_nodes)
        elif self.hparams.model_type == "rapidash":
            # Pass data.edge_index (which will be None). The model should build the graph internally.
            pred, _ = self.net(data.x, data.pos, data.edge_index, data.batch, vec=None)
        return pred

    def training_step(self, data, batch_idx):
        pred = self(data)
        loss = torch.nn.functional.cross_entropy(pred, data.y)
        self.train_metric(pred, data.y)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, data, batch_idx):
        pred = self(data)
        loss = torch.nn.functional.cross_entropy(pred, data.y)
        self.valid_metric(pred, data.y)
        self.log("valid_loss", loss)

    def test_step(self, data, batch_idx):
        pred = self(data)
        self.test_metric(pred, data.y)

    def on_train_epoch_end(self):
        self.log("train_acc", self.train_metric, prog_bar=True)

    def on_validation_epoch_end(self):
        self.log("valid_acc", self.valid_metric, prog_bar=True)

    def on_test_epoch_end(self):
        self.log("test_acc", self.test_metric, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = CosineWarmupScheduler(optimizer, self.hparams.warmup, self.trainer.max_epochs)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

def load_data(args):
    """Load and preprocess CIFAR10 dataset by patching images into point clouds."""
    
    # --- Assert that patch size is valid ---
    if 32 % args.patch_size != 0:
        raise ValueError("Image dimension (32) must be divisible by patch_size.")

    # --- Standard image augmentations and normalization ---
    transform_train = torchvision.transforms.Compose([
        torchvision.transforms.RandomCrop(32, padding=4),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
    ])

    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
    ])
    
    # --- Create a shared grid of PATCH coordinates, scaled to [0, 1] ---
    num_patches_1d = 32 // args.patch_size
    grid = torch.linspace(0.0, 1.0, num_patches_1d)
    grid_x, grid_y = torch.meshgrid(grid, grid, indexing='xy')
    # pos shape will be [num_patches, 2]
    patch_pos = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1)

    def collate_fn(batch):
        """
        Custom collate function to convert a batch of images into a batch of patch-based point clouds.
        """
        data_list = []
        p = args.patch_size
        
        for image_tensor, label in batch:
            # image_tensor shape: [C, H, W] = [3, 32, 32]
            
            # --- Patching Logic using unfold ---
            # 1. Unfold image into patches: [C, H, W] -> [C, num_patches_h, num_patches_w, p, p]
            patches = image_tensor.unfold(1, p, p).unfold(2, p, p)
            # 2. Permute and flatten: -> [num_patches_total, C*p*p]
            patches = patches.permute(1, 2, 0, 3, 4).contiguous()
            # x shape will be [num_patches, patch_size*patch_size*C]
            x = patches.view(-1, 3 * p * p)

            # Create a PyG Data object for each image's patch cloud
            data = Data(x=x, pos=patch_pos.clone(), y=torch.tensor([label]))
            data_list.append(data)
        
        # Batch the list of Data objects
        return Batch.from_data_list(data_list)

    # --- Create datasets and DataLoaders (this part remains the same) ---
    full_train_dataset = torchvision.datasets.CIFAR10(root=args.data_dir, train=True, transform=transform_train, download=True)
    test_dataset = torchvision.datasets.CIFAR10(root=args.data_dir, train=False, transform=transform, download=True)
    
    train_size = int(0.9 * len(full_train_dataset))
    val_size = len(full_train_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_train_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(args.seed))
   
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn)
    
    return train_loader, val_loader, test_loader

def main(args):
    pl.seed_everything(args.seed)
    train_loader, val_loader, test_loader = load_data(args)

    if args.gpus > 0:
        accelerator, devices = "gpu", args.gpus
    else:
        accelerator, devices = "cpu", "auto"
        
    if args.log:
        save_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "logs")
        logger = pl.loggers.WandbLogger(project="Platonic-CIFAR10", config=args, save_dir=save_dir)
    else:
        logger = None

    callbacks = [pl.callbacks.ModelCheckpoint(monitor='valid_acc', mode='max', save_last=True), TimerCallback()]
    if args.log:
        callbacks.append(pl.callbacks.LearningRateMonitor(logging_interval='epoch'))

    trainer = pl.Trainer(logger=logger, max_epochs=args.epochs, callbacks=callbacks, gradient_clip_val=0.5,
                         accelerator=accelerator, devices=devices, enable_progress_bar=args.enable_progress_bar)

    if args.test_ckpt is None:
        model = CIFAR10Model(args)
        trainer.fit(model, train_loader, val_loader)
        trainer.test(model, test_loader, ckpt_path='best')
    else:
        model = CIFAR10Model.load_from_checkpoint(args.test_ckpt)
        trainer.test(model, test_loader)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='CIFAR10 Point Cloud Classification Training')
    
    # --- General Training Parameters ---
    parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
    parser.add_argument('--warmup', type=int, default=10, help='Number of warmup epochs')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
    parser.add_argument('--lr', type=float, default=5e-4, help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-10, help='Weight decay')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')

    # --- Model Architecture ---
    parser.add_argument('--model_type', type=str, default='platoformer', choices=['rapidash', 'platoformer'], help='Model to use')
    parser.add_argument('--hidden_dim', type=eval, default=768, help='Hidden dimension(s)')
    parser.add_argument('--layers', type=eval, default=5, help='Number of layers or layers per scale')
    parser.add_argument('--equivariance', type=str, default="SEn", help='Type of equivariance (Tn, SEn)')
    parser.add_argument('--num_ori', type=int, default=12, help='Number of orientations')
    
    # --- Rapidash Specific Parameters ---
    parser.add_argument('--basis_dim', type=int, default=256, help='Basis dimension')
    parser.add_argument('--basis_hidden_dim', type=int, default=128, help='Hidden dimension of the basis function MLP')
    parser.add_argument('--degree', type=int, default=3, help='Polynomial degree')
    parser.add_argument('--edge_types', type=eval, default='["fc"]', help='Edge types for each layer')
    parser.add_argument('--ratios', type=eval, default="[]", help='Pooling ratios for U-Net architecture (empty for no pooling)')
    parser.add_argument('--widening_factor', type=int, default=4, help='Network widening factor')
    parser.add_argument('--layer_scale', type=eval, default=None, help='Layer scaling factor')
    parser.add_argument('--skip_connections', type=eval, default=False, help='Use U-Net style skip connections')

    # --- Platonic Transformer Specific Parameters ---
    parser.add_argument('--num_heads', type=int, default=None, help='Number of attention heads')
    parser.add_argument('--head_dim', type=int, default=16, help='Implicitly defines number of heads')
    parser.add_argument('--freq_sigma', type=float, default=1.0, help='Sigma for RFF positional encoding')
    parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate')
    parser.add_argument('--norm_first', type=eval, default=True, help='Use LayerNorm before attention')
    parser.add_argument('--learned_freqs', type=eval, default=True, help='Learnable frequencies for RFF')
    parser.add_argument('--dense_mode', type=eval, default=True, help='Use dense attention')
    parser.add_argument('--mean_aggregation', type=eval, default=False, help='Use mean aggregation instead of sum')
    parser.add_argument('--attention', type=eval, default=False, help='Use attention in the model')
    parser.add_argument('--post_pool_readout', type=eval, default=True, help='Use post-pooling readout')
    parser.add_argument('--ffn_readout', type=eval, default=False, help='Use FFN readout after pooling')

    # --- Data and Augmentation ---
    parser.add_argument('--train_augm', type=eval, default=True, help='Use rotation augmentation during training')
    parser.add_argument('--patch_size', type=int, default=8, help='Side length of the square image patches')
    parser.add_argument('--data_dir', type=str, default="./datasets/cifar10", help='Data directory')

    # --- System and Checkpointing ---
    parser.add_argument('--gpus', type=int, default=1, help='Number of GPUs')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of data loading workers')
    parser.add_argument('--log', type=eval, default=True, help='Enable logging')
    parser.add_argument('--enable_progress_bar', type=eval, default=True, help='Show progress bar')
    parser.add_argument('--test_ckpt', type=str, default=None, help='Checkpoint for testing')
    parser.add_argument('--resume_ckpt', type=str, default=None, help='Checkpoint to resume from')
    parser.add_argument('--config', type=eval, default=None, help='Sweep configuration dictionary')
    
    args = parser.parse_args()

    if args.config is not None:
        for key, value in args.config.items():
            setattr(args, key, value)
    
    main(args)