import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
from collections import deque 

from collections import OrderedDict
from lib.training.hyperdict import HDict
from lib.models.pcqm.multitask import TGT_Multi
from lib.data.pcqm import data
from lib.data.pcqm.structural_transform import AddStructuralData
from ..tgt_training import TGTTraining
from ..commons import DiscreteDistLoss, DiscreteAngleLoss, MaskedL1Loss, add_coords_noise, coords2dist, coords2angle, mask_angles, calculate_angle, calculate_torsion_angle

class NaNDetectedError(Exception):  
    pass  

def convert_float16_to_float32(data):  
    if isinstance(data, torch.Tensor) and data.dtype == torch.float16:  
        return data.float()  
    elif isinstance(data, dict):  
        return {k: convert_float16_to_float32(v) for k, v in data.items()}  
    elif isinstance(data, list):  
        return [convert_float16_to_float32(v) for v in data]  
    else:  
        return data 

def contains_nan(x):  
    if isinstance(x, torch.Tensor):  
        return torch.isnan(x).any()  
    elif isinstance(x, (list, tuple)):  
        return any(contains_nan(item) for item in x)  
    elif isinstance(x, dict):  
        return any(contains_nan(value) for value in x.values())  
    return False  

def print_io_info(module_name, inputs, outputs):  
    print("\n" + "=" * 50)  
    print(f"Layer: {module_name}")  
    print("-" * 50)  
    
    print("Inputs:")  
    if isinstance(inputs, tuple):  
        for i, input in enumerate(inputs):  
            if isinstance(input, torch.Tensor):  
                print(f"  Input {i}:")  
                print(f"    Shape: {input.shape}, dtype = {input.dtype}")  
                print(f"    Contains NaN: {contains_nan(input)}")  
                if input.numel() > 0:  
                    print(f"    Min: {input.min().item():.4f}, Max: {input.max().item():.4f}")  
            else:  
                print(f"  Input {i}: type = {type(input)}")  
    else:  
        print(f"  Input: type = {type(inputs)}")  
    
    print("-" * 50)  
    print("Outputs:")  
    if isinstance(outputs, torch.Tensor):  
        print("  Output:")  
        print(f"    Shape: {outputs.shape}, dtype = {outputs.dtype}")  
        print(f"    Contains NaN: {contains_nan(outputs)}")  
        if outputs.numel() > 0:  
            print(f"    Min: {outputs.min().item():.4f}, Max: {outputs.max().item():.4f}")  
    elif isinstance(outputs, tuple):  
        for i, output in enumerate(outputs):  
            if isinstance(output, torch.Tensor):  
                print(f"  Output {i}:")  
                print(f"    Shape: {output.shape}")  
                print(f"    Contains NaN: {contains_nan(output)}")  
                if output.numel() > 0:  
                    print(f"    Min: {output.min().item():.4f}, Max: {output.max().item():.4f}")  
            else:  
                print(f"  Output {i}: type = {type(output)}")  
    else:  
        print(f"  Output: type = {type(outputs)}")  
    
    print("=" * 50 + "\n") 

class LayerTracker:  
    def __init__(self, max_layers=50):  
        self.max_layers = max_layers  
        self.layers = deque(maxlen=max_layers)  
        self.nan_detected = False  
        self.nan_layer_index = None  

    def add_layer(self, module_name, inputs, outputs):  
        self.layers.append((module_name, inputs, outputs))  
        if not self.nan_detected and contains_nan(inputs) or contains_nan(outputs):  
            self.nan_detected = True  
            self.nan_layer_index = len(self.layers) - 1  

    def print_context(self):  
        if self.nan_detected:  
            start = max(0, self.nan_layer_index - self.max_layers)  
            end = min(len(self.layers), self.nan_layer_index + self.max_layers + 1)  
            
            print(f"Printing information for layers {start} to {end-1}:")  
            for i in range(start, end):  
                module_name, inputs, outputs = self.layers[i]  
                print(f"{'*' * 20} Layer {i} {'*' * 20}")  
                print_io_info(module_name, inputs, outputs)  

def forward_hook(module_name, layer_tracker):  
    def hook(module, inputs, outputs):  
        layer_tracker.add_layer(module_name, inputs, outputs)  
        if layer_tracker.nan_detected:  
            message = f"NaN detected in module: {module_name}\n"  
            print(message)  
            print("Printing context information:")  
            layer_tracker.print_context()  
            raise NaNDetectedError(message)  

    return hook  


class SCHEME(TGTTraining):
    def get_default_config(self):
        config_dict = super().get_default_config()
        config_dict.update(
            save_path_prefix    = HDict.L(lambda c: 'models/pcqm/pretrain' if c.model_prefix is None else f'models/pcqm/{c.model_prefix}/pretrain'),
            coords_noise        = 0.5,
            coords_noise_smooth = 1.0,
            embed_3d_type       = 'gaussian',
            num_dist_bins       = 256,
            range_dist_bins     = 8,
            train_split         = 'train-3d',
            val_split           = 'valid-3d',
            test_split          = 'test-dev',
            dist_loss_weight    = 0.1,
            angle_loss_weight   = 0.1,
        )
        return config_dict
    
    def __post_init__(self):
        super().__post_init__()
        self.xent_loss_fn_dist = DiscreteDistLoss(num_bins=self.config.num_dist_bins,range_bins=self.config.range_dist_bins)

        self.xent_loss_fn_angle = DiscreteAngleLoss(num_bins=self.config.num_dist_bins,max_angle=2*3.14159)
    
    def get_dataset_config(self, split):
        dataset_config, _ = super().get_dataset_config(split)
        ds_split = {'train':self.config.train_split,
                    'val':self.config.val_split,
                    'test':self.config.test_split}[split]
        
        columns = [data.Coords('dft')]
        
        is_train_split = 'train' in split
        transforms = [AddStructuralData()]
        dataset_config.update(
            split               = ds_split,
            dataset_path        = self.config.dataset_path,
            transforms          = transforms,
            verbose             = int(self.is_main_rank and is_train_split),
            additional_columns  = columns,
        )
        return dataset_config, data.PCQM4Mv2Dataset
    
    def get_model_config(self):
        model_config, _ = super().get_model_config()
        model_config.update(
            num_dist_bins = self.config.num_dist_bins,
        )
        return model_config, TGT_Multi
    
    def preprocess_batch(self, batch, training):
        batch = super().preprocess_batch(batch, training)
        
        node_mask = batch['node_mask']
        edge_mask = node_mask.unsqueeze(-1) * node_mask.unsqueeze(-2)
        batch['edge_mask'] = edge_mask.cuda()
        angle_indices = batch['angle_indices'].long()
        torsion_indices = batch['torsion_indices'].long()

        coords = batch['dft_coords']
        if self.config.coords_noise > 0:
            coords = add_coords_noise(coords=coords,
                                    edge_mask=edge_mask,
                                    noise_level=self.config.coords_noise,
                                    noise_smoothing=self.config.coords_noise_smooth)
        
        dist_input = coords2dist(coords)

        angle_input = calculate_angle(coords,angle_indices)

        torsion_input = calculate_torsion_angle(coords,torsion_indices)

        # print(f'node_input size {node_mask.size()} angle_input size {angle_input.size()} torsion_input size {torsion_input.size()}')

        angle_mask = (angle_indices != 0).any(dim=-1)
        torsion_mask = (torsion_indices != 0).any(dim=-1)
        batch['angle_mask'] = angle_mask.cuda()
        batch['torsion_mask'] = torsion_mask.cuda()
        batch['dist_input'] = dist_input.cuda()
        batch['angle_input'] = angle_input.cuda()
        batch['angle_indices'] = angle_indices.cuda()
        batch['torsion_indices'] = torsion_indices.cuda()
        batch['torsion_input'] = torsion_input.cuda()
        
        return batch
    
    def calculate_loss(self, outputs, inputs):
        gap_pred, dist_pred, angle_pred, torsion_angle_pred = outputs
        
        prim_loss = F.l1_loss(gap_pred, inputs['target'])
        
        dist_targ = coords2dist(inputs['dft_coords']).cuda()
        angle_targ = calculate_angle(inputs['dft_coords'],inputs['angle_indices'].long()).cuda()
        torsion_targ = calculate_torsion_angle(inputs['dft_coords'],inputs['torsion_indices'].long()).cuda()

        mask = inputs['edge_mask']
        mask_angle = inputs['angle_mask']
        mask_torsion_angle = inputs['torsion_mask']

        dist_loss = self.xent_loss_fn_dist(dist_pred, dist_targ, mask)
        # print_statistics(dist_pred, 'dist_pred')
        # print_statistics(dist_targ, 'dist_targ')
        # print_statistics(dist_loss, 'dist_loss')
        angle_loss = self.xent_loss_fn_angle(angle_pred, angle_targ, mask_angle)*0.1
        torsion_angle_loss = self.xent_loss_fn_angle(torsion_angle_pred, torsion_targ, mask_torsion_angle.float())*0.2
        
        return prim_loss, self.config.dist_loss_weight * dist_loss, self.config.angle_loss_weight * angle_loss, self.config.angle_loss_weight * torsion_angle_loss 

    def get_verbose_logs(self):  
        logs = OrderedDict([  
            ('loss', '0.4f'),  
            ('gap_loss', '0.4f'),  
            ('dist_loss', '0.4f'),  
            ('angle_loss', '0.4f'),  
            ('torsion_angle_loss', '0.4f'),  
            # ('total_grad_norm', '0.4f'),   
        ])  
        

        
        return logs 

    def initialize_losses(self, logs, training):
        self._total_loss = 0.
        self._total_gap_loss = 0.
        self._total_dist_loss = 0.
        self._total_angle_loss = 0.
        self._total_torsion_angle_loss = 0.0
        self._total_samples = 0.

    def update_losses(self, i, loss, inputs, logs, training):  
        gap_loss, loss_dist, loss_angle, loss_torsion_angle = loss  
        loss = gap_loss + loss_dist + loss_angle + loss_torsion_angle  

        step_samples = float(inputs['num_nodes'].shape[0])  
        if not self.is_distributed:  
            step_loss = loss.item() * step_samples  
            step_loss_dist = loss_dist.item() * step_samples  
            step_loss_angle = loss_angle.item() * step_samples  
            step_loss_torsion_angle = loss_torsion_angle.item() * step_samples  
            step_loss_gap = gap_loss.item() * step_samples  
        else:  
            step_samples = torch.tensor(step_samples, device=loss.device, dtype=loss.dtype)  

            if training:  
                loss = loss.detach()  
                loss_dist = loss_dist.detach()  
                loss_angle = loss_angle.detach()  
                loss_torsion_angle = loss_torsion_angle.detach()  
                gap_loss = gap_loss.detach()  

            step_loss = loss * step_samples  
            step_loss_dist = loss_dist * step_samples  
            step_loss_angle = loss_angle * step_samples  
            step_loss_torsion_angle = loss_torsion_angle * step_samples  
            step_loss_gap = gap_loss * step_samples  

            tensor_list = [step_loss, step_loss_dist, step_loss_angle, step_loss_torsion_angle, step_loss_gap, step_samples]  
            tensor = torch.stack(tensor_list)  
            torch.distributed.all_reduce(tensor)  

            step_loss, step_loss_dist, step_loss_angle, step_loss_torsion_angle, step_loss_gap, step_samples = [t.item() for t in tensor]  
  
        if self.config.mixed_precision:  
            if step_loss == step_loss or self._nan_loss_count >= 10:  
                self._nan_loss_count = 0  
                self._total_loss += step_loss  
                self._total_dist_loss += step_loss_dist  
                self._total_angle_loss += step_loss_angle  
                self._total_torsion_angle_loss += step_loss_torsion_angle  
                self._total_gap_loss += step_loss_gap  
                self._total_samples += step_samples  
            else:  
                self._nan_loss_count += 1  
        else:  
            self._total_loss += step_loss  
            self._total_dist_loss += step_loss_dist  
            self._total_angle_loss += step_loss_angle  
            self._total_torsion_angle_loss += step_loss_torsion_angle  
            self._total_gap_loss += step_loss_gap  
            self._total_samples += step_samples  

        self.update_logs(logs=logs, training=training,  
                        loss=self._total_loss / (self._total_samples + 1e-12),  
                        dist_loss=self._total_dist_loss / (self._total_samples + 1e-12),  
                        angle_loss=self._total_angle_loss / (self._total_samples + 1e-12),  
                        torsion_angle_loss=self._total_torsion_angle_loss / (self._total_samples + 1e-12),  
                        gap_loss=self._total_gap_loss / (self._total_samples + 1e-12))

    def training_step(self, batch, logs):
        self.model = self.model.float()  
        batch = convert_float16_to_float32(batch)    

        for param in self.trainable_params:
            param.grad = None
        
        mixed_precision = self.config.mixed_precision
        clip_grad_value = self.config.clip_grad_value
        clip_grad_norm = self.config.clip_grad_norm
        
        with torch.cuda.amp.autocast(enabled=False):
            with self.fwd_pass_context():
                outputs = self.model(batch)
                loss = self.calculate_loss(outputs=outputs, inputs=batch)
                loss_sum = loss[0]+loss[1]+loss[2]+loss[3]
            if not mixed_precision:
                loss_sum.backward()
            else:
                self.grad_scaler.scale(loss_sum).backward()
        
        should_clip_grad_value = clip_grad_value is not None
        should_clip_grad_norm = clip_grad_norm is not None
        if mixed_precision and (should_clip_grad_value or should_clip_grad_norm):
            self.grad_scaler.unscale_(self.optimizer)
        
        if should_clip_grad_value:
            nn.utils.clip_grad_value_(self.trainable_params, self.config.clip_grad_value)
        if should_clip_grad_norm:
            nn.utils.clip_grad_norm_(self.trainable_params, self.config.clip_grad_norm)
        
        if not mixed_precision:
            self.optimizer.step()
        else:
            self.grad_scaler.step(self.optimizer)
            self.grad_scaler.update()
        return outputs, loss

    def prediction_step(self, batch):
        gap_pred = None
        dist_preds = None
        angle_preds = None
        nb_samples = self.nb_draw_samples
        valid_samples = 0
        for _ in range(nb_samples*2):
            new_gap_pred, new_dist_preds, new_angle_preds = self.model(batch)
            if torch.isnan(new_gap_pred).any() or torch.isinf(new_gap_pred).any()\
                or torch.isnan(new_dist_preds).any() or torch.isinf(new_dist_preds).any()\
                or torch.isnan(new_angle_preds).any() or torch.isinf(new_angle_preds).any():
                continue
            
            if gap_pred is None:
                gap_pred = new_gap_pred
            else:
                gap_pred.add_(new_gap_pred)
            
            if dist_preds is None:
                dist_preds = new_dist_preds
            else:
                dist_preds.add_(new_dist_preds)

            if angle_preds is None:
                angle_preds = new_angle_preds
            else:
                angle_preds.add_(new_angle_preds)
            
            valid_samples += 1
            if valid_samples >= nb_samples:
                break
        
        if not valid_samples:
            raise ValueError('All predictions were NaN')
        elif valid_samples < nb_samples:
            nan_samples = nb_samples - valid_samples
            print(f'Warning: '
                  f'{nan_samples}/{nb_samples} predictions were NaN')
        
        gap_pred.div_(valid_samples)
        gap_loss = torch.abs(gap_pred - batch['target'])
        
        dist_preds = dist_preds + dist_preds.transpose(-2,-3)
        dist_preds.div_(valid_samples*2)
        dist_targ = coords2dist(batch['dft_coords'])
        dist_loss = self.mse_loss_fn_dist(dist_preds.squeeze(-1), dist_targ, batch['edge_mask'])

        angle_preds = angle_preds + angle_preds.transpose(-2,-3)
        angle_preds.div_(valid_samples*2)
        angle_targ = coords2angle_cos_sin(batch['dft_coords'])
        mask_angel = self.mask_angles(batch['distance_matrix'])
        angle_loss = self.cossin_loss_fn_angle(angle_preds, angle_targ, mask_angel)

        return dict(
            gap_loss = gap_loss,
            dist_loss = dist_loss,
            angle_loss = angle_loss
        )
    
    def prediction_loop(self, dataloader):
        return super().prediction_loop(dataloader,
                       predict_in_train=self.config.predict_in_train)
    
    def evaluate_predictions(self, predictions,
                             dataset_name='validation',
                             evaluation_stage=False):
        gap_loss = predictions['gap_loss'].mean()
        dist_loss = predictions['dist_loss'].mean()
        angle_loss = predictions['angle_loss'].mean()
        loss = gap_loss \
                + self.config.dist_loss_weight * dist_loss
        return dict(
            gap_loss = gap_loss,
            dist_loss = dist_loss,
            angle_loss = angle_loss,
            loss = loss,
        )
    
    def evaluate_on(self, dataset_name, dataset, predictions):
        if self.is_main_rank: print(f'Evaluating on {dataset_name}')
        results = self.evaluate_predictions(predictions,
                                            dataset_name=dataset_name,
                                            evaluation_stage=True)
        return results    

    def get_outer_4_layers(self):  
        outer_layers = set()  
        for name, _ in self.model.named_parameters():  
            parts = name.split('.')[:4]    
            outer_layers.add('.'.join(parts))  
        return list(outer_layers) 

    def record_gradients(self):  
        total_norm = 0.  
        outer_layers = self.get_outer_4_layers()  
        layer_norms = {layer: 0. for layer in outer_layers}  

        for name, param in self.model.named_parameters():  
            if param.grad is not None:  
                param_norm = param.grad.data.norm(2)  
                total_norm += param_norm.item() ** 2  
 
                for layer in outer_layers:  
                    if name.startswith(layer):  
                        layer_norms[layer] += param_norm.item() ** 2  
                        break  

        total_norm = total_norm ** 0.5  

        if self.is_distributed:  
            tensor_list = [torch.tensor(total_norm).cuda()]  
            tensor_list.extend([torch.tensor(norm).cuda() for norm in layer_norms.values()])  
            tensor = torch.stack(tensor_list)  
            torch.distributed.all_reduce(tensor)  
            total_norm = tensor[0].item()  
            layer_norms = dict(zip(layer_norms.keys(), tensor[1:].tolist()))  

        return total_norm, layer_norms

    def train_epoch(self, epoch, logs, minimal=False, train_in_eval=False):  
        if not train_in_eval:  
            self.model.train()  
        else:  
            self.model.eval()  
        
        if not minimal:  
            self.initialize_losses(logs, True)  
            self.initialize_metrics(logs, True)  
        
        if self.is_distributed:  
            self.train_sampler.set_epoch(epoch)  
        
        gen = self.train_dataloader  
        if self.is_main_rank:  
            gen = self.progbar(gen, total=self.train_steps_per_epoch, desc='Training: ')  
        
        hooks = []  
        # writer = SummaryWriter(self.config.log_path)  # 创建TensorBoard writer   
        # layer_tracker = LayerStatisticsTracker(writer, self.model)  
        try:  
            # for name, module in self.model.named_modules():  
            #     if name in layer_tracker.target_layers:  
            #         hooks.append(module.register_forward_hook(layer_tracker))

            for i, batch in enumerate(gen):    
                self.on_batch_begin(i, logs, True)  
                batch = self.preprocess_batch(batch=batch, training=True)  
                outputs, loss = self.training_step(batch, logs)  
                
                self.state.global_step = self.state.global_step + 1  
                logs.update(global_step=self.state.global_step)  
                
                if not minimal:  
                    self.update_losses(i, loss, batch, logs, True)  
                    self.update_metrics(outputs, batch, logs, True)  

                # if i % 400 == 0:  
                #     total_norm, layer_norms = self.record_gradients()  
                #     self.update_logs(logs=logs, training=True, total_grad_norm=total_norm)  
                #     for layer, norm in layer_norms.items():  
                #         self.update_logs(logs=logs, training=True, **{f'grad_norm/{layer}': norm})
                    # layer_tracker.record_statistics(self.state.global_step)  
                
                self.on_batch_end(i, logs, True)  
                
                if self.is_main_rank and not minimal:  
                    desc = 'Training: '+'; '.join(self.log_description(epoch * len(gen) + i, logs, True))  
                    gen.set_description(desc)  
                
                # layer_tracker = LayerTracker()  # Reset the tracker for each batch   
        finally:  
            for hook in hooks:  
                hook.remove()  
            if self.is_main_rank:   
                gen.close()  
            for param in self.trainable_params:  
                param.grad = None
            # writer.close()  # 关闭TensorBoard writer      