from torch import nn
import numpy as np
from .base_fusion import BaseFuseTrainer
from torchvision.models import resnet50, ResNet50_Weights
import os

class UniCXRResNet50(BaseFuseTrainer):
    def __init__(self, hparams):
        super().__init__()

        self.save_hyperparameters(hparams)
        self.class_names = self.hparams['class_names']
        self.num_classes = len(self.hparams['class_names'])

        self.cxr_model_spec = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.cxr_model_spec.fc = nn.Linear(in_features=2048, out_features=self.hparams.hidden_size)
        self.cxr_fc = nn.Linear(in_features=self.hparams.hidden_size, out_features=self.num_classes)

        self.pred_criterion = nn.BCELoss()

        if self.hparams['save']:
            self.train_cxr_features = []
            
            self.train_labels = []
            self.valid_cxr_features = []
           
            self.valid_labels = []
            self.test_cxr_features = []
            
            self.test_labels = []
            self.feature_save_dir = f"./features/{self.hparams.task}/{self.hparams['model_name']}/seed{self.hparams.seed}_features"
            os.makedirs(self.feature_save_dir, exist_ok=True)

    def forward(self, data_dict):
        img = data_dict['cxr_imgs'] 
        feat_cxr_distinct = self.cxr_model_spec(img)

        predictions = self.cxr_fc(feat_cxr_distinct).sigmoid()

        loss = self.pred_criterion(predictions, data_dict['labels'])
        outputs = {
            'feat_cxr_distinct' : feat_cxr_distinct,
            'loss': loss,
            'predictions': predictions
        }
        return outputs
    
    
    def on_validation_epoch_end(self):
        
        if self.hparams['save'] and len(self.valid_cxr_features) > 0:
            cxr_features = np.vstack(self.valid_cxr_features)
            
            labels = np.vstack(self.valid_labels)
            save_path = os.path.join(self.feature_save_dir, 
                                f"val_features_epoch_{self.current_epoch}.npz")
            np.savez(save_path, 
                cxr_features=cxr_features, 
               
                labels=labels, 
                hidden_size=self.hparams['hidden_size'],
                epoch=self.current_epoch)
            self.valid_cxr_features = []
            
            self.valid_labels = []   
        scores = self._val_test_epoch_end(self.val_info, clear_cache=True)
        scores['step'] = float(self.current_epoch)
        self.log_dict({k: v for k, v in scores.items() if not isinstance(v, list)}, on_epoch=True, on_step=False)
        
        return scores

  
    def on_test_epoch_end(self):
        if self.hparams['save'] and len(self.test_cxr_features) > 0:
            cxr_features = np.vstack(self.test_cxr_features)
            
            labels = np.vstack(self.test_labels)
            save_path = os.path.join(self.feature_save_dir, 
                                f"test_features_epoch_{self.current_epoch}.npz")
            np.savez(save_path, 
                cxr_features=cxr_features, 
                
                labels=labels, 
                hidden_size=self.hparams['hidden_size'],
                epoch=self.current_epoch)
        
        scores = self._val_test_epoch_end(self.test_info, clear_cache=True)
        self.test_results = {x: scores[x] for x in scores}


    def training_step(self, batch, batch_idx):
        out = self._shared_step(batch)

        if self.hparams['save']:
            self.train_cxr_features.append(out['feat_cxr_distinct'].detach().cpu().numpy())
            
            self.train_labels.append(batch['labels'].detach().cpu())

        self.log_dict({'loss/train': out['loss'].detach()},
                      on_epoch=True, on_step=True,
                      batch_size=batch['labels'].shape[0])
        return out['loss']


    def validation_step(self, batch, batch_idx):
        out = self._val_test_shared_step(batch, self.val_info)
        if self.hparams['save']:
            feat_cxr = out['feat_cxr_distinct'].detach().cpu()
            
            self.valid_cxr_features.append(feat_cxr)
            
            self.valid_labels.append(batch['labels'].detach().cpu())  
        self.log_dict({'loss/validation': out['loss'].detach()},
                      on_epoch=True, on_step=True,
                      batch_size=batch['labels'].shape[0])
        return out['loss']

    def test_step(self, batch, batch_idx):
        out = self._val_test_shared_step(batch, self.test_info)
        
        if self.hparams['save']:
            self.test_cxr_features.append(out['feat_cxr_distinct'].detach().cpu().numpy())
            
            self.test_labels.append(batch['labels'].detach().cpu().numpy())
        
    def on_train_epoch_end(self):
        # save the features and clear the features
        if self.hparams['save'] and len(self.train_cxr_features) > 0:
            cxr_features = np.vstack(self.train_cxr_features)
           
            labels = np.vstack(self.train_labels)
            save_path = os.path.join(self.feature_save_dir, 
                                f"train_features_epoch_{self.current_epoch}.npz")
            
            np.savez(save_path, 
                cxr_features=cxr_features, 
                
                labels=labels, 
                hidden_size=self.hparams['hidden_size'],
                epoch=self.current_epoch)
            print(f"Save the features in epoch {self.current_epoch}")
            
            # clear the features
            self.train_cxr_features = []
           
            self.train_labels = []