import os
import pickle
import yaml
from copy import deepcopy
import torch.nn.init as init
from datetime import datetime
from argparse import Namespace
import numpy as np
from torch import nn
import random
import torch
import lightning as L


from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader
from arguments import get_arg_parser

from datasets.dataset_new import create_data_loaders

from models import (
    SAM_method
)


def get_model_specific_params(args):
   
    base_params = f"bs{args.batch_size}-lr{args.lr}-fusion{args.fusion_method}-hs{args.hidden_size}"
    
    model_specific = {
       'masam': f"{base_params}-eh{args.ehr_n_head}-el{args.ehr_n_layers}-ed{args.ehr_dropout}-rho{args.rho}-wd{args.wd}-sw{args.score_weight}-m{args.momentum}",
    }
    
    return model_specific.get(args.model, base_params)

def get_log_info(args):
    
    base_dir = f"/root/autodl-tmp/experiments/for{args.model}/{args.task}"
    specific_params = get_model_specific_params(args)
    prefix = f"{args.model.upper()}"     
    ver_name = f"{prefix}-{specific_params}-seed{args.seed}"
    
    return base_dir, ver_name



def load_config(config_path):
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

def load_model_config(model_name, args):
  
    config_path = f'./configs/{model_name}.yaml'
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    model_params = config['hparams']
    print(f"model_params: {model_params}")
    for key, value in vars(args).items():
       
        if key in model_params:
            model_params[key] = value
    
    return model_params

def run_model(args):

    if isinstance(args, dict):
        args = Namespace(**args)

    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  
    torch.set_num_threads(5)
    L.seed_everything(seed,workers = True)


    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Now matched_subset is {args.matched}")

  
    train_loader, val_loader, test_loader = create_data_loaders(args.ehr_root, args.cxr_root, args.task,
                                                                args.fold, args.batch_size, args.num_workers,
                                                                matched_subset = args.matched,index = args.index,seed = seed,one_hot = args.mortality2,
                                                                resized_base_path=args.resized_cxr_root)


    print(f"train_loader: {len(train_loader)}")
    config = load_config(args.config if args.config is not None else './configs/sam_method.yaml')
    model_class = SAM_method
    train_data_num = len(train_loader.dataset)
    model_params = load_model_config(args.model, args)
    model_params.update({
        'new_data': args.new_data,
        'task': args.task,
        'seed': args.seed,
        'save': args.save,
        'class_names': train_loader.dataset.CLASSES,
        'train_data_num': train_data_num,
    })

    model = model_class(model_params)
    callback_metric = 'overall/PRAUC'
    
   
  
    early_stop_callback = EarlyStopping(monitor=callback_metric,
                                    min_delta=0.00,
                                    patience=args.patience,
                                    verbose=False,
                                    mode="max")

   
    checkpoint_callback = ModelCheckpoint(
        monitor=callback_metric,
        mode='max',
        save_top_k=1,
        verbose=True,
        filename='{epoch:02d}-{overall/PRAUC:.2f}'
    )


    log_dir, ver_name = get_log_info(args)
    print(f"log_dir: {log_dir}")
    print(f"in the ver_name {ver_name}")
    tb_logger = pl_loggers.TensorBoardLogger(save_dir=log_dir, version=ver_name)
    csv_logger = pl_loggers.CSVLogger(save_dir=log_dir, version=ver_name)

    trainer = L.Trainer(enable_checkpointing=True,
                    accelerator='gpu',
                    devices=[args.gpu],
                    fast_dev_run=20 if args.dev_run else False,
                    logger=[tb_logger, csv_logger],
                    num_sanity_val_steps=0,
                    max_epochs=100,  
                    log_every_n_steps=1,
                    min_epochs=4,   
                    callbacks=[early_stop_callback, checkpoint_callback])

    if args.mode == 'train':
    
        trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)
        print("Test model")
        best_model_path = checkpoint_callback.best_model_path
        print(f"best_model_path: {best_model_path}")

        best_model = model_class.load_from_checkpoint(best_model_path, strict=False)



        if hasattr(model, 'tree_info') and model.tree_info is not None:
            with open(os.path.join(csv_logger.log_dir, 'tree_info.pkl'), 'wb') as f:
                pickle.dump(model.tree_info, f)

        if not args.dev_run:
            trainer.test(model=best_model, dataloaders=test_loader)
            with open(os.path.join(csv_logger.log_dir, 'test_set_results.yaml'), 'w') as f:
                yaml.dump(best_model.test_results, f)
        print(f"save in the {f}")
        print("save success!")
        print(best_model.test_results)
    elif args.mode == 'test':
       
        print("Test model by dynamic")

        best_model = model_class.load_from_checkpoint(best_model_path,strict=False)
        best_model.hparams.update(model_params) 
        best_model.eval()



        if hasattr(model, 'tree_info') and model.tree_info is not None:
            with open(os.path.join(csv_logger.log_dir, 'tree_info.pkl'), 'wb') as f:
                pickle.dump(model.tree_info, f)
        if not args.dev_run:
            trainer.test(model=best_model, dataloaders=test_loader)
            if args.dynamic:
                dynamic_path = f"./logs/MLA_Files/dynamic_result/index-{args.index}"
                if not os.path.exists(dynamic_path):
                    os.makedirs(dynamic_path)
                with open(os.path.join(dynamic_path, f'index{args.index}-seed{args.seed}.yaml'), 'w') as f:
                    yaml.dump(best_model.test_results, f)
            else:
                with open(os.path.join(csv_logger.log_dir, 'test_set_results.yaml'), 'w') as f:
                    yaml.dump(best_model.test_results, f)
        print(f"save in the {f}")
        return best_model.test_results



if __name__ == '__main__':
    parser = get_arg_parser()
    args = parser.parse_args()
    test_results = run_model(args)

