import torch
import numpy as np
import torch.nn.functional as F

from lib.training.hyperdict import HDict
from lib.models.pcqm.multitask import TGT_Multi
from lib.data.pcqm import data
from lib.data.pcqm.structural_transform import AddStructuralData
from ..commons import DiscreteDistLoss, coords2dist, BinsProcessor
from ..tgt_training import TGTTraining
from ..commons import DiscreteDistLoss, DiscreteAngleLoss, MaskedL1Loss, add_coords_noise, coords2dist, coords2angle, mask_angles, calculate_angle, calculate_torsion_angle

class SCHEME(TGTTraining):
    def get_default_config(self):
        config_dict = super().get_default_config()
        config_dict.update(
            save_path_prefix    = HDict.L(lambda c: 'models/pcqm/finetune' if c.model_prefix is None else f'models/pcqm/{c.model_prefix}/finetune'),
            embed_3d_type       = 'gaussian',
            num_dist_bins       = 256,
            range_dist_bins     = 8,
            train_split         = 'train',
            val_split           = 'valid',
            test_split          = 'test-dev',
            dist_loss_weight    = 0.1,
            angle_loss_weight   = 0.1,
            bins_input_path     = None,
            bins_shift_half     = True,
            bins_zero_diag      = True,
        )
        return config_dict
    
    def __post_init__(self):
        super().__post_init__()
        if self.executing_command == 'evaluate':
            self.nb_draw_samples = self.config.prediction_samples
        self.dist_loss_fn = DiscreteDistLoss(num_bins=self.config.num_dist_bins,
                                             range_bins=self.config.range_dist_bins)
        self.xent_loss_fn_angle = DiscreteAngleLoss(num_bins=self.config.num_dist_bins,max_angle=2*3.14159)
        if self.config.bins_input_path is not None:
            self.bins_proc = BinsProcessor(self.config.bins_input_path,
                                           shift_half=self.config.bins_shift_half,
                                           zero_diag=self.config.bins_zero_diag)
    
    def get_dataset_config(self, split):
        dataset_config, _ = super().get_dataset_config(split)
        ds_split = {'train':self.config.train_split,
                    'val':self.config.val_split,
                    'test':self.config.test_split}[split]
        
        columns = [data.Bins(self.bins_proc.data_path,
                             self.bins_proc.num_samples)]
        
        is_train_split = 'train' in split
        if is_train_split and self.executing_command == 'train':
            columns.append(data.Coords('dft'))
        
        transforms = [AddStructuralData()]
        dataset_config.update(
            split               = ds_split,
            dataset_path        = self.config.dataset_path,
            transforms          = transforms,
            verbose             = int(self.is_main_rank and is_train_split),
            additional_columns  = columns,
        )
        return dataset_config, data.PCQM4Mv2Dataset
    
    def get_model_config(self):
        model_config, _ = super().get_model_config()
        model_config.update(
            num_dist_bins = self.config.num_dist_bins,
        )
        return model_config, TGT_Multi
    
    def preprocess_batch(self, batch, training):
        batch = super().preprocess_batch(batch, training)
        
        # node_mask = batch['node_mask']
        # edge_mask = node_mask.unsqueeze(-1) * node_mask.unsqueeze(-2)
        # batch['edge_mask'] = edge_mask

        node_mask = batch['node_mask']
        edge_mask = node_mask.unsqueeze(-1) * node_mask.unsqueeze(-2)
        batch['edge_mask'] = edge_mask.cuda()
        angle_indices = batch['angle_indices'].long()
        torsion_indices = batch['torsion_indices'].long()

        if training:
            all_bins = batch['dist_bins']
            cur_sample = self.state.current_epoch % all_bins.size(1)
            bins = all_bins[:,cur_sample]
        else:
            bins = batch['dist_bins']
        
        dist_input = self.bins_proc.bins2dist(bins)
        # batch['dist_input'] = dist_input

        angle_mask = (angle_indices != 0).any(dim=-1)
        torsion_mask = (torsion_indices != 0).any(dim=-1)
        batch['angle_mask'] = angle_mask.cuda()
        batch['torsion_mask'] = torsion_mask.cuda()
        batch['dist_input'] = dist_input.cuda()
        batch['angle_input'] = angle_input.cuda()
        batch['angle_indices'] = angle_indices.cuda()
        batch['torsion_indices'] = torsion_indices.cuda()
        batch['torsion_input'] = torsion_input.cuda()
        
        return batch
    
    def calculate_loss(self, outputs, inputs):
        # gap_pred, dist_logits = outputs
        
        # prim_loss = F.l1_loss(gap_pred, inputs['target'])
        
        # dist_targ = coords2dist(inputs['dft_coords'])
        # dist_loss = self.dist_loss_fn(dist_logits, dist_targ, inputs['edge_mask'])
        
        # loss = prim_loss \
        #        + self.config.dist_loss_weight * dist_loss
        # return loss

        gap_pred, dist_pred, angle_pred, torsion_angle_pred = outputs
        
        prim_loss = F.l1_loss(gap_pred, inputs['target'])
        
        dist_targ = coords2dist(inputs['dft_coords']).cuda()
        angle_targ = calculate_angle(inputs['dft_coords'],inputs['angle_indices'].long()).cuda()
        torsion_targ = calculate_torsion_angle(inputs['dft_coords'],inputs['torsion_indices'].long()).cuda()

        mask = inputs['edge_mask']
        mask_angle = inputs['angle_mask']
        mask_torsion_angle = inputs['torsion_mask']

        dist_loss = self.xent_loss_fn_dist(dist_pred, dist_targ, mask)
        angle_loss = self.xent_loss_fn_angle(angle_pred, angle_targ, mask_angle)*0.1
        torsion_angle_loss = self.xent_loss_fn_angle(torsion_angle_pred, torsion_targ, mask_torsion_angle.float())*0.2
        
        return prim_loss, self.config.dist_loss_weight * dist_loss, self.config.angle_loss_weight * angle_loss, self.config.angle_loss_weight * torsion_angle_loss

    def get_verbose_logs(self):  
        logs = OrderedDict([  
            ('loss', '0.4f'),  
            ('gap_loss', '0.4f'),  
            ('dist_loss', '0.4f'),  
            ('angle_loss', '0.4f'),  
            ('torsion_angle_loss', '0.4f'),  
            # ('total_grad_norm', '0.4f'),   
        ])  
        
        return logs 

    def initialize_losses(self, logs, training):
        self._total_loss = 0.
        self._total_gap_loss = 0.
        self._total_dist_loss = 0.
        self._total_angle_loss = 0.
        self._total_torsion_angle_loss = 0.0
        self._total_samples = 0.

    def update_losses(self, i, loss, inputs, logs, training):  
 
        gap_loss, loss_dist, loss_angle, loss_torsion_angle = loss  
        loss = gap_loss + loss_dist + loss_angle + loss_torsion_angle  

        step_samples = float(inputs['num_nodes'].shape[0])  
        if not self.is_distributed:  
            step_loss = loss.item() * step_samples  
            step_loss_dist = loss_dist.item() * step_samples  
            step_loss_angle = loss_angle.item() * step_samples  
            step_loss_torsion_angle = loss_torsion_angle.item() * step_samples  
            step_loss_gap = gap_loss.item() * step_samples  
        else:  
            step_samples = torch.tensor(step_samples, device=loss.device, dtype=loss.dtype)  

            if training:  
                loss = loss.detach()  
                loss_dist = loss_dist.detach()  
                loss_angle = loss_angle.detach()  
                loss_torsion_angle = loss_torsion_angle.detach()  
                gap_loss = gap_loss.detach()  

            step_loss = loss * step_samples  
            step_loss_dist = loss_dist * step_samples  
            step_loss_angle = loss_angle * step_samples  
            step_loss_torsion_angle = loss_torsion_angle * step_samples  
            step_loss_gap = gap_loss * step_samples  

            tensor_list = [step_loss, step_loss_dist, step_loss_angle, step_loss_torsion_angle, step_loss_gap, step_samples]  
            tensor = torch.stack(tensor_list)  
            torch.distributed.all_reduce(tensor)  

            step_loss, step_loss_dist, step_loss_angle, step_loss_torsion_angle, step_loss_gap, step_samples = [t.item() for t in tensor]  

 
        if self.config.mixed_precision:  
            if step_loss == step_loss or self._nan_loss_count >= 10:  
                self._nan_loss_count = 0  
                self._total_loss += step_loss  
                self._total_dist_loss += step_loss_dist  
                self._total_angle_loss += step_loss_angle  
                self._total_torsion_angle_loss += step_loss_torsion_angle  
                self._total_gap_loss += step_loss_gap  
                self._total_samples += step_samples  
            else:  
                self._nan_loss_count += 1  
        else:  
            self._total_loss += step_loss  
            self._total_dist_loss += step_loss_dist  
            self._total_angle_loss += step_loss_angle  
            self._total_torsion_angle_loss += step_loss_torsion_angle  
            self._total_gap_loss += step_loss_gap  
            self._total_samples += step_samples  

 
        self.update_logs(logs=logs, training=training,  
                        loss=self._total_loss / (self._total_samples + 1e-12),  
                        dist_loss=self._total_dist_loss / (self._total_samples + 1e-12),  
                        angle_loss=self._total_angle_loss / (self._total_samples + 1e-12),  
                        torsion_angle_loss=self._total_torsion_angle_loss / (self._total_samples + 1e-12),  
                        gap_loss=self._total_gap_loss / (self._total_samples + 1e-12))

    def training_step(self, batch, logs):
        self.model = self.model.float()  
        batch = convert_float16_to_float32(batch)    

        for param in self.trainable_params:
            param.grad = None
        
        mixed_precision = self.config.mixed_precision
        clip_grad_value = self.config.clip_grad_value
        clip_grad_norm = self.config.clip_grad_norm
        
        with torch.cuda.amp.autocast(enabled=False):
            with self.fwd_pass_context():
                outputs = self.model(batch)
                loss = self.calculate_loss(outputs=outputs, inputs=batch)
                loss_sum = loss[0]+loss[1]+loss[2]+loss[3]
            if not mixed_precision:
                loss_sum.backward()
            else:
                self.grad_scaler.scale(loss_sum).backward()
        
        should_clip_grad_value = clip_grad_value is not None
        should_clip_grad_norm = clip_grad_norm is not None
        if mixed_precision and (should_clip_grad_value or should_clip_grad_norm):
            self.grad_scaler.unscale_(self.optimizer)
        
        if should_clip_grad_value:
            nn.utils.clip_grad_value_(self.trainable_params, self.config.clip_grad_value)
        if should_clip_grad_norm:
            nn.utils.clip_grad_norm_(self.trainable_params, self.config.clip_grad_norm)
        
        if not mixed_precision:
            self.optimizer.step()
        else:
            self.grad_scaler.step(self.optimizer)
            self.grad_scaler.update()
        return outputs, loss

    
    def prediction_step(self, batch):
        gap_pred = None
        nb_samples = self.nb_draw_samples
        valid_samples = 0
        
        all_dist_inputs = batch['dist_input']
        assert all_dist_inputs.ndim == 4
        num_dist_inputs = all_dist_inputs.size(1)
        for _ in range(nb_samples*2):
            batch['dist_input'] = all_dist_inputs[:,valid_samples%num_dist_inputs]
            new_gap_pred, _ = self.model(batch)
            if torch.isnan(new_gap_pred).any() or torch.isinf(new_gap_pred).any():
                continue
            
            if gap_pred is None:
                gap_pred = new_gap_pred
            else:
                gap_pred.add_(new_gap_pred)
            
            valid_samples += 1
            if valid_samples >= nb_samples:
                break
        
        if not valid_samples:
            raise ValueError('All predictions were NaN')
        elif valid_samples < nb_samples:
            nan_samples = nb_samples - valid_samples
            print(f'Warning: '
                  f'{nan_samples}/{nb_samples} predictions were NaN')
        
        gap_pred.div_(valid_samples)
        gap_loss = torch.abs(gap_pred - batch['target'])
        return dict(
            gap_loss = gap_loss,
        )
    
    def prediction_loop(self, dataloader):
        return super().prediction_loop(dataloader,
                       predict_in_train=self.config.predict_in_train)
    
    def evaluate_predictions(self, predictions,
                             dataset_name='validation',
                             evaluation_stage=False):
        loss = predictions['gap_loss'].mean()
        return dict(
            loss = loss,
        )
    
    def evaluate_on(self, dataset_name, dataset, predictions):
        if self.is_main_rank: print(f'Evaluating on {dataset_name}')
        results = self.evaluate_predictions(predictions,
                                            dataset_name=dataset_name,
                                            evaluation_stage=True)
        return results    
    
    def make_predictions(self):
        super().make_predictions()
        self.evaluate_and_save()

