import os
import argparse

import torch
import torchmetrics
import pytorch_lightning as pl
from datasets.scanobjectnn import ScanObjectNN
from torch_geometric.loader import DataLoader
import torch_geometric.nn as tgnn

from models.platoformer.platoformer import PlatonicTransformer
from models.platoformer.groups import PLATONIC_GROUPS

from utils import (CosineWarmupScheduler, NormalizeCoord, RandomJitter,
                   RandomRotatePerturbation, RandomShift, RandomSOd,
                   RandomSO2AroundAxis, SamplePoints, TimerCallback)

# Performance optimization
torch.set_float32_matmul_precision('medium')

# Some augmentation functions
import torch_geometric.transforms as T
from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform, RandomScale


def patchify(data, sampling_ratio=0.125, num_neighbours=16):
    center_ids = tgnn.fps(data.pos, data.batch, sampling_ratio)
    centers = data.pos[center_ids]  # [nbr centers, 3]
    batch_centers = data.batch[center_ids]
    knn_ids = tgnn.knn(
        x=data.pos,
        y=centers,
        batch_x=data.batch,
        batch_y=batch_centers,
        k=num_neighbours,
    )
    # TODO: is the below reshape safe with batched graphs, i.e. is knn_ids[1] always ordered?
    vectors = data.pos[knn_ids[0]].reshape(-1, num_neighbours, 3) - centers[:, None]  # [nbr centers, nbr neighbours, 3]
    return centers, vectors, batch_centers


class ScanObjectNNModel(pl.LightningModule):
    def __init__(self, args):
        super().__init__()
        self.save_hyperparameters(args)
        
        # Setup rotation augmentation
        if args.inplane_augm:
            self.rotation_generator = RandomSO2AroundAxis(axis=1, degrees=180)
        else:
            self.rotation_generator = RandomSOd(3)
        self.avg_num_nodes = args.patchify_num_centers if args.patchify else args.num_points
        
        # Calculate total input channels
        in_channels_scalar = 0                             
        if self.hparams.use_positions_as_vector_input:
            in_channels_vector = 1
        else:
            in_channels_vector = 0

        in_channels_vector += 3
        
        if self.hparams.patchify:
            in_channels_vector += self.hparams.patchify_num_neighbours
        in_channels = in_channels_scalar + in_channels_vector

        # Ensure at least one input channel if none are specified
        if in_channels == 0:
            in_channels_scalar = 1  # will use constant ones as input
            in_channels = 1

        # Initialize model
        if self.hparams.equivariance == "Tn":
            solid_name = "trivial"
        elif self.hparams.equivariance == "SEn":
            solid_name = "tetrahedron"
        elif self.hparams.equivariance == "TnFlip":
            solid_name = "flip_3d_axis0"
        else:
            raise ValueError(f"Unsupported equivariance type: {self.hparams.equivariance}. "
                             "Supported types are 'Tn' (trivial) and 'SEn' (tetrahedron).")

        # This sets the number of heads in case head_dim is specified.
        if self.hparams.head_dim is not None:
            num_heads = self.hparams.hidden_dim // (self.hparams.head_dim * PLATONIC_GROUPS[solid_name.lower()].G)
            if (self.hparams.num_heads is not None) and (num_heads != self.hparams.num_heads):
                raise ValueError(f"head_dim {self.hparams.head_dim} does not match num_heads {self.hparams.num_heads} ")
            self.hparams.num_heads = num_heads
        
        num_classes = 15

        ape_sigma = self.hparams.ape_sigma if self.hparams.ape_sigma > 0.0 else None
        self.net = PlatonicTransformer(
            input_dim=in_channels_scalar, 
            input_dim_vec=in_channels_vector,
            hidden_dim=self.hparams.hidden_dim,
            output_dim=num_classes,
            output_dim_vec=0,
            nhead=self.hparams.num_heads,
            num_layers=self.hparams.layers,
            solid_name=solid_name,
            ffn_dim_factor=4,
            task_level="graph",
            dropout=self.hparams.dropout,
            norm_first=self.hparams.norm_first,
            freq_sigma=self.hparams.freq_sigma,
            ape_sigma=ape_sigma,
            learned_freqs=self.hparams.learned_freqs,
            spatial_dim=3,
            dense_mode=self.hparams.dense_mode,
            mean_aggregation=self.hparams.mean_aggregation,
            use_key=self.hparams.use_key,
        )
        
        # Setup metrics
        self.train_metric = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.valid_metric = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.test_metric = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)

    def forward(self, data):
        # Apply rotation augmentation if enabled (during training)
        if (self.training and self.hparams.train_augm) or (not self.training and self.hparams.test_augm):
            rot = self.rotation_generator().type_as(data.pos)
            data.pos = torch.einsum('ij,bj->bi', rot, data.pos)
            if hasattr(data, 'normal'):
                data.normal = torch.einsum('ij,bj->bi', rot, data.normal)
        else:
            rot = torch.eye(3, device=data.pos.device)

        # Prepare input features
        x = []  # scalar features
        vec = []  # vector features

        if self.hparams.patchify:
            centers, vectors, batch_centers = patchify(
                data,
                sampling_ratio=self.hparams.patchify_num_centers/self.hparams.num_points,
                num_neighbours=self.hparams.patchify_num_neighbours,
            )
            data.pos = centers
            data.batch = batch_centers
            vec.append(vectors)

        # Add reference frame to node features.
        vec.append(rot.transpose(-2,-1).unsqueeze(0).expand(data.pos.shape[0], -1, -1))

        # Add vector features
        if self.hparams.use_positions_as_vector_input:
            vec.append(data.pos[:,None,:])

        if not self.hparams.use_positions:
            data.pos = torch.zeros_like(data.pos)

        # Combine features
        if not x and not vec:  # Only add constant ones if both x and vec are empty
            x = torch.ones(data.pos.size(0), 1).type_as(data.pos)
        else:
            x = torch.cat(x, dim=-1) if x else None
        vec = torch.cat(vec, dim=1) if vec else None

        # Forward pass
        pred, _ = self.net(x, data.pos, data.batch, vec=vec, avg_num_nodes=self.avg_num_nodes)
        return pred

    def training_step(self, data, batch_idx):
        pred = self(data)
        loss = torch.nn.functional.cross_entropy(pred, data.y)
        self.train_metric(pred, data.y)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, data, batch_idx):
        pred = self(data)
        self.valid_metric(pred, data.y)

    def test_step(self, data, batch_idx):
        pred = self(data)
        self.test_metric(pred, data.y)

    def on_train_epoch_end(self):
        self.log("train_acc", self.train_metric, prog_bar=True)

    def on_validation_epoch_end(self):
        self.log("valid_acc", self.valid_metric, prog_bar=True)

    def on_test_epoch_end(self):
        suffix = "_rotated" if self.hparams.test_augm else ""
        self.log(f"test_acc{suffix}", self.test_metric, prog_bar=True)
        # self.log("test_acc", self.test_metric, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = CosineWarmupScheduler(optimizer, self.hparams.warmup, self.trainer.max_epochs)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

def load_data(args):
    """Load and preprocess ScanObjectNN dataset using PyG."""
    
    # Define transforms
    # scale_function = torch.max
    train_transform = T.Compose([
        # NormalizeCoord(scale_function=scale_function),
        RandomScale((args.min_scale, args.max_scale)),
        RandomJitter(
            sigma=args.noise_strength,
            clip=5*args.noise_strength,
            relative=True,
        ),
    ])

    test_transform = T.Compose([
        # NormalizeCoord(scale_function=scale_function),
    ])
    
    # Create datasets
    train_dataset = ScanObjectNN(
        args.data_dir,
        subset='train',
        split=args.data_split,
        version=args.data_version,
        transform=train_transform,
        nbr_pts=args.num_points,
    )

    test_dataset = ScanObjectNN(
        args.data_dir,
        subset='test',
        split=args.data_split,
        version=args.data_version,
        transform=test_transform,
        nbr_pts=args.num_points,
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
    )
    
    return train_loader, test_loader

def main(args):
    # Set random seed
    pl.seed_everything(args.seed)

    # Load data
    train_loader, test_loader = load_data(args)

    # Setup hardware configuration
    if args.gpus > 0:
        accelerator = "gpu"
        devices = args.gpus
    else:
        accelerator = "cpu"
        devices = "auto"
        
    # Configure logging
    if args.log:
        save_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "logs")
        logger = pl.loggers.WandbLogger(
            project="Platoformer-ScanObjectNN",
            config=args,
            save_dir=save_dir
        )
    else:
        logger = None

    # Setup callbacks
    callbacks = [
        pl.callbacks.ModelCheckpoint(
            monitor='valid_acc',
            mode='max',
            save_last=True
        ),
        TimerCallback()
    ]
    if args.log:
        callbacks.append(pl.callbacks.LearningRateMonitor(logging_interval='epoch'))

    # Initialize trainer
    trainer = pl.Trainer(
        logger=logger,
        max_epochs=args.epochs,
        callbacks=callbacks,
        gradient_clip_val=10,
        accelerator=accelerator,
        devices=devices,
        enable_progress_bar=args.enable_progress_bar
    )

    # Train or test
    if args.test_ckpt is None:
        model = ScanObjectNNModel(args)
        trainer.fit(model, train_loader, test_loader)
        # Test without augmentation
        trainer.test(model, test_loader, ckpt_path = callbacks[0].last_model_path)
        # Test with augmentation
        model.hparams.test_augm = True
        trainer.test(model, test_loader, ckpt_path = callbacks[0].last_model_path)
    else:
        model = ScanObjectNNModel.load_from_checkpoint(args.test_ckpt)
        # Test without augmentation
        trainer.test(model, test_loader)
        # Test with augmentation
        model.hparams.test_augm = True
        trainer.test(model, test_loader)


## NOTES
# - Mamba3d uses 2048 points
# - Surprisingly(?) large freq_sigma seems good

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='ScanObjectNN Classification Training')
    
    # Training parameters
    parser.add_argument('--epochs', type=int, default=300, help='Number of training epochs')
    parser.add_argument('--warmup', type=int, default=10, help='Number of warmup epochs')
    parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=5e-2, help='Weight decay')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')

    # Model architecture
    parser.add_argument('--model_type', type=str, default='platoformer', choices=['platoformer'], help='Type of model to use: currently only transformer.')
    parser.add_argument('--hidden_dim', type=eval, default=576, help='Hidden dimension(s), for rapidash [256,256,256,256]')
    parser.add_argument('--layers', type=eval, default=4, help='Layers per scale, for rapidash [0, 1, 1, 1]')
    parser.add_argument('--equivariance', type=str, default="Tn", help='Type of equivariance')
    
    # Platonic Transformer specific parameters
    parser.add_argument('--num_heads', type=int, default=12, help='Number of attention heads (Transformer only).')
    parser.add_argument('--head_dim', type=int, default=None, help='Dimension of attention heads (Transformer only).')
    parser.add_argument('--freq_sigma', type=float, default=10.0, help='Sigma for RFF positional encoding (Transformer only).')
    parser.add_argument('--ape_sigma', type=eval, default=10.0, help='Sigma for RFF positional encoding')
    parser.add_argument('--learned_freqs', type=eval, default=True, help='Use Rotary Position Embedding (RoPE) in Transformer.')
    parser.add_argument('--use_key', type=eval, default=False, help='Use key projection when using RoPE')
    parser.add_argument('--freq_init', type=str, default='spiral', choices=['random', 'spiral'], help='Frequency initialization method for RoPE')
    parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate (Transformer only).')
    parser.add_argument('--norm_first', type=eval, default=True, help='Use LayerNorm before attention in Transformer.')
    parser.add_argument('--use_positions', type=eval, default=True, help='Use positions (otherwise they are zeroed out).')
    parser.add_argument('--use_positions_as_vector_input', type=eval, default=False, help='Use positions as vector input.')
    parser.add_argument('--patchify', type=eval, default=True, help='Patchify pointcloud using FPS and KNN.')
    parser.add_argument('--patchify_num_centers', type=int, default=128, help='Number of centers to sample using FPS if patchifying.')
    parser.add_argument('--patchify_num_neighbours', type=int, default=32, help='Number of neighbours to calculate per center if patchifying.')
    parser.add_argument('--dense_mode', type=eval, default=True, help='Use dense mode.')
    parser.add_argument('--mean_aggregation', type=eval, default=False, help='Use mean aggregation.')
    parser.add_argument('--attention', type=eval, default=False, help='Use attention in PlatonicConv layers.')
    
    # Training features
    parser.add_argument('--train_augm', type=eval, default=True, help='Use rotation augmentation during training')
    parser.add_argument('--inplane_augm', type=eval, default=True, help='Use only inplane rotation augmentation')
    parser.add_argument('--test_augm', type=eval, default=False, help='Use rotation augmentation during testing')
    parser.add_argument('--num_points', type=int, default=2048, help='Number of points to sample')
    parser.add_argument('--noise_strength', type=float, default=0.1, help='Strength of jitter noise')
    parser.add_argument('--min_scale', type=float, default=0.8, help='Min scale aug')
    parser.add_argument('--max_scale', type=float, default=1.25, help='Max scale aug')

    # Sweep configuration
    parser.add_argument('--config', type=eval, default=None, help='Sweep configuration dictionary')
    parser.add_argument('--model_id', type=int, default=None, help='Model ID in case you would want to label the configuration')
    
    # Data and logging
    parser.add_argument('--data_dir', type=str, default="./datasets/scanobjectnn/h5_files", help='Data directory')
    parser.add_argument('--data_split', type=str, default="main_split")
    parser.add_argument('--data_version', type=str, default="_augmentedrot_scale75")
    parser.add_argument('--log', type=eval, default=True, help='Enable logging')
    
    # System and checkpointing
    parser.add_argument('--gpus', type=int, default=1, help='Number of GPUs')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of data loading workers')
    parser.add_argument('--enable_progress_bar', type=eval, default=True, help='Show progress bar')
    parser.add_argument('--test_ckpt', type=str, default=None, help='Checkpoint for testing')
    parser.add_argument('--resume_ckpt', type=str, default=None, help='Checkpoint to resume from')
    
    args = parser.parse_args()

    # Overwrite default settings with values from config if provided
    if args.config is not None:
        for key, value in args.config.items():
            setattr(args, key, value)

    if args.inplane_augm and not args.train_augm:
        raise ValueError()

    main(args)