
from .sam_decomp_closure import SAMDecompClosure
import math
import csv
import torch
import os
from copy import deepcopy
import seaborn as sns
from torch import nn
from torch.nn import functional as F
import numpy  as  np
import random
import pandas as pd
import torch.nn.init as init
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models import swin_s, Swin_S_Weights
from torchvision.models import swin_t, Swin_T_Weights
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
from .fusion_method import SumFusion, ConcatFusion, FiLM, GatedFusion
from .ehr_transformer import DisentangledEHRTransformer
from .base_fusion import BaseFuseTrainer


from tools.calculate_trace import Hessian
from tools.log import CSVLogger, AverageMeter

class SAM_method(BaseFuseTrainer):
   
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)

        self.class_names = self.hparams['class_names']
        self.pred_criterion =  nn.BCELoss(reduction='none')
        
        self.num_classes = 1 if self.hparams.task == 'mortality' else 25
        self._step_counter = 0
        self.loss_history = {name: [] for name in self.class_names}
        self.loss_history_ehr = {name: [] for name in self.class_names}
        self.loss_history_cxr = {name: [] for name in self.class_names}
        self.losses_per_label = {name: [] for name in self.class_names}
        self.losses_per_label_ehr = {name: [] for name in self.class_names}
        self.losses_per_label_cxr = {name: [] for name in self.class_names}
        if self.hparams.save_trace:
            self.csv_logger = CSVLogger(self.hparams, ['step','ehr_trace', 'cxr_trace'], f'./loggs/trace_step_log_seed{self.hparams.seed}.csv')
        elif self.hparams.save_eigenvalue:
            self.csv_logger = CSVLogger(self.hparams, ['step','ehr_eigenvalue', 'cxr_eigenvalue'], f'./loggs/eigenvalue_step_log_seed{self.hparams.seed}.csv')

        # save features
        if self.hparams['save']:
            self.train_ehr_features = []
            self.train_cxr_features = []
            self.train_labels = []
            self.valid_ehr_features = []
            self.valid_cxr_features = []
            self.valid_labels = []
            self.test_ehr_features = []
            self.test_cxr_features = []
            self.test_labels = []
            self.feature_save_dir = f"./features/{self.hparams.task}/{self.hparams['model_name']}/seed{self.hparams.seed}_features"
            os.makedirs(self.feature_save_dir, exist_ok=True)

        if self.hparams['sam_decomp']:
            self.automatic_optimization = False
            print("Automatic optimization is set to False")

        # Fusion module
        if self.hparams['fusion_method'] == 'sum':
            self.fusion_module = SumFusion(input_dim=self.hparams.hidden_size, output_dim=self.num_classes)
        elif self.hparams['fusion_method'] == 'concate' or self.hparams['fusion_method'] == 'single':
            self.fusion_module = ConcatFusion(input_dim=self.hparams.hidden_size if self.hparams['gs_flag'] else self.hparams.hidden_size * 2, output_dim=self.num_classes,sam=self.hparams.sam)
        elif self.hparams['fusion_method'] == 'film':
            self.fusion_module = FiLM(input_dim=self.hparams.hidden_size, dim=self.hparams.hidden_size, output_dim=self.num_classes, x_film=True)

        ehr_input_size = 24
        self.ehr_model = DisentangledEHRTransformer(input_size=ehr_input_size, num_classes=self.num_classes,
                                        d_model=self.hparams.hidden_size, n_head=self.hparams.ehr_n_head,
                                        n_layers_feat=1, n_layers_shared=1,
                                        n_layers_distinct=self.hparams.ehr_n_layers_distinct,
                                        dropout=self.hparams.ehr_dropout,simple=True)

        if self.hparams.cxr_model == 'resnet50':

            self.cxr_model_spec = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
            self.cxr_model_spec.fc = nn.Linear(in_features=2048, out_features=self.hparams.hidden_size)
        elif self.hparams.cxr_model == 'swin_s':
            self.cxr_model_spec = swin_s(weights=Swin_S_Weights.DEFAULT)
            self.cxr_model_spec.head = nn.Linear(in_features=self.cxr_model_shared.head.in_features,
                                                out_features=self.hparams.hidden_size)
        elif self.hparams.cxr_model == 'swin_t':
            self.cxr_model_spec = swin_t(weights=Swin_T_Weights.DEFAULT)
            self.cxr_model_spec.head = nn.Linear(in_features=self.cxr_model_shared.head.in_features,
                                                out_features=self.hparams.hidden_size)
        else:
            raise NotImplementedError(f'specified CXR model "{self.hparams.cxr_model}" is not supported.')




        if self.hparams['uniloss']:
            print(f"inter method uniloss")
            self.ehr_model_linear = nn.Linear(in_features=self.hparams.hidden_size, out_features=self.num_classes)
            self.cxr_model_linear = nn.Linear(in_features=self.hparams.hidden_size, out_features=self.num_classes)



        self.score_ehr = 0
        self.score_cxr = 0
        self.save_batch = []
        self.trace_ehr = []
        self.trace_cxr = []
        self.feat_ehr_encoder_list = []
        self.feat_ehr_distinct_list = []
        self.feat_ehr_shared_list = []
        self.feat_cxr_distinct_list = []
        self.feat_cxr_shared_list =[]
        self.eigenvalue_ehr = []
        self.eigenvalue_cxr = []

        if self.hparams['sam_decomp']:
            self.ma_loss = {
                'ehr':None,
                'cxr':None
            }

            self.last_loss = {
                'ehr':None,
                'cxr':None
            }

            self.loss_decay = {
                'ehr':None,
                'cxr':None
            }

            self.grad_similarity = {
                'ehr':None,
                'cxr':None
            }

            self.score_weight = self.hparams.get('score_weight',0.5)
            self.momentum = self.hparams.get('momentum',0.9)


    def calculate_score(self,labels,outputs,pairs=None):

        if self.hparams.task == 'mortality':
            labels = labels.float()
            pred_ehr = outputs['pred_ehr']
            pred_cxr = outputs['pred_cxr']

            match_score_ehr = labels * pred_ehr + (1 - labels) * (1 - pred_ehr)
            match_score_cxr = labels * pred_cxr + (1 - labels) * (1 - pred_cxr)

            score_ehr = match_score_ehr.sum()
            score_cxr = match_score_cxr.sum()
           

            return score_ehr,score_cxr
        elif self.hparams.task == 'phenotype':
            # base on probability
            labels = labels.float()
            pred_ehr = outputs['pred_ehr']
            pred_cxr = outputs['pred_cxr']

            match_score_ehr = labels * pred_ehr + (1 - labels) * (1 - pred_ehr)
            match_score_cxr = labels * pred_cxr + (1 - labels) * (1 - pred_cxr)

            match_score_ehr = match_score_ehr.sum(dim=1)
            match_score_cxr = match_score_cxr.sum(dim=1)

            score_ehr = match_score_ehr.sum()
            score_cxr = match_score_cxr.sum()

            return score_ehr, score_cxr

    def training_step(self,batch,batch_idx):
        current_step = self._step_counter

        if self.hparams['sam_decomp']:
            current_step = self._step_counter
            
            def loss_fn(predictions, targets):
              
                pairs = torch.ones_like(predictions['predictions'][:, 0])
                loss_pred_multi = self._compute_masked_pred_loss(predictions['predictions'], targets, pairs)
                loss_pred_ehr = self._compute_masked_pred_loss(predictions['pred_ehr'], targets, pairs)
                loss_pred_cxr = self._compute_masked_pred_loss(predictions['pred_cxr'], targets, pairs)
      
                return self.hparams.loss_multi*loss_pred_multi, self.hparams.loss_ehr*loss_pred_ehr, self.hparams.loss_cxr*loss_pred_cxr
                

           
            self.SAM_decomp_optimizer.set_closure(loss_fn, batch, batch['labels'])

           
            if self.hparams.dynamic_mode:
                loss_multi, loss_ehr, loss_cxr = self.SAM_decomp_optimizer.forward_backward_func()
                
                for modality, loss_value in zip(['ehr','cxr'],[loss_ehr,loss_cxr]):
                    if self.ma_loss[modality] is None:
                        self.ma_loss[modality] = loss_value
                        self.loss_decay[modality] = 0.0
                    else:
                        previous_loss = self.ma_loss[modality]
                        self.ma_loss[modality] = self.momentum * previous_loss + (1 - self.momentum) * loss_value
                        self.loss_decay[modality] = max(0.0, self.last_loss[modality] - self.ma_loss[modality]) # get loss decay

                    self.last_loss[modality] = loss_value
                
                grad_similarity = self.SAM_decomp_optimizer._compute_gradient_similarity()
                self.grad_similarity.update(grad_similarity)

                # score for each modality
                score_ehr = self.score_weight * self.loss_decay['ehr'] + (1 - self.score_weight) * self.grad_similarity['ehr']
                score_cxr = self.score_weight * self.loss_decay['cxr'] + (1 - self.score_weight) * self.grad_similarity['cxr']

                select_modality = 'cxr' if score_cxr > score_ehr else 'ehr'

                # update optimizer
                print(f"select modality is {select_modality}")
                self.SAM_decomp_optimizer.set_perturb_mode(select_modality)

            self._step_counter += 1
            loss_multi = self.SAM_decomp_optimizer.step()
            
            if self.hparams.dynamic_mode:

                epoch_log = {
                    'loss/train': loss_multi,
                    'epoch_num': float(self.current_epoch),
                    'similarity_ehr': grad_similarity['ehr'],
                    'similarity_cxr': grad_similarity['cxr'],
                    'score_ehr': score_ehr,
                    'score_cxr': score_cxr,
                }
                self.log_dict(epoch_log, on_epoch=True, on_step=True, batch_size=batch['labels'].shape[0])
            else:
                epoch_log = {
                    'loss/train': loss_multi,
                    'epoch_num': float(self.current_epoch),
                }
                self.log_dict(epoch_log, on_epoch=True, on_step=True, batch_size=batch['labels'].shape[0])
    
            return None
        else:
            current_step = self.global_step
          
            out = self._shared_step(batch) 

            pairs = torch.ones_like(out['predictions'][:, 0])
           
            if self.hparams['save']:
                self.train_ehr_features.append(out['feat_ehr_distinct'].detach().cpu().numpy())
                self.train_cxr_features.append(out['feat_cxr_distinct'].detach().cpu().numpy())
                self.train_labels.append(batch['labels'].detach().cpu())



            pairs = torch.ones_like(out['predictions'][:, 0])
            loss_total =  self._compute_and_log_loss(out, batch['labels'], pairs,mode='train', where='training_step', name="original", step=current_step)
            
          


            epoch_log = {}
            epoch_log.update({
                'loss/train': loss_total.detach(),
                'epoch_num': float(self.current_epoch),  # 将'step'改名为'epoch_num'
            })
            self.log_dict(epoch_log, on_epoch=True, on_step=False , batch_size=batch['labels'].shape[0])
            return loss_total

    def validation_step(self, batch, batch_idx):
        self.eval()
        device = next(self.parameters()).device
        batch['labels'] = batch['labels'].to(device)

      
        out = self._val_test_shared_step(batch, self.val_info)

       

        
        pairs = torch.ones_like(out['feat_ehr_distinct'][:, 0], device=device)
        loss_pred_multi = self._compute_masked_pred_loss(out['predictions'], batch['labels'], pairs)
        loss_pred_ehr = self._compute_masked_pred_loss(out['pred_ehr'], batch['labels'], pairs)
        loss_pred_cxr = self._compute_masked_pred_loss(out['pred_cxr'], batch['labels'], pairs)
        loss_total = loss_pred_multi + loss_pred_ehr + loss_pred_cxr

        
        epoch_log = {
            'loss/validation': loss_total.detach(),
            'loss/valid_multi': loss_pred_multi.detach(),
            'loss/valid_ehr': loss_pred_ehr.detach(),
            'loss/valid_cxr': loss_pred_cxr.detach(),
            'step': float(self.current_epoch),
        }
        self.log_dict(epoch_log, on_epoch=True, on_step=False, batch_size=batch['labels'].shape[0])

        return loss_total

    def test_step(self, batch, batch_idx):
       
        out = self._val_test_shared_step(batch, self.test_info)
        
        if self.hparams['save']:
            self.test_ehr_features.append(out['feat_ehr_distinct'].detach().cpu().numpy())
            self.test_cxr_features.append(out['feat_cxr_distinct'].detach().cpu().numpy())
            self.test_labels.append(batch['labels'].detach().cpu().numpy())
        


    def _compute_masked_pred_loss(self, input, target, mask):
       
        if self.hparams['matched']:
            return self.pred_criterion(input, target).mean()
        else:
            return (self.pred_criterion(input, target).mean(dim=1) * mask).sum() / max(mask.sum(), 1e-6)

    def _losslandscpe_loss(self, model_output, y_gt, pairs):
        loss_multi = self._compute_masked_pred_loss(model_output['predictions'], y_gt, pairs)
        return loss_multi

    def _compute_prediction_losses(self, model_output, y_gt, pairs, log=True, mode='train',where="None input", name = "defaul", step=0):
        ehr_mask = torch.ones_like(model_output['predictions'][:, 0]) # [batch] all be ones
        loss_pred_multi = self._compute_masked_pred_loss(model_output['predictions'], y_gt, ehr_mask)
        loss_pred_ehr = self._compute_masked_pred_loss(model_output['pred_ehr'], y_gt, ehr_mask)
        loss_pred_cxr = self._compute_masked_pred_loss(model_output['pred_cxr'], y_gt, pairs) # cxr maybe lack
        
        if self.hparams['uniloss']:
            loss_pred_final = loss_pred_multi + loss_pred_ehr + loss_pred_cxr
        else:
           
            loss_pred_final = loss_pred_multi


        if log:
            self.log_dict({
                f'{mode}_loss/pred_final': loss_pred_multi.detach(),
                f'{mode}_loss/pred_ehr': loss_pred_ehr.detach(),
                f'{mode}_loss/pred_cxr': loss_pred_cxr.detach(),
                'step': float(self.current_epoch)
            }, on_epoch=True, on_step=False, batch_size=y_gt.shape[0])
        return loss_pred_final


    def _compute_and_log_loss(self, model_output, y_gt, pairs, log=True, mode='train',where="None Input", name = "defaul", step=0):
        prediction_losses = self._compute_prediction_losses(model_output, y_gt, pairs, log, mode,where, name, step)
        return prediction_losses


    def forward(self, data_dict):
        x = data_dict['ehr_ts']
        img = data_dict['cxr_imgs'] 


        seq_lengths = data_dict['seq_len']
        pairs = data_dict['has_cxr'] 

        feat_ehr_distinct,_ = self.ehr_model(x, seq_lengths)
        feat_cxr_distinct = self.cxr_model_spec(img)




        _, _, pred_final = self.fusion_module(feat_ehr_distinct, feat_cxr_distinct)
        weight_size = self.fusion_module.fc_out.weight.size(1)
        pred_cxr_inner = (torch.mm(feat_cxr_distinct, torch.transpose(self.fusion_module.fc_out.weight[:, weight_size // 2 :], 0, 1)) +
                self.fusion_module.fc_out.bias / 2)
        pred_ehr_inner = (torch.mm(feat_ehr_distinct, torch.transpose(self.fusion_module.fc_out.weight[:, :weight_size // 2], 0, 1)) +
                self.fusion_module.fc_out.bias / 2)
        if self.hparams['uniloss']:
            feat_ehr_distinct_pred = self.ehr_model_linear(feat_ehr_distinct)
            feat_cxr_distinct_pred = self.cxr_model_linear(feat_cxr_distinct)
            pred_ehr = feat_ehr_distinct_pred.sigmoid()
            pred_cxr = feat_cxr_distinct_pred.sigmoid()
        else:
            pred_ehr = pred_ehr_inner.sigmoid()
            pred_cxr = pred_cxr_inner.sigmoid()
    

        pred_final = pred_final.sigmoid()


   
        outputs = {
        'feat_ehr_distinct': feat_ehr_distinct,
        'feat_cxr_distinct': feat_cxr_distinct,
        'predictions':  pred_final,
        'pred_ehr':pred_ehr,
        'pred_cxr': pred_cxr,
        }

        return outputs
    
    
    def on_train_epoch_end(self):
        if self.hparams['save'] and len(self.train_ehr_features) > 0:
            ehr_features = np.vstack(self.train_ehr_features)
            cxr_features = np.vstack(self.train_cxr_features)
            labels = np.vstack(self.train_labels)
            save_path = os.path.join(self.feature_save_dir, 
                                f"train_features_epoch_{self.current_epoch}.npz")
            
            np.savez(save_path, 
                ehr_features=ehr_features, 
                cxr_features=cxr_features, 
                labels=labels, 
                hidden_size=self.hparams['hidden_size'],
                epoch=self.current_epoch)
            print(f"Save the features in epoch {self.current_epoch}")
            
            self.train_ehr_features = []
            self.train_cxr_features = []
            self.train_labels = []
        
        if self.hparams['save_trace_epoch']:
            hessian_comp = Hessian(
                model = self,
                data = None,
                dataloader = self.save_batch,
                cuda = True
            )
            trace_ehr, trace_cxr = hessian_comp.trace_modal(maxIter=100)
            self.log_dict({
                'trace/ehr': np.mean(trace_ehr),
                'trace/cxr': np.mean(trace_cxr),
                'epoch_num': float(self.current_epoch)
            }, on_epoch=True, on_step=False)
            self.save_batch = {}
        if self.hparams['save_trace']:
            trace_ehr = np.mean(self.trace_ehr)
            trace_cxr = np.mean(self.trace_cxr)
            self.trace_ehr = []
            self.trace_cxr = []
            self.log_dict({
                'trace/ehr': trace_ehr,
                'trace/cxr': trace_cxr,
                'epoch_num': float(self.current_epoch)
            }, on_epoch=True, on_step=False)
        if self.hparams['save_eigenvalue']:
            eigenvalue_ehr = np.mean(self.eigenvalue_ehr)
            eigenvalue_cxr = np.mean(self.eigenvalue_cxr)
            self.eigenvalue_ehr = []
            self.eigenvalue_cxr = []
            self.log_dict({
                'eigenvalue/ehr': eigenvalue_ehr,
                'eigenvalue/cxr': eigenvalue_cxr,
                'epoch_num': float(self.current_epoch)
            }, on_epoch=True, on_step=False)
                

    def on_train_batch_end(self, outputs, batch, batch_idx):
       
        if (self.hparams['save_trace'] or self.hparams['save_eigenvalue']) and self._step_counter % 20 == 0:
            model_state = {k: v.clone() for k, v in self.state_dict().items()}
            optimizer_state = deepcopy(self.optimizers().state_dict())

           
            model_copy = deepcopy(self)
            model_copy.load_state_dict(model_state)
            model_copy.eval()
           
            hessian_comp = Hessian(
                model = model_copy,
                data = batch,
                dataloader = None,
                cuda = True
            )
            if self.hparams['save_trace']:
                trace_ehr, trace_cxr = hessian_comp.trace_modal(maxIter=200) # 返回的是100次采样的估计结果
                self.csv_logger.save_values(self._step_counter, trace_ehr, trace_cxr)
                self.trace_ehr.append(trace_ehr)
                self.trace_cxr.append(trace_cxr)
                print(f"trace_ehr is {trace_ehr}, trace_cxr is {trace_cxr}")
            
            if self.hparams['save_eigenvalue']:
                
                eigenvalue_ehr = hessian_comp.eigenvalues_uni(model_name='ehr', maxIter=200, top_n=10)
                eigenvalue_cxr = hessian_comp.eigenvalues_uni(model_name='cxr', maxIter=200, top_n=10)
                self.csv_logger.save_values(self._step_counter, eigenvalue_ehr, eigenvalue_cxr)
                
                self.eigenvalue_ehr.append(eigenvalue_ehr)
                self.eigenvalue_cxr.append(eigenvalue_cxr)
            
          
            self.load_state_dict(model_state)
            self.optimizers().load_state_dict(optimizer_state)
            self.zero_grad()
            self.train()
           
            
        elif self.hparams['save_trace_epoch']:
            self.save_batch.append(batch)


    def on_validation_epoch_end(self):

        if self.hparams['save'] and len(self.valid_ehr_features) > 0:
            ehr_features = np.vstack(self.valid_ehr_features)
            cxr_features = np.vstack(self.valid_cxr_features)
            labels = np.vstack(self.valid_labels)
            save_path = os.path.join(self.feature_save_dir, 
                                f"val_features_epoch_{self.current_epoch}.npz")
            np.savez(save_path, 
                ehr_features=ehr_features, 
                cxr_features=cxr_features, 
                labels=labels, 
                hidden_size=self.hparams['hidden_size'],
                epoch=self.current_epoch)
            self.valid_ehr_features = []
            self.valid_cxr_features = []
            self.valid_labels = []    


        scores_ehr,scores_cxr = self._get_ehr_cxr_scores(self.val_info,clear_cache=False)
        scores = self._val_test_epoch_end(self.val_info,clear_cache=True)
        scores_ehr_prefixed = {f"ehr_{k}": v for k, v in scores_ehr.items()}
        scores_cxr_prefixed = {f"cxr_{k}": v for k, v in scores_cxr.items()}
        combined_scores = {**scores, **scores_ehr_prefixed, **scores_cxr_prefixed}
        combined_scores['step'] = float(self.current_epoch)
        self.log_dict({k: v for k, v in combined_scores.items() if not isinstance(v, list)}, on_epoch=True, on_step=False)
        
        return scores

    def on_test_epoch_end(self):
        
        if self.hparams['save'] and len(self.test_ehr_features) > 0:
            ehr_features = np.vstack(self.test_ehr_features)
            cxr_features = np.vstack(self.test_cxr_features)
            labels = np.vstack(self.test_labels)
            save_path = os.path.join(self.feature_save_dir, 
                                f"test_features_epoch_{self.current_epoch}.npz")
            np.savez(save_path, 
                ehr_features=ehr_features, 
                cxr_features=cxr_features, 
                labels=labels, 
                hidden_size=self.hparams['hidden_size'],
                epoch=self.current_epoch)
        
        scores = self._val_test_epoch_end(self.test_info,clear_cache=False)
        scores_ehr,scores_cxr = self._get_ehr_cxr_scores(self.test_info,clear_cache=True)
        scores_ehr_prefixed = {f"ehr_{k}": v for k, v in scores_ehr.items()}
        scores_cxr_prefixed = {f"cxr_{k}": v for k, v in scores_cxr.items()}
        combined_scores = {**scores, **scores_ehr_prefixed, **scores_cxr_prefixed}
        self.test_results = {x: combined_scores[x] for x in combined_scores}

    def _val_test_shared_step(self, batch, cache):
        out = self._shared_step(batch)
        cache['predictions'].append(out['predictions'].detach())
        cache['pred_ehr'].append(out['pred_ehr'].detach())
        cache['pred_cxr'].append(out['pred_cxr'].detach())
        cache['labels'].append(batch['labels'].detach())
        return out

    def configure_optimizers(self):
        if self.hparams['sam_decomp']:
            print("into sam-decomp")
            base_optimizer = torch.optim.AdamW
            model_params = list(self.parameters())
            ehr_params = list(self.ehr_model.parameters())
            cxr_params = list(self.cxr_model_spec.parameters())
            other_params = [p for p in model_params if p not in set(ehr_params) and p not in set(cxr_params)]

            # set sam modality
            param_groups = [
                {"params": ehr_params, "name": "ehr", "adaptive": False, "rho": self.hparams.get('rho', 0.1), "apply_sam": False, "sagm_alpha": self.hparams.get('sagm_alpha', 0.5)},
                {"params": cxr_params, "name": "cxr", "adaptive": False, "rho": self.hparams.get('rho', 0.1), "apply_sam": False, "sagm_alpha": self.hparams.get('sagm_alpha', 0.5)},
                {"params": other_params, "name": "other", "adaptive": False, "rho": self.hparams.get('rho', 0.1), "apply_sam": False, "sagm_alpha": self.hparams.get('sagm_alpha', 0.5)},
            ]

           
            self.SAM_decomp_optimizer = SAMDecompClosure(
                params=param_groups,
                base_optimizer=base_optimizer,
                model=self,
                lr=self.hparams["lr"],
                weight_decay=self.hparams.weight_decay,
                alpha=self.hparams.get('scale_alpha', 0.01),
                pareto=False,
                dynamic=self.hparams.dynamic_mode
            )
            self.SAM_decomp_optimizer.set_perturb_mode('all')
           
            return self.SAM_decomp_optimizer            
      
        else:
            optimizer = torch.optim.AdamW(self.parameters(), lr=0.0001, weight_decay=0.0)
            return optimizer
