import argparse
import math
import numbers
import os
import random
import time
from typing import Tuple, Union

import pytorch_lightning as pl
import torch
import torchmetrics
from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform

from datasets.motion import MotionDataset
from models.rapidash_v4 import Rapidash
from utils.utils import CosineWarmupScheduler, RandomSOd

torch.set_float32_matmul_precision('medium')


class RandomRotatePointCloud(BaseTransform):
    r"""Rotates node positions around a specific axis by a randomly sampled
    factor within a given interval (functional name: :obj:`random_rotate`).

    Args:
        degrees (tuple or float): Rotation interval from which the rotation
            angle is sampled. If :obj:`degrees` is a number instead of a
            tuple, the interval is given by :math:`[-\mathrm{degrees},
            \mathrm{degrees}]`.
        axis (int, optional): The rotation axis. (default: :obj:`0`)
    """
    def __init__(self, degrees: Union[Tuple[float, float], float],
                 axis: int = 0):
        if isinstance(degrees, numbers.Number):
            degrees = (-abs(degrees), abs(degrees))
        assert isinstance(degrees, (tuple, list)) and len(degrees) == 2
        self.degrees = degrees
        self.axis = axis

    def __call__(self, data: Data) -> Data:
        degree = math.pi * random.uniform(*self.degrees) / 180.0
        sin, cos = math.sin(degree), math.cos(degree)

        if data.pos.size(-1) == 2:
            matrix = [[cos, sin], [-sin, cos]]
        else:
            if self.axis == 0:
                matrix = [[1, 0, 0], [0, cos, sin], [0, -sin, cos]]
            elif self.axis == 1:
                matrix = [[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]]
            else:
                matrix = [[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]]
        return LinearTransformationPointCloud(torch.tensor(matrix))(data)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.degrees}, '
                f'axis={self.axis})')

class LinearTransformationPointCloud(BaseTransform):
    r"""Transforms node positions with a square transformation matrix computed
    offline.

    Args:
        matrix (Tensor): tensor with shape :obj:`[D, D]` where :obj:`D`
            corresponds to the dimensionality of node positions.
    """
    def __init__(self, matrix):
        assert matrix.dim() == 2, (
            'Transformation matrix should be two-dimensional.')
        assert matrix.size(0) == matrix.size(1), (
            'Transformation matrix should be square. Got [{} x {}] rectangular'
            'matrix.'.format(*matrix.size()))

        # Store the matrix as its transpose.
        # We do this to enable post-multiplication in `__call__`.
        self.matrix = matrix.t()

    def __call__(self, data):
        pos = data.pos.view(-1, 1) if data.pos.dim() == 1 else data.pos
        vel = data.vel.view(-1, 1) if data.vel.dim() == 1 else data.vel
        normals = data.normals.view(-1, 1) if data.normals.dim() == 1 else data.normals
        y = data.y.view(-1, 1) if data.y.dim() == 1 else data.y
        rot = data.rot
        assert pos.size(-1) == self.matrix.size(-2), (
            'Node position matrix and transformation matrix have incompatible '
            'shape.')

        assert vel.size(-1) == self.matrix.size(-2), (
            'Node position matrix and transformation matrix have incompatible '
            'shape.')
        
        assert normals.size(-1) == self.matrix.size(-2), (
            'Node position matrix and transformation matrix have incompatible '
            'shape.')
        
        assert y.size(-1) == self.matrix.size(-2), (
            'Node position matrix and transformation matrix have incompatible '
            'shape.')

        # We post-multiply the points by the transformation matrix instead of
        # pre-multiplying, because `data.pos` has shape `[N, D]`, and we want
        # to preserve this shape.
        data.pos = torch.matmul(pos, self.matrix.to(pos.dtype).to(pos.device))
        data.normals = torch.matmul(normals, self.matrix.to(normals.dtype).to(normals.device))
        data.vel = torch.matmul(vel, self.matrix.to(vel.dtype).to(vel.device))
        data.y = torch.matmul(y, self.matrix.to(y.dtype).to(y.device))
        data.rot = torch.matmul(rot, self.matrix.to(rot.dtype).to(rot.device))
        return data

class TimerCallback(pl.Callback):
    def __init__(self):
        super().__init__()
        self.total_training_start_time = 0.0
        self.epoch_start_time = 0.0
        self.test_inference_time = 0.0

    # Called when training begins
    def on_train_start(self, trainer, pl_module):
        self.total_training_start_time = time.time()

    # Called when training ends
    def on_train_end(self, trainer, pl_module):
        total_training_time = (time.time() - self.total_training_start_time)/60
        # Log total training time at the end of training
        trainer.logger.experiment.log({"Total Training Time (min)" : total_training_time})

    # Called at the start of the test epoch
    def on_test_epoch_start(self, trainer, pl_module):
        self.epoch_start_time = time.time()

    # Called at the end of the test epoch
    def on_test_epoch_end(self, trainer, pl_module):
        # Calculate the inference time for the entire test epoch
        self.test_inference_time = (time.time() - self.epoch_start_time)/60
        # Log the inference time for the test epoch
        trainer.logger.experiment.log({"Test Inference Time (min)": self.test_inference_time})


class CMUMotionModel(pl.LightningModule):
    def __init__(self, args):
        super().__init__()
        self.save_hyperparameters(args)

        # For rotation augmentations during training and testing
        self.rotation_generator = RandomSOd(3)
        self.n_joints = 31

        _dim = 6 if self.hparams.base_space == "R3S2" else 3
        
        # Input forms 
        in_channels_scalar = _dim * self.hparams.use_coords_as_scalars + 3 * self.hparams.use_velocity_as_scalars+ 3 * self.hparams.use_normals_as_scalars + 9 * self.hparams.use_pose_as_scalars
        in_channels_vec = 1 * self.hparams.use_coords_as_vectors + 1 * self.hparams.use_velocity_as_vectors + 1 * self.hparams.use_normals_as_vectors + 3 * self.hparams.use_pose_as_vectors
        
        if in_channels_scalar == 0 and in_channels_vec == 0:
            in_channels_scalar = 1
            
        # Output forms
        if self.hparams.fiber_dim == 0:
            out_channels_scalar = 3
            out_channels_vec = 0
        elif self.hparams.fiber_dim > 0:
            out_channels_scalar = 0
            out_channels_vec = 1
        else:
            # Error
            raise ValueError("fiber_dim must be 0 or > 1")
    
        model_type = self.hparams.model
        if model_type == 'rapidash':
            self.net = Rapidash(input_dim         = in_channels_scalar + in_channels_vec,
                                hidden_dim        = self.hparams.hidden_dim,
                                output_dim        = out_channels_scalar,
                                num_layers        = self.hparams.layers,
                                edge_types        = self.hparams.edge_types,
                                ratios            = self.hparams.ratios,
                                output_dim_vec    = out_channels_vec,
                                dim               = 3 ,
                                basis_dim         = self.hparams.basis_dim,
                                degree            = self.hparams.degree,
                                widening_factor   = self.hparams.widening_factor,
                                layer_scale       = self.hparams.layer_scale,
                                task_level        = 'node',
                                multiple_readouts = self.hparams.multiple_readouts,
                                last_feature_conditioning=False,
                                attention         = self.hparams.attention,
                                fully_connected   = True,
                                residual_connections=self.hparams.residual_connections,
                                global_basis      = False,
                                equivariance      = self.hparams.equivariance,
                                base_space        = self.hparams.base_space,
                                fiber_space       = self.hparams.fiber_space,
                                fiber_dim         = self.hparams.fiber_dim 
                                )

        self.all_joint_metric = torch.nn.MSELoss(reduction='none')
        self.avg_metric = torchmetrics.MeanSquaredError()

        # axis 1 is  the correct axis not the axis z
        self.rotated_prediction = RandomRotatePointCloud(15,axis=1)
        self.train_step_loss = []
        self.val_step_loss = []
       
        # TODO: What does this do?
        self.use_coords_as_features = False
        self.use_pose = False

        self.all_joint_mse = []
  
    def unpack_batch(self, batch):
        # The pose
        if self.training:
            # Augment the data
            if self.hparams.train_augm:
                rot = self.rotation_generator().type_as(batch['pos']).contiguous()
                batch['pos'] = torch.einsum('ij, bj->bi', rot, batch['pos']).contiguous()
                batch['vel'] = torch.einsum('ij, bj->bi', rot, batch['vel']).contiguous()
                batch['normals'] = torch.einsum('ij, bj->bi', rot, batch['normals']).contiguous()
                batch['y'] = torch.einsum('ij, bj->bi', rot, batch['y']).contiguous()
            else:
                rot = torch.eye(3, device=batch['pos'].device)
        else:
            rot = torch.eye(3, device=batch['pos'].device)
        
        batch['rot'] = rot
        
        # us the Ra
        if self.hparams.rotated_prediction:
            batch= self.rotated_prediction(batch)
        
        # The features as provided
        pos, vel, batch, y,  ptr, edge_index,normals,rot = batch['pos'], batch['vel'], batch['batch'], batch['y'], batch['ptr'], batch['edge_index'], batch['normals'],batch['rot']
        

        if self.hparams.base_space == "R3S2":
            pos = torch.cat([pos, normals], dim=-1)

        # This is anyway overwritten in the model
        # with a fully connected edge index
        edge_index = None
        
        x = []
        vec = []
        # Use the cooordiantes as scalar features
        if self.hparams.use_coords_as_scalars:
            x.append(pos)
        # Use the coordinates as vector features
        if self.hparams.use_coords_as_vectors:
            vec.append(pos[:,None,:])
        # Use the normal as scalar features
        if self.hparams.use_normals_as_scalars:
            x.append(normals)
        if self.hparams.use_normals_as_vectors:
            vec.append(normals[:,None,:])
        # Use the velocity as scalar features
        if self.hparams.use_velocity_as_scalars:
            x.append(vel)
        # Use the velocity as vector features
        if self.hparams.use_velocity_as_vectors:
            vec.append(vel[:,None,:])
        if self.hparams.use_pose_as_scalars:
            x.append(rot.transpose(-2,-1).unsqueeze(0).expand(pos.shape[0], -1, -1).flatten(-2, -1))
        if self.hparams.use_pose_as_vectors:
            vec.append(rot.transpose(-2,-1).unsqueeze(0).expand(pos.shape[0], -1, -1))
            # vec.append(rot.unsqueeze(0).expand(pos.shape[0], -1, -1))
        
        
        x = torch.cat(x, dim=-1) if len(x) > 0 else None
        vec = torch.cat(vec, dim=-2) if len(vec) > 0 else None
        
        if x is None and vec is None:
           x = torch.ones(pos.size(0), 1).type_as(pos)

        return x, vec, pos, edge_index, batch, y, ptr

    def training_step(self, batch, batch_idx):
        x, vec, pos, edge_index, batch, y, _ = self.unpack_batch(batch)
        # Make prediction
        y_pred_sc, y_pred_vec = self.net(x, pos, edge_index, batch, vec=vec)
        
        if self.hparams.fiber_dim == 0:
            delta_pos = y_pred_sc
        else:
            delta_pos = y_pred_vec[:,0,:]
        y_pred = pos + delta_pos

        # Compute loss
        loss = torch.nn.functional.mse_loss(y_pred, y)
        self.avg_mse=self.avg_metric(y_pred, y)
        return loss
    
    def on_train_epoch_end(self):
        self.log("train MSE", self.avg_mse, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        x, vec, pos, edge_index, batch, y, ptr = self.unpack_batch(batch)
        # Make prediction
        y_pred_sc, y_pred_vec = self.net(x, pos, edge_index, batch, vec=vec)

        if self.hparams.fiber_dim == 0:
            delta_pos = y_pred_sc
        else:
            delta_pos = y_pred_vec[:,0,:]
        y_pred = pos + delta_pos

        # Compute loss
        self.avg_metric(y_pred, y)
        self.all_joint_mse_batch=self.all_joint_metric(y_pred, y)
        self.all_joint_mse.append(torch.stack([chunk for chunk in torch.chunk(self.all_joint_mse_batch, 
                                                                       chunks=self.all_joint_mse_batch.shape[0] // self.n_joints)]).mean(dim=0))
       
    def on_validation_epoch_end(self):
        # Log overall metrics
        self.log("validation MSE",self.avg_metric, prog_bar=True)
        mse=torch.stack(self.all_joint_mse).mean(dim=0)
        # Log mse per joint
        for i in range(self.n_joints):
            self.log("validation MSE for joint  ("+str(i)+")", mse[i].mean(), prog_bar=False)
        
        # Reset the list
        self.all_joint_mse.clear()

    def test_step(self, batch, batch_idx):
        x, vec, pos, edge_index, batch, y, ptr = self.unpack_batch(batch)
        # Make prediction
        y_pred_sc, y_pred_vec = self.net(x, pos, edge_index, batch, vec=vec)
        
        if self.hparams.fiber_dim == 0:
            delta_pos = y_pred_sc
        else:
            delta_pos = y_pred_vec[:,0,:]
        y_pred = pos + delta_pos

        # Compute loss
        self.avg_metric(y_pred, y)
        self.all_joint_mse_batch=self.all_joint_metric(y_pred, y)
        self.all_joint_mse.append(torch.stack([chunk for chunk in torch.chunk(self.all_joint_mse_batch, 
                                                                       chunks=self.all_joint_mse_batch.shape[0] // self.n_joints)]).mean(dim=0))

    def on_test_epoch_end(self):
        # Log overall metrics
        self.log("test MSE", self.avg_metric, prog_bar=True)
        mse=torch.stack(self.all_joint_mse).mean(dim=0)
     
        # Log mse per joint
        for i in range(self.n_joints):
            self.log("Test MSE for joint  ("+str(i)+")", mse[i].mean(), prog_bar=False)
        # Reset the list
        self.all_joint_mse.clear()
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = CosineWarmupScheduler(optimizer, self.hparams.warmup, self.trainer.max_epochs)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}


def main(args):
    # Seed everything
    pl.seed_everything(42)

    # Load the data
    dataset = MotionDataset(batch_size=args.batch_size, all_joint_normals=args.all_joint_normals, num_training_samples=args.max_train_sample)
    train_loader,val_loader, test_loader = dataset.train_loader(), dataset.val_loader(), dataset.test_loader()

    # Hardware settings
    if args.gpus > 0 and torch.cuda.is_available(): 
        accelerator = "gpu"
        devices = args.gpus
    else:
        accelerator = "cpu"
        devices = "auto"
    if args.num_workers == -1:
        args.num_workers = os.cpu_count()

    # Logging settings
    if args.log:
        save_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "logs")
        logger = pl.loggers.WandbLogger(project="CMU-Motion-Prediction", name=args.model, config=args, save_dir=save_dir)
    else:
        logger = None

    # Pytorch lightning call backs and trainer
    checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor='validation MSE', mode = 'min', every_n_epochs = 1, save_last=True)
    timer_callback = TimerCallback()
    callbacks = [checkpoint_callback, timer_callback]
    if args.log: callbacks.append(pl.callbacks.LearningRateMonitor(logging_interval='epoch'))
    trainer = pl.Trainer(logger=logger, max_epochs=args.epochs, callbacks=callbacks, gradient_clip_val=0.5, 
                         accelerator=accelerator, devices=devices, enable_progress_bar=args.enable_progress_bar)

    # Do the training or testing
    if args.test_ckpt is None:
        model = CMUMotionModel(args)
        model.hparams.S_churn = args.S_churn
        model.hparams.sigma_max = args.sigma_max
        model.hparams.num_steps = args.num_steps
        model.hparams.batch_size = args.batch_size
        trainer.fit(model, train_loader, val_loader, ckpt_path=args.resume_ckpt)
        trainer.test(model, test_loader, ckpt_path = checkpoint_callback.best_model_path)
    else:   
        model = CMUMotionModel.load_from_checkpoint(args.test_ckpt)
        model.save_hyperparameters(args)    
        trainer.test(model, val_loader)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # ------------------------ Input arguments
    
    # Run parameters
    parser.add_argument('--epochs', type=int, default=2)
    parser.add_argument('--warmup', type=int, default=20)
    parser.add_argument('--batch_size', type=int, default=5)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--weight_decay', type=float, default=1e-12)
    parser.add_argument('--log', type=eval, default=True)
    parser.add_argument('--enable_progress_bar', type=eval, default=True)
    parser.add_argument('--num_workers', type=int, default=0)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--test_ckpt', type=str, default=None)
    parser.add_argument('--resume_ckpt', type=str, default=None)
    
    # Train settings
    parser.add_argument('--train_augm', type=eval, default=False)
    
    # CMU Dataset
    parser.add_argument('--root', type=str, default="./datasets/motion")
    
    # Model class
    parser.add_argument('--model', type=str, default="rapidash")  # ponita or egnn
   
   # PONTA model settings
    parser.add_argument('--hidden_dim', type=eval, default=256) #hidden dimensions
    parser.add_argument('--basis_dim', type=int, default=256)
    parser.add_argument('--degree', type=int, default=2)
    parser.add_argument('--layers', type=eval, default=7) #messge passing layers

    parser.add_argument('--edge_types', type=eval, default=["fc"])
    parser.add_argument('--ratios', type=eval, default=[]) #shrnkage ratios
    parser.add_argument('--widening_factor', type=int, default=4)
    parser.add_argument('--layer_scale', type=eval, default=None)
    parser.add_argument('--multiple_readouts', type=eval, default=False)
    parser.add_argument('--attention', type=eval, default=False)
    parser.add_argument('--residual_connections', type=eval, default=False)  # Check what works better


    # CMUMotion settings
    parser.add_argument('--max_train_sample', type=int, default=100)
    parser.add_argument('--all_joint_normals', type=eval, default=False)
    parser.add_argument('--combinations', type=eval, default=None)  # for wandb sweeps...
    
   
    parser.add_argument('--base_space', type=str, default="R3")  # "R2" or "R3" or "R3S2
    parser.add_argument('--fiber_space', type=str, default='S2')  # None, "S1", "SO2", "S2", "SO3"
    parser.add_argument('--fiber_dim', type=int, default=8)
    parser.add_argument('--equivariance', type=str, default="SE3")  # "T2", "T3", "Tn", "SE2", "SE3", "SEn"
   
    # Node feature types
    parser.add_argument('--use_coords_as_scalars', type=eval, default=False)
    parser.add_argument('--use_velocity_as_scalars', type=eval, default=False)
    parser.add_argument('--use_pose_as_scalars', type=eval, default=False)
    
    parser.add_argument('--use_coords_as_vectors', type=eval, default=False)
    parser.add_argument('--use_velocity_as_vectors', type=eval, default=True)
    parser.add_argument('--use_pose_as_vectors', type=eval, default=False)


    # Normals are always OFF for now !!
    parser.add_argument('--use_normals_as_vectors', type=eval, default=False)
    parser.add_argument('--use_normals_as_scalars', type=eval, default=False)
    
    # Rotated prediction
    parser.add_argument('--rotated_prediction', type=eval, default=True)
   
    # Diffusion model settings
    parser.add_argument('--S_churn', type=float, default=10)
    parser.add_argument('--sigma_max', type=float, default=1)
    parser.add_argument('--num_steps', type=int, default=50)
    parser.add_argument('--sigma_data', type=float, default=1)
    parser.add_argument('--normalize_x_factor', type=float, default=4.0)
    parser.add_argument('--normalize_charge_factor', type=float, default=8.0)
    
    # Parallel computing stuff
    parser.add_argument('-g', '--gpus', default=1, type=int)
    
    # Arg parser
    args = parser.parse_args()

    # Overwrite default settings with values from combinations
    if args.combinations is not None:
        for key, value in args.combinations.items():
            setattr(args, key, value)

    main(args)
