import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
import time

from collections import OrderedDict
from lib.training.hyperdict import HDict
from lib.models.pcqm.distance_predictor_novn import TGT_Distance
from lib.data.pcqm import data
from lib.data.pcqm.structural_transform import AddStructuralData
from lib.data.pcqm import bin_ops as pbins
from ..tgt_training import TGTTraining
from ..commons import DiscreteDistLoss, DiscreteAngleLoss, MaskedL1Loss, add_coords_noise, coords2dist, coords2angle, mask_angles, calculate_angle, calculate_torsion_angle

from torch.utils.tensorboard import SummaryWriter
class LayerStatisticsTracker:  
    def __init__(self, writer, model):  
        self.writer = writer  
        self.layer_outputs = {}  
        self.model = model  
        self.target_layers = [f"module.encoder.TGT_layers.{i}" for i in range(24)]  
        self.target_outputs = ['h', 'e', 'p', 't']  
        self.statistics = ['min', 'max', 'mean', 'median', 'std', 'p1', 'p99']  

    def __call__(self, module, input, output):  
        name = self.get_module_name(module)  
        if name in self.target_layers:  
            self.layer_outputs[name] = output  

    def get_module_name(self, module):  
        for name, mod in self.model.named_modules():  
            if mod == module:  
                return name  
        return None  

    def record_statistics(self, global_step):  
        for name, output in self.layer_outputs.items():  
            if isinstance(output, dict):  
                for key, value in output.items():  
                    if key in self.target_outputs and isinstance(value, torch.Tensor):  
                        self._record_tensor_statistics(name, key, value, global_step)  
            elif isinstance(output, tuple):  
                for i, item in enumerate(output):  
                    if i < len(self.target_outputs) and isinstance(item, torch.Tensor):  
                        self._record_tensor_statistics(name, self.target_outputs[i], item, global_step)  
            elif isinstance(output, torch.Tensor):  
                self._record_tensor_statistics(name, 'output', output, global_step)  

    def _record_tensor_statistics(self, layer_name, output_name, tensor, global_step):  
        if tensor.device.index == 0:    
            tensor_np = tensor.detach().cpu().numpy()  
            
            if tensor_np.dtype == bool:  
                true_ratio = np.mean(tensor_np)  
                self.writer.add_scalar(f'{output_name}/true_ratio', true_ratio, global_step)  
            else:  
                tensor_np = tensor_np.flatten()  
                
                stats = {  
                    'min': np.min(tensor_np),  
                    'max': np.max(tensor_np),  
                    # 'mean': np.mean(tensor_np),  
                    # 'median': np.median(tensor_np),  
                    # 'std': np.std(tensor_np),  
                }  
                
                # percentiles = np.percentile(tensor_np, [1, 99])  
                # stats['p1'] = percentiles[0]  
                # stats['p99'] = percentiles[1]  
                
                self.writer.add_scalars(f'{output_name}', {  
                    f'{stat_name}/{layer_name}': stat_value   
                    for stat_name, stat_value in stats.items()  
                }, global_step)

class SCHEME(TGTTraining):
    def get_default_config(self):
        config_dict = super().get_default_config()
        config_dict.update(
            save_path_prefix    = HDict.L(lambda c: 'models/pcqm/dist_pred' if c.model_prefix is None else f'models/pcqm/{c.model_prefix}/dist_pred'),
            predict_on          = ['train', 'val'] if self.executing_command == 'predict' else ['val'],
            coords_noise        = 0,
            coords_noise_smooth = 0,
            coords_input        = 'rdkit',
            coords_target       = 'dft',
            embed_3d_type       = HDict.L(lambda c: 'gaussian' if c.coords_input!='none'
                                                     else 'none'),
            num_dist_bins       = 512,
            num_angle_bins      = 256,
            num_torsion_bins    = 256,
            range_dist_bins     = 8,
            train_split         = 'train-3d' if self.executing_command != 'predict' else 'train',
            val_split           = 'valid-3d' if self.executing_command != 'predict' else 'valid',
            test_split          = 'test-dev',
            save_pred_dir       = HDict.L(lambda c: f'bins{c.prediction_samples}'),
            save_bins_smooth    = 0,
            coords_target_noise = 0,
        )
        return config_dict
    
    def __post_init__(self):
        super().__post_init__()
        self.xent_loss_fn_dist = DiscreteDistLoss(num_bins=self.config.num_dist_bins,range_bins=self.config.range_dist_bins)

        self.xent_loss_fn_angle = DiscreteAngleLoss(num_bins=self.config.num_dist_bins,max_angle=2*3.14159)
        
        self._uses_3d = self.config.coords_input != 'none'
        
        training_or_eval = self.executing_command in ('train', 'evaluate')
        
        input_dist_rdkit = self.config.coords_input == 'rdkit'
        predict_dist_rdkit = self.config.coords_target == 'rdkit'
        self._use_rdkit_coords = input_dist_rdkit or (predict_dist_rdkit and training_or_eval)
        
        input_dist_dft = self.config.coords_input == 'dft'
        predict_dist_dft = self.config.coords_target == 'dft'
        self._use_dft_coords = input_dist_dft or (predict_dist_dft and training_or_eval)


    
    def get_dataset_config(self, split):
        dataset_config, _ = super().get_dataset_config(split)
        ds_split = {'train':self.config.train_split,
                    'val':self.config.val_split,
                    'test':self.config.test_split}[split]
        
        columns = []
        if self._use_rdkit_coords:
            columns.append(data.Coords('rdkit'))
        
        if self._use_dft_coords:
            columns.append(data.Coords('dft'))
        
        is_train_split = 'train' in split
        transforms = [AddStructuralData()]
        dataset_config.update(
            split               = ds_split,
            dataset_path        = self.config.dataset_path,
            return_idx          = True,
            transforms          = transforms,
            verbose             = int(self.is_main_rank and is_train_split),
            additional_columns  = columns,
        )
        return dataset_config, data.PCQM4Mv2Dataset
    
    def get_model_config(self):
        model_config, _ = super().get_model_config()
        model_config.update(
            num_dist_bins = self.config.num_dist_bins,
        )
        return model_config, TGT_Distance
    
    def preprocess_batch(self, batch, training):
        batch = super().preprocess_batch(batch, training)
        
        node_mask = batch['node_mask']
        edge_mask = node_mask.unsqueeze(-1) * node_mask.unsqueeze(-2)
        batch['edge_mask'] = edge_mask.cuda()
        angle_indices = batch['angle_indices'].long()
        torsion_indices = batch['torsion_indices'].long()


        if self._uses_3d:
            coords = self.get_coords_input(batch, training)
            if self.config.coords_noise > 0:
                coords = add_coords_noise(coords=coords,
                                        edge_mask=edge_mask,
                                        noise_level=self.config.coords_noise,
                                        noise_smoothing=self.config.coords_noise_smooth)
            
            dist_input = coords2dist(coords)

            angle_input = calculate_angle(coords,angle_indices)

            torsion_input = calculate_torsion_angle(coords,torsion_indices)

            # print(f'node_input size {node_mask.size()} angle_input size {angle_input.size()} torsion_input size {torsion_input.size()}')

            angle_mask = (angle_indices != 0).any(dim=-1)
            torsion_mask = (torsion_indices != 0).any(dim=-1)
            batch['angle_mask'] = angle_mask.cuda()
            batch['torsion_mask'] = torsion_mask.cuda()
            batch['dist_input'] = dist_input.cuda()
            batch['angle_input'] = angle_input.cuda()
            batch['angle_indices'] = angle_indices.cuda()
            batch['torsion_indices'] = torsion_indices.cuda()
            batch['torsion_input'] = torsion_input.cuda()
        
        return batch


    def get_coords_input(self, batch, training):
        if self.config.coords_input == 'rdkit':
            coords = batch['rdkit_coords']
        elif self.config.coords_input == 'dft':
            coords = batch['dft_coords']
        else:
            raise ValueError('Invalid coords_input')
        return coords
    
    def get_coords_target(self, batch):
        if self.config.coords_target == 'rdkit':
            coords = batch['rdkit_coords']
        elif self.config.coords_target == 'dft':
            coords = batch['dft_coords']
        else:
            raise ValueError('Invalid coords_target')
        return coords
    
    def get_dist_target(self, batch, training=False):
        coords_targ = self.get_coords_target(batch)
        noise_level = self.config.coords_target_noise
        if training and noise_level > 0:
            noise = coords_targ.new(coords_targ.size())\
                               .normal_(0, noise_level)
            coords_targ = coords_targ + noise
        dist_targ = coords2dist(coords_targ)
        return dist_targ

    def get_angle_target(self, batch, training=False):
        coords_targ = self.get_coords_target(batch)
        noise_level = self.config.coords_target_noise
        if training and noise_level > 0:
            noise = coords_targ.new(coords_targ.size())\
                               .normal_(0, noise_level)
            coords_targ = coords_targ + noise
        angle_targ = calculate_angle(coords_targ,batch['angle_indices'].long())
        return angle_targ

    def get_torsion_target(self, batch, training=False):
        coords_targ = self.get_coords_target(batch)
        noise_level = self.config.coords_target_noise
        if training and noise_level > 0:
            noise = coords_targ.new(coords_targ.size())\
                               .normal_(0, noise_level)
            coords_targ = coords_targ + noise
        torsion_targ = calculate_torsion_angle(coords_targ,batch['torsion_indices'].long())
        return torsion_targ

    
    def calculate_loss(self, outputs, inputs):
        dist_targ = self.get_dist_target(inputs, training=True).cuda()
        angle_targ = self.get_angle_target(inputs, training=True).cuda()
        torsion_targ = self.get_torsion_target(inputs, training=True).cuda()
        mask = inputs['edge_mask']
        mask_angle = inputs['angle_mask']
        mask_torsion_angle = inputs['torsion_mask']
        dist_loss = self.xent_loss_fn_dist(outputs[0], dist_targ, mask)
        angle_loss = self.xent_loss_fn_angle(outputs[1], angle_targ, mask_angle)*0.1
        torsion_angle_loss = self.xent_loss_fn_angle(outputs[2], torsion_targ, mask_torsion_angle.float())*0.2
        return dist_loss, angle_loss, torsion_angle_loss
    
    def prediction_step4eval(self, batch):
        dist_values, angle_values = self.predict_values(batch)
        dist_targ = self.get_dist_target(batch)
        angle_targ = self.get_angle_target(batch)
        mask = batch['edge_mask']
        mask_angel = self.mask_angles(batch['distance_matrix'])
        dist_loss = self.mse_loss_fn_dist(dist_values.squeeze(-1), dist_targ, mask)
        angle_loss = self.cossin_loss_fn_angle(angle_values, angle_targ, mask_angel)
        return dict(
            dist_loss = dist_loss,
            angle_loss = angle_loss,
        )
    

    # Logging
    def get_verbose_logs(self):
        return OrderedDict([
            ('loss', '0.4f'),
            ('dist_loss', '0.4f'),
            ('angle_loss', '0.4f'),
            ('torsion_angle_loss', '0.4f'),
        ])

    def initialize_losses(self, logs, training):
        self._total_loss = 0.
        self._total_dist_loss = 0.
        self._total_angle_loss = 0.
        self._total_torsion_angle_loss = 0.0
        self._total_samples = 0.

    def update_losses(self, i, loss, inputs, logs, training):
        loss_dist, loss_angle, loss_torsion_angle = loss
        loss = loss_dist + loss_angle + loss_torsion_angle
        
        step_samples = float(inputs['num_nodes'].shape[0])
        if not self.is_distributed:
            step_loss = loss.item() * step_samples
            step_loss_dist = loss_dist.item() * step_samples
            step_loss_angle = loss_angle.item() * step_samples
            step_loss_torsion_angle = loss_torsion_angle.item() * step_samples
        else:
            step_samples = torch.tensor(step_samples, device=loss.device, dtype=loss.dtype)

            if training:
                loss = loss.detach()
                loss_dist = loss_dist.detach()
                loss_angle = loss_angle.detach()
                loss_torsion_angle = loss_torsion_angle.detach()

            step_loss = loss * step_samples
            step_loss_dist = loss_dist * step_samples
            step_loss_angle = loss_angle * step_samples
            step_loss_torsion_angle = loss_torsion_angle * step_samples

            torch.distributed.all_reduce(step_loss)
            torch.distributed.all_reduce(step_loss_dist)
            torch.distributed.all_reduce(step_loss_angle)
            torch.distributed.all_reduce(step_loss_torsion_angle)
            torch.distributed.all_reduce(step_samples)

            step_loss = step_loss.item()
            step_loss_dist = step_loss_dist.item()
            step_loss_angle = step_loss_angle.item()
            step_loss_torsion_angle = step_loss_torsion_angle.item()
            step_samples = step_samples.item()

        if self.config.mixed_precision:
            if step_loss == step_loss or self._nan_loss_count >= 10:
                self._nan_loss_count = 0
                self._total_loss += step_loss
                self._total_dist_loss += step_loss_dist
                self._total_angle_loss += step_loss_angle
                self._total_torsion_angle_loss += step_loss_torsion_angle
                self._total_samples += step_samples
            else:
                self._nan_loss_count += 1
        else:
            self._total_loss += step_loss
            self._total_dist_loss += step_loss_dist
            self._total_angle_loss += step_loss_angle
            self._total_torsion_angle_loss += step_loss_torsion_angle
            self._total_samples += step_samples

        self.update_logs(logs=logs, training=training,
                        loss=self._total_loss / (self._total_samples + 1e-12),
                        dist_loss=self._total_dist_loss / (self._total_samples + 1e-12),
                        angle_loss=self._total_angle_loss / (self._total_samples + 1e-12),
                        torsion_angle_loss=self._total_torsion_angle_loss / (self._total_samples + 1e-12)
                        )

    # def training_step(self, batch, logs):
    #     for param in self.trainable_params:
    #         param.grad = None
        
    #     mixed_precision = self.config.mixed_precision
    #     clip_grad_value = self.config.clip_grad_value
    #     clip_grad_norm = self.config.clip_grad_norm
        
    #     with self.fwd_pass_context():
    #         outputs = self.model(batch)
    #         loss = self.calculate_loss(outputs=outputs, inputs=batch)
    #         loss_sum = loss[0]+loss[1]
    #     if not mixed_precision:
    #         loss_sum.backward()
    #     else:
    #         self.grad_scaler.scale(loss_sum).backward()
        
    #     should_clip_grad_value = clip_grad_value is not None
    #     should_clip_grad_norm = clip_grad_norm is not None
    #     if mixed_precision and (should_clip_grad_value or should_clip_grad_norm):
    #         self.grad_scaler.unscale_(self.optimizer)
        
    #     if should_clip_grad_value:
    #         nn.utils.clip_grad_value_(self.trainable_params, self.config.clip_grad_value)
    #     if should_clip_grad_norm:
    #         nn.utils.clip_grad_norm_(self.trainable_params, self.config.clip_grad_norm)
        
    #     if not mixed_precision:
    #         self.optimizer.step()
    #     else:
    #         self.grad_scaler.step(self.optimizer)
    #         self.grad_scaler.update()
    #     return outputs, loss

    def training_step(self, batch, logs):
        e_repr = None
        p_repr = None
        t_repr = None        
        
        for i in range(3):  
            for param in self.trainable_params:
                param.grad = None
            
            mixed_precision = self.config.mixed_precision
            clip_grad_value = self.config.clip_grad_value
            clip_grad_norm = self.config.clip_grad_norm
            
            with self.fwd_pass_context():
                e, p, t, e_out_repr, p_out_repr, t_out_repr = self.model(batch, e_repr=e_repr, p_repr=p_repr, t_repr=t_repr)
                loss = self.calculate_loss(outputs=(e, p, t), inputs=batch)
                loss_sum = loss[0] + loss[1] + loss[2]
            
            if not mixed_precision:
                loss_sum.backward()
            else:
                self.grad_scaler.scale(loss_sum).backward()
            
            should_clip_grad_value = clip_grad_value is not None
            should_clip_grad_norm = clip_grad_norm is not None
            if mixed_precision and (should_clip_grad_value or should_clip_grad_norm):
                self.grad_scaler.unscale_(self.optimizer)
            
            if should_clip_grad_value:
                nn.utils.clip_grad_value_(self.trainable_params, clip_grad_value)
            if should_clip_grad_norm:
                nn.utils.clip_grad_norm_(self.trainable_params, clip_grad_norm)
            
            if not mixed_precision:
                self.optimizer.step()
            else:
                self.grad_scaler.step(self.optimizer)
                self.grad_scaler.update()

            e_repr = e_out_repr.detach()
            p_repr = p_out_repr.detach()
            t_repr = t_out_repr.detach()
            break
        
        return (e, p, t), loss

    def validation_step(self, batch, logs):
        with self.fwd_pass_context():
            outputs = self.model(batch)
            loss = self.calculate_loss(outputs=outputs, inputs=batch)
        return outputs, loss
    
    def predict_bins(self, batch):
        dist_bins = []
        angle_bins = []
        torsion_bins = []
        nb_samples = self.nb_draw_samples
        valid_samples = 0

        for _ in range(nb_samples * 2):
            logits = self.model(batch)
            dist_logits, angle_logits, torsion_logits = logits[0], logits[1], logits[2]

            # Check for NaN or inf values in logits
            if (torch.isnan(dist_logits).any() or torch.isinf(dist_logits).any() or
                torch.isnan(angle_logits).any() or torch.isinf(angle_logits).any() or
                torch.isnan(torsion_logits).any() or torch.isinf(torsion_logits).any()):
                continue

            # Process dist_logits
            dist_prob = torch.softmax(dist_logits, dim=-1)
            dist_prob = dist_prob + dist_prob.transpose(-2, -3)
            dist_bin = dist_prob.argmax(dim=-1)
            dist_bins.append(dist_bin)

            # Process angle_logits (shape: (b, max_num_angle, num_bins))
            angle_prob = torch.softmax(angle_logits, dim=-1)
            angle_bin = angle_prob.argmax(dim=-1)
            angle_bins.append(angle_bin)

            # Process torsion_logits (shape: (b, max_num_torsion, num_bins))
            torsion_prob = torch.softmax(torsion_logits, dim=-1)
            torsion_bin = torsion_prob.argmax(dim=-1)
            torsion_bins.append(torsion_bin)

            valid_samples += 1
            if valid_samples >= nb_samples:
                break

        if valid_samples < nb_samples:
            nan_samples = nb_samples - valid_samples
            raise ValueError(f'{nan_samples}/{nb_samples} predictions were NaN')

        # Stack the bins along the new dimension
        dist_bins = torch.stack(dist_bins, dim=1)
        angle_bins = torch.stack(angle_bins, dim=1)
        torsion_bins = torch.stack(torsion_bins, dim=1)

        return dist_bins, angle_bins, torsion_bins

        
    
    def prediction_step4savebins(self, batch):
        dist_bins, angle_bins, torsion_bins = self.predict_bins(batch)
        
        idx = batch['idx'].cpu().numpy()
        num_nodes = batch['num_nodes'].cpu().numpy()
        dist_bins = dist_bins.cpu().numpy()
        angle_indices = batch['angle_indices'].cpu().numpy()
        angle_bins = angle_bins.cpu().numpy()
        torsion_indices = batch['torsion_indices'].cpu().numpy()
        torsion_bins = torsion_bins.cpu().numpy()
        
        if self.config.num_dist_bins <= 256:
            dist_bins = dist_bins.astype('uint8')
        elif self.config.num_dist_bins <= 65536:
            dist_bins = dist_bins.astype('uint16')

        if self.config.num_angle_bins <= 256:
            angle_bins = angle_bins.astype('uint8')
        elif self.config.num_angle_bins <= 65536:
            angle_bins = angle_bins.astype('uint16')

        if self.config.num_torsion_bins <= 256:
            torsion_bins = torsion_bins.astype('uint8')
        elif self.config.num_torsion_bins <= 65536:
            torsion_bins = torsion_bins.astype('uint16')
        
        packed_dist_bins = []
        packed_angle_bins = []
        packed_angle_indices = []
        packed_torsion_bins = []
        packed_torsion_indices = []
        for i, n in enumerate(num_nodes):
            packed_dist_bins_i = pbins.pack_bins_multi(dist_bins[i, :, :n, :n])
            packed_dist_bins_i = packed_dist_bins_i.reshape(-1)
            packed_dist_bins.append(packed_dist_bins_i)
        for i in range(len(angle_indices)):
            valid_angle_bins = []
            valid_angle_indices = []
            for j in range(angle_indices.shape[1]):
                if not np.all(angle_indices[i, j] == 0):
                    valid_angle_bins.append(angle_bins[i, :, j])
                    valid_angle_indices.append(angle_indices[i, j])
            if valid_angle_bins:
                packed_angle_bins_i = np.array(valid_angle_bins)
                packed_angle_bins.append(packed_angle_bins_i.flatten())
                packed_angle_indices.append(np.array(valid_angle_indices, dtype=angle_indices.dtype).flatten())
            else:
                packed_angle_bins.append(np.array([], dtype=angle_bins.dtype))
                packed_angle_indices.append(np.array([], dtype=angle_indices.dtype))

        for i in range(len(torsion_indices)):
            valid_torsion_bins = []
            valid_torsion_indices = []
            for j in range(torsion_indices.shape[1]):
                if not np.all(torsion_indices[i, j] == 0):
                    valid_torsion_bins.append(torsion_bins[i, :, j])
                    valid_torsion_indices.append(torsion_indices[i, j])
            if valid_torsion_bins:
                packed_torsion_bins_i = np.array(valid_torsion_bins)
                packed_torsion_bins.append(packed_torsion_bins_i.flatten())
                packed_torsion_indices.append(np.array(valid_torsion_indices, dtype=torsion_indices.dtype).flatten())
            else:
                packed_torsion_bins.append(np.array([], dtype=torsion_bins.dtype))
                packed_torsion_indices.append(np.array([], dtype=torsion_indices.dtype))

        return dict(
            idx=idx,
            dist_bins=packed_dist_bins,
            angle_bins=packed_angle_bins,
            angle_indices=packed_angle_indices,
            torsion_bins=packed_torsion_bins,
            torsion_indices=packed_torsion_indices
        )


    # def predict_values(self, batch, angle=False):
    #     dist_values, angle_values = [], []
    #     nb_samples = self.nb_draw_samples
    #     valid_samples = 0
    #     for _ in range(nb_samples * 2):
    #         if angle:
    #             dists, angles = self.model(batch)
    #             if torch.isnan(dists).any() or torch.isinf(dists).any() or torch.isnan(angles).any() or torch.isinf(angles).any():
    #                 continue
    #             new_dists = dists + dists.transpose(-2, -3)
    #             new_angles = angles + angles.transpose(-2, -3)
    #             dist_values.append(new_dists)
    #             angle_values.append(new_angles)
    #             valid_samples += 1
    #         else:
    #             dists = self.model(batch)[0]
    #             if torch.isnan(dists).any() or torch.isinf(dists).any():
    #                 continue
    #             new_dists = dists + dists.transpose(-2, -3)
    #             dist_values.append(new_dists)
    #             valid_samples += 1

    #         if valid_samples >= nb_samples:
    #             break

    #     if valid_samples < nb_samples:
    #         nan_samples = nb_samples - valid_samples
    #         raise ValueError(f'{nan_samples}/{nb_samples} predictions were NaN')

    #     dist_values = torch.stack(dist_values, dim=1)
    #     if angle_values:  # Check if angle_values is not empty
    #         angle_values = torch.stack(angle_values, dim=1)
    #     else:
    #         angle_values = None  # Or handle it differently depending on the context

    #     return dist_values, angle_values
        
    
    # def prediction_step4savevalues(self, batch):
    #     # print(batch['num_nodes'].max())
    #     dist_values, angle_values, torsion_angle_values = self.predict_values(batch)
        
    #     idx = batch['idx'].cpu().numpy()
    #     num_nodes = batch['num_nodes'].cpu().numpy()
    #     dist_values = dist_values.cpu().numpy()
    #     angle_values = angle_values.cpu().numpy()
        
    #     packed_dist_values = []
    #     for i,n in enumerate(num_nodes):
    #         packed_values_i = pbins.pack_bins_multi(dist_values[i,:,:n,:n].squeeze(-1))
    #         packed_values_i = packed_values_i.reshape(-1)
    #         packed_dist_values.append(packed_values_i)
        
    #     return dict(
    #         idx = idx,
    #         dists = packed_dist_values,
    #     )
        
    def prediction_step(self, batch):
        if self.executing_command == 'predict':
            return self.prediction_step4savebins(batch)
        else:
            return self.prediction_step4eval(batch)
    
    def prediction_loop(self, dataloader):
        return super().prediction_loop(dataloader,
                       predict_in_train=self.config.predict_in_train)
    
    def evaluate_predictions(self, predictions,
                             dataset_name='validation',
                             evaluation_stage=False):
        dist_loss = predictions['dist_loss'].mean()
        angle_loss = predictions['angle_loss'].mean()
        return dict(
            dist_loss = dist_loss,
            angle_loss = angle_loss,
        )
    
    def evaluate_on(self, dataset_name, dataset, predictions):
        if self.is_main_rank: print(f'Evaluating on {dataset_name}')
        results = self.evaluate_predictions(predictions,
                                            dataset_name=dataset_name,
                                            evaluation_stage=True)
        return results    
    
    def predict_and_save(self):
        if not self.executing_command == 'predict':
            return super().predict_and_save()
        
        import pyarrow as pa
        import pyarrow.parquet as pq
        import os
        
        def collate_columns(predictions, column_names):
            collated = {col:[] for col in column_names}
            for batch in predictions:
                for col in column_names:
                    collated[col].extend(batch[col])
            return collated
        
        save_pred_dir = os.path.join(self.config.predictions_path,
                                     self.config.save_pred_dir)
        os.makedirs(save_pred_dir, exist_ok=True)
        data_dir = os.path.join(save_pred_dir, 'data')
        os.makedirs(data_dir, exist_ok=True)
        meta_path = os.path.join(save_pred_dir, 'meta.json')
        for dataset_name in self.config.predict_on:
            if self.is_main_rank:
                print(f'Predicting on {dataset_name} dataset...')
            dataloader = self.test_dataloader_for_dataset(dataset_name)
            print("Total data items in DataLoader:", len(dataloader.dataset))
            outputs = self.prediction_loop(dataloader)

            table_dict = dict(
                idx = np.concatenate([o['idx'] for o in outputs]),
            )
            column_names = ['dist_bins', 'angle_bins', 'angle_indices', 'torsion_bins', 'torsion_indices']
            columns = collate_columns(outputs, column_names)
            table_dict.update(columns)

            if self.is_main_rank:
                import json
                meta = dict(
                    num_bins = self.config.num_dist_bins,
                    range_bins = self.config.range_dist_bins,
                    num_samples = self.nb_draw_samples,
                )
                with open(meta_path, 'w') as f:
                    json.dump(meta, f)
            
            
            table = pa.Table.from_pydict(table_dict)
            save_path = os.path.join(data_dir, 
                                     f'{dataset_name}_{self.ddp_rank:03d}.parquet')
            pq.write_table(table, save_path)
            print(f'Rank {self.ddp_rank} saved {dataset_name} to {save_path}')
            
    def get_outer_4_layers(self):  
        outer_layers = set()  
        for name, _ in self.model.named_parameters():  
            parts = name.split('.')[:4]  
            outer_layers.add('.'.join(parts))  
        return list(outer_layers)
    
    def train_epoch(self, epoch, logs,
                    minimal=False,
                    train_in_eval=False):
        if not train_in_eval:
            self.model.train()
        else:
            self.model.eval()
        
        if not minimal:
            self.initialize_losses(logs, True)
            self.initialize_metrics(logs, True)
        
        if self.is_distributed:
            self.train_sampler.set_epoch(epoch)
        
        gen = self.train_dataloader
        if self.is_main_rank:
            gen = self.progbar(gen, total=self.train_steps_per_epoch,
                               desc='Training: ')

        hooks = []  
        writer = SummaryWriter(self.config.log_path)    
        layer_tracker = LayerStatisticsTracker(writer, self.model)

        try:
            for name, module in self.model.named_modules():  
                if name in layer_tracker.target_layers:  
                    hooks.append(module.register_forward_hook(layer_tracker))

            for i, batch in enumerate(gen):
                self.on_batch_begin(i, logs, True)
                batch = self.preprocess_batch(batch=batch, training=True)
                outputs, loss = self.training_step(batch, logs)
                
                self.state.global_step = self.state.global_step + 1
                logs.update(global_step=self.state.global_step)
                
                if not minimal:
                    self.update_losses(i, loss, batch, logs, True)
                    self.update_metrics(outputs, batch, logs, True)

                if i % 200 == 0:  
                    layer_tracker.record_statistics(self.state.global_step)
                
                self.on_batch_end(i, logs, True)
                
                if self.is_main_rank and not minimal:
                    desc = 'Training: '+'; '.join(self.log_description(epoch * len(gen) + i, logs, True))
                    gen.set_description(desc)
        finally:
            for hook in hooks:  
                hook.remove() 
            if self.is_main_rank: gen.close()
            for param in self.trainable_params:
                param.grad = None
            writer.close()  
