import os
import pickle
import yaml
from copy import deepcopy
import torch.nn.init as init
from datetime import datetime
from argparse import Namespace
import numpy as np
from torch import nn
import random
import torch
import lightning as L
import argparse
from models import get_model, get_model_cls
from utils.ver_name import get_version_name

from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader
from arguments import get_args
from lightning.pytorch.utilities.rank_zero import rank_zero_only
from datasets.dataset import create_data_loaders
# from datasets.dataset_old_version import create_data_loaders


def get_device_map_location(target_device=0):
    """获取设备映射位置，处理不同GPU设备间的checkpoint加载"""
    if torch.cuda.is_available():
        device_count = torch.cuda.device_count()
        if target_device >= device_count:
            target_device = 0
        return f'cuda:{target_device}'
    else:
        return 'cpu'


def run_model(args):

    if isinstance(args, dict):
        args = Namespace(**args)

    hparams = vars(args)
    # Set number of threads allowed
    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # For multiple GPUs
    torch.set_num_threads(5)
    L.seed_everything(seed, workers = True)

    # Ensure that all operations are deterministic on GPU (if used) for reproducibility
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # Print data configuration
    print(f"Data subset configuration:")
    print(f"  Train: {'matched' if args.train_matched else 'full'} data")
    print(f"  Validation: {'matched' if args.val_matched else 'full'} data")
    print(f"  Test: {'matched' if args.test_matched else 'full'} data")

    # Data loader with flexible matched/full configuration and demographics
    train_loader, val_loader, test_loader = create_data_loaders(
        args.ehr_root, args.task,
        args.fold, args.batch_size, args.num_workers,
        matched_subset=getattr(args, 'matched', False),
        train_matched=args.train_matched,
        val_matched=args.val_matched,
        test_matched=args.test_matched,
        use_triplet=args.use_triplet, 
        seed=seed,
        resized_base_path=args.resized_cxr_root,
        image_meta_path=args.image_meta_path,
        pkl_dir=args.pkl_dir,
        use_demographics=args.use_demographics,
        demographic_cols=args.demographic_cols,
        use_label_weights=getattr(args, 'use_label_weights', True),
        label_weight_method=getattr(args, 'label_weight_method', 'balanced'),
        custom_label_weights=getattr(args, 'custom_label_weights', None),
        cxr_dropout_rate=getattr(args, 'cxr_dropout_rate', 0.0),
        cxr_dropout_seed=getattr(args, 'cxr_dropout_seed', None),
        demographics_in_model_input=getattr(args, 'demographics_in_model_input', False)
    )

    # Auto-detect and update input_dim from actual data
    train_dataset = train_loader.dataset
    if hasattr(train_dataset, 'input_dim'):
        actual_input_dim = train_dataset.input_dim
        print(f"Auto-detected input dimension: {actual_input_dim}")
        
        # Update args and hparams with actual input dimension
        args.input_dim = actual_input_dim
        hparams['input_dim'] = actual_input_dim
        
        # Print dimension breakdown for debugging
        if hasattr(train_dataset, 'base_ehr_dim') and hasattr(train_dataset, 'demo_feature_dim'):
            print(f"  Base EHR dimension: {train_dataset.base_ehr_dim}")
            print(f"  Demographic dimension: {train_dataset.demo_feature_dim}")
    else:
        print(f"Using configured input dimension: {args.input_dim}")

    class_names = train_loader.dataset.CLASSES
    hparams['class_names'] = class_names
    hparams['steps_per_epoch'] = len(train_loader)
    
    # Get label weights from training dataset if enabled
    if getattr(args, 'use_label_weights', False) and hasattr(train_dataset, 'get_label_weights'):
        label_weights = train_dataset.get_label_weights()
        if label_weights is not None:
            # 为了TensorBoard兼容性，存储为Python列表
            if isinstance(label_weights, torch.Tensor):
                hparams['label_weights'] = label_weights.detach().cpu().numpy().tolist()
            else:
                hparams['label_weights'] = label_weights
            print(f"Label weights loaded from dataset: {hparams['label_weights']}")
            
            # 对于mortality任务，也可以设置pos_weight
            if args.task == 'mortality':
                # pos_weight是正类的权重，对于二分类任务
                if isinstance(label_weights, torch.Tensor):
                    pos_weight = label_weights[1].item() if len(label_weights) > 1 else label_weights[0].item()
                else:
                    pos_weight = label_weights[1] if len(label_weights) > 1 else label_weights[0]
                hparams['mortality_pos_weight'] = pos_weight
                print(f"Mortality pos_weight: {pos_weight}")
    
    # Register model 
    model = get_model(args.model, hparams)

    # Choose callback metric based on task type
    if args.task == 'mortality':
        callback_metric = 'overall/PRAUC'  # Binary classification
        filename_template = '{epoch:02d}-{overall/PRAUC:.2f}'
    elif args.task == 'phenotype':
        callback_metric = 'overall/PRAUC'  # Multi-label classification  
        filename_template = '{epoch:02d}-{overall/PRAUC:.2f}'
    elif args.task == 'los':
        callback_metric = 'overall/ACC'  # Multi-class classification
        filename_template = '{epoch:02d}-{overall/ACC:.2f}'
    else:
        raise ValueError(f"Unknown task: {args.task}")

    # Early stop training
    early_stop_callback = EarlyStopping(monitor=callback_metric,
                                    min_delta=0.00,
                                    patience=args.patience,
                                    verbose=False,
                                    mode="max")

    # Set best validation metric
    checkpoint_callback = ModelCheckpoint(
        monitor=callback_metric,
        mode='max',
        save_top_k=1,
        verbose=True,
        filename=filename_template
    )

    # Get experiment name (model parameters can be stored in yaml)
    log_dir, ver_name = get_version_name(args)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    print(f"in the ver_name {ver_name}")

    tb_logger = pl_loggers.TensorBoardLogger(save_dir=log_dir, version=ver_name)
    csv_logger = pl_loggers.CSVLogger(save_dir=log_dir, version=ver_name)

    if len(args.gpu) > 1:
        devices = args.gpu
        strategy = 'ddp_find_unused_parameters_true'  
    else:
        devices = [args.gpu[0]]  
        strategy = 'auto'
        
    # Get target device for checkpoint loading
    target_device = args.gpu[0] if isinstance(args.gpu, list) else args.gpu
    map_location = get_device_map_location(target_device)
        
    # Lightning trainer
    trainer = L.Trainer(
        enable_checkpointing=args.save_checkpoint,
        accelerator='gpu',
        devices=devices,  # int for single GPU, list for multiple GPUs
        strategy=strategy,  # automatically switches to DDP when devices>1
        fast_dev_run=20 if args.dev_run else False,
        logger=[tb_logger, csv_logger],
        num_sanity_val_steps=0,
        max_epochs=args.epochs,
        log_every_n_steps=1,
        min_epochs=4,
        callbacks=[early_stop_callback, checkpoint_callback]
    )

    # Train and test model
    if args.mode == 'train':
        trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

        # Test model
        print("Test model")
        best_model_path = checkpoint_callback.best_model_path
        print(f"best_model_path: {best_model_path}")

        if not args.dev_run:
            # Only test on main process (rank 0)
            if trainer.global_rank == 0:
                # Create a single GPU Trainer for testing on main process
                test_trainer = L.Trainer(
                    accelerator='gpu',
                    devices=[devices[0] if isinstance(devices, list) else devices],
                    logger=[tb_logger, csv_logger],
                    enable_checkpointing=False,
                )
                
                # Check model path
                best_model_path = checkpoint_callback.best_model_path
                print(f"ModelCheckpoint best_model_path: {best_model_path}")
                
                # Check if model was manually saved
                manual_best_path = getattr(model, 'best_model_path', None)
                print(f"Manual best_model_path: {manual_best_path}")
                
                model_to_test = None
                if manual_best_path and os.path.exists(manual_best_path):
                    print(f"Loading manually saved best model: {manual_best_path}")
                    model_to_test = get_model_cls(args.model).load_from_checkpoint(
                        manual_best_path, map_location=map_location
                    )
                    # Update loaded model's hparams with current args (for new parameters like save_predictions)
                    for key, value in hparams.items():
                        if not hasattr(model_to_test.hparams, key):
                            setattr(model_to_test.hparams, key, value)
                    print(f"Updated model hparams with save_predictions={getattr(model_to_test.hparams, 'save_predictions', False)}")
                elif best_model_path and os.path.exists(best_model_path):
                    print(f"Loading ModelCheckpoint best model: {best_model_path}")
                    model_to_test = get_model_cls(args.model).load_from_checkpoint(
                        best_model_path, map_location=map_location
                    )
                    # Update loaded model's hparams with current args (for new parameters like save_predictions)
                    for key, value in hparams.items():
                        if not hasattr(model_to_test.hparams, key):
                            setattr(model_to_test.hparams, key, value)
                    print(f"Updated model hparams with save_predictions={getattr(model_to_test.hparams, 'save_predictions', False)}")
                else:
                    print(f"Using current trained model for testing")
                    model_to_test = model
                
                # Run test
                test_trainer.test(model=model_to_test, dataloaders=test_loader)
                test_results = model_to_test.test_results
                save_test_results(csv_logger, test_results)

    elif args.mode == 'test':
        print("Test mode")
            
        if not args.dev_run:
            # Only test on main process (rank 0)
            if trainer.global_rank == 0:
                # Create a single GPU Trainer for testing on main process
                test_trainer = L.Trainer(
                    accelerator='gpu',
                    devices=[devices[0] if isinstance(devices, list) else devices],
                    logger=[tb_logger, csv_logger],
                    enable_checkpointing=False,
                )
                
                model_to_test = None
                
                # 优先检查手动指定的checkpoint路径
                if hasattr(args, 'checkpoint_path') and args.checkpoint_path and os.path.exists(args.checkpoint_path):
                    print(f"Loading manually specified checkpoint: {args.checkpoint_path}")
                    model_to_test = get_model_cls(args.model).load_from_checkpoint(
                        args.checkpoint_path, map_location=map_location
                    )
                    # Update loaded model's hparams with current args
                    for key, value in hparams.items():
                        # 对于关键的测试参数，强制更新（即使checkpoint中已存在）
                        force_update_keys = ['compute_fairness', 'fairness_attributes', 'fairness_age_bins', 
                                            'fairness_intersectional', 'save_predictions', 'use_demographics']
                        if key in force_update_keys:
                            setattr(model_to_test.hparams, key, value)
                            print(f"Force updated {key} = {value}")
                        elif not hasattr(model_to_test.hparams, key):
                            setattr(model_to_test.hparams, key, value)
                            print(f"Added new parameter {key} = {value}")
                    print(f"Updated model hparams with save_predictions={getattr(model_to_test.hparams, 'save_predictions', False)}")
                else:
                    # 检查模型路径
                    best_model_path = checkpoint_callback.best_model_path
                    print(f"ModelCheckpoint best_model_path: {best_model_path}")
                    
                    # 检查MMTM是否手动保存了模型
                    manual_best_path = getattr(model, 'best_model_path', None)
                    print(f"Manual best_model_path: {manual_best_path}")
                    
                    if manual_best_path and os.path.exists(manual_best_path):
                        print(f"Loading manually saved best model: {manual_best_path}")
                        model_to_test = get_model_cls(args.model).load_from_checkpoint(
                            manual_best_path, map_location=map_location
                        )
                        # Update loaded model's hparams with current args 
                        for key, value in hparams.items():
                            # 对于关键的测试参数，强制更新（即使checkpoint中已存在）
                            force_update_keys = ['compute_fairness', 'fairness_attributes', 'fairness_age_bins', 
                                                'fairness_intersectional', 'save_predictions', 'use_demographics']
                            if key in force_update_keys:
                                setattr(model_to_test.hparams, key, value)
                                print(f"Force updated {key} = {value}")
                            elif not hasattr(model_to_test.hparams, key):
                                setattr(model_to_test.hparams, key, value)
                                print(f"Added new parameter {key} = {value}")
                        print(f"Updated model hparams with save_predictions={getattr(model_to_test.hparams, 'save_predictions', False)}")
                    elif best_model_path and os.path.exists(best_model_path):
                        print(f"Loading ModelCheckpoint best model: {best_model_path}")
                        model_to_test = get_model_cls(args.model).load_from_checkpoint(
                            best_model_path, map_location=map_location
                        )
                        # Update loaded model's hparams with current args 
                        for key, value in hparams.items():
                            # 对于关键的测试参数，强制更新（即使checkpoint中已存在）
                            force_update_keys = ['compute_fairness', 'fairness_attributes', 'fairness_age_bins', 
                                                'fairness_intersectional', 'save_predictions', 'use_demographics']
                            if key in force_update_keys:
                                setattr(model_to_test.hparams, key, value)
                                print(f"Force updated {key} = {value}")
                            elif not hasattr(model_to_test.hparams, key):
                                setattr(model_to_test.hparams, key, value)
                                print(f"Added new parameter {key} = {value}")
                        print(f"Updated model hparams with save_predictions={getattr(model_to_test.hparams, 'save_predictions', False)}")
                    else:
                        print(f"Using current trained model for testing")
                        model_to_test = model
                
                # 运行测试
                test_trainer.test(model=model_to_test, dataloaders=test_loader)
                test_results = model_to_test.test_results
                save_test_results(csv_logger, test_results)


def save_test_results(csv_logger, test_results):
    """Save test results only on rank 0"""
    save_path = os.path.join(csv_logger.log_dir, 'test_set_results.yaml')
    with open(save_path, 'w') as f:
        yaml.dump(test_results, f)  # Note: test_results is in list format
    print(f"Results saved to {save_path}")
    print(test_results)
    print("Save success!")


if __name__ == '__main__':
    args = get_args()
    run_model(args)
