import os
import torch
from flowdock.dataset.pdbbind import (
    PDBBind, SameComplexPDBBind, PDBBindWithSortedBatching)
from flowdock.dataset.pdbbind_scoring import (
    PDBBindForRanking, PDBBindForScoring, PDBBindForScoringInference)


def get_datasets(conf, splits, return_separately=False, is_ranking_dataset=False, 
                 is_scoring_dataset=False,
                 complex_collate_fn=None, predicted_complex_positions_path=None,
                 predicted_ligand_transforms_path=None, use_predicted_tr_only=True,
                 is_train_dataset=True, 
                 n_preds_to_use=1, use_all_chains=None, stage_num=None):
    all_datasets = {}
    sample_same_complexes_in_batch = conf.get('sample_same_complexes_in_batch', False)
    use_sorted_batching = conf.get('use_sorted_batching', False)
    randomize_bond_neighbors = conf.get('randomize_bond_neighbors', True)

    if use_all_chains is None:
        use_all_chains = conf.get('use_all_chains', False)

    for split in splits: 
        if split == 'train':
            dataset_list = conf.train_dataset_types
            add_all_atom_pos = False
            min_lig_size = 7
        elif split == 'val':
            dataset_list = conf.val_dataset_types
            add_all_atom_pos = False
            min_lig_size = 7
        elif split == 'test':
            dataset_list = conf.test_dataset_types
            add_all_atom_pos = True
            min_lig_size = 0

        split_datasets = []
        chain_mapping_path = None
        inverse_crop_ids_path = None
        for dataset_type in dataset_list:
            if dataset_type == 'pdbbind' or dataset_type == 'pdbbind_conf':
                data_dir = conf.pdbbind_data_dir
                esm_emb_path = conf.pdbbind_esm_embeddings_path
                sequences_path = conf.pdbbind_sequences_path
                train_split_path = conf.pdbbind_split_train
                val_split_path = conf.pdbbind_split_val
                test_split_path = conf.pdbbind_split_test
                if 'chain_mapping_pdbbind' in conf:
                    chain_mapping_path = conf.chain_mapping_pdbbind
                if 'pdbbind_inverse_crop_ids' in conf:
                    inverse_crop_ids_path = conf.pdbbind_inverse_crop_ids
            elif dataset_type == 'dockgen' or dataset_type == 'dockgen_full' or dataset_type == 'dockgen_full_conf':
                data_dir = conf.dockgen_data_dir
                esm_emb_path = conf.dockgen_esm_embeddings_path
                sequences_path = conf.dockgen_sequences_path
                val_split_path = conf.dockgen_split_val
                if dataset_type == 'dockgen_full' or dataset_type == 'dockgen_full_conf':
                    test_split_path = conf.dockgen_split_test_full
                else:
                    test_split_path = conf.dockgen_split_test
                if 'chain_mapping_dockgen' in conf:
                    chain_mapping_path = conf.chain_mapping_dockgen
            elif dataset_type == 'moad':
                data_dir = conf.moad_data_dir
                esm_emb_path = conf.moad_esm_embeddings_path
                sequences_path = conf.moad_sequences_path
                train_split_path = conf.moad_split_train
                if 'chain_mapping_moad' in conf:
                    chain_mapping_path = conf.chain_mapping_moad
                if 'moad_inverse_crop_ids' in conf:
                    inverse_crop_ids_path = conf.moad_inverse_crop_ids
            elif dataset_type == 'posebusters' or dataset_type == 'posebusters_conf':
                data_dir = conf.posebusters_data_dir
                esm_emb_path = conf.posebusters_esm_embeddings_path
                sequences_path = conf.posebusters_sequences_path
                test_split_path = conf.posebusters_split_test
                if 'chain_mapping_posebusters' in conf:
                    chain_mapping_path = conf.chain_mapping_posebusters
            elif dataset_type == 'astex' or dataset_type == 'astex_conf':
                data_dir = conf.astex_data_dir
                esm_emb_path = conf.astex_esm_embeddings_path
                sequences_path = conf.astex_sequences_path
                test_split_path = conf.astex_split_test
                if 'chain_mapping_astex' in conf:
                    chain_mapping_path = conf.chain_mapping_astex
            else:
                raise ValueError(f"Unknown dataset type: {dataset_type}")
            
            if dataset_type == 'pdbbind_conf' or dataset_type == 'dockgen_full_conf' or dataset_type == 'posebusters_conf' or dataset_type == 'astex_conf':
                data_dir_conf = os.path.join(conf.data_dir_conf, f'{dataset_type}100')
            else:
                data_dir_conf = None
            
            if split == 'train':
                split_path = train_split_path
            elif split == 'val':
                split_path = val_split_path
            elif split == 'test':
                split_path = test_split_path

            if predicted_complex_positions_path is not None:
                # scoring inference
                split_dataset = PDBBindForScoringInference(
                    data_dir=data_dir, 
                    split_path=split_path, 
                    tr_std=1.,
                    esm_embeddings_path=esm_emb_path,
                    sequences_path=sequences_path,
                    max_lig_size=conf.max_lig_size,
                    limit_complexes=None,
                    no_cache=conf.no_cache,
                    cache_path=conf.cache_path,
                    num_dataset_workers=1,
                    std_protein_pos=conf.std_protein_pos,
                    std_lig_pos=conf.std_lig_pos,
                    predicted_complex_positions_path=predicted_complex_positions_path,
                    ligand_mask_ratio=conf.ligand_mask_ratio,
                    protein_mask_ratio=conf.protein_mask_ratio,
                    dataset_type=dataset_type,
                    add_all_atom_pos=add_all_atom_pos,
                    min_lig_size=min_lig_size,
                    n_preds_to_use=n_preds_to_use,
                    use_all_chains=use_all_chains,
                    stage_num=stage_num,
                )
            else:
                if is_scoring_dataset:
                    split_dataset = PDBBindForScoring(
                        data_dir=data_dir,
                        split_path=split_path,
                        tr_std=-1.,
                        esm_embeddings_path=esm_emb_path,
                        sequences_path=sequences_path,
                        max_lig_size=conf.max_lig_size,
                        limit_complexes=conf.limit_complexes_valid,
                        no_cache=conf.no_cache,
                        cache_path=conf.cache_path,
                        num_dataset_workers=1,
                        std_protein_pos=conf.std_protein_pos,
                        std_lig_pos=conf.std_lig_pos,
                        augm_ligand_transforms=False,
                        dataset_type=dataset_type,
                        ligand_mask_ratio=conf.ligand_mask_ratio,
                        protein_mask_ratio=conf.protein_mask_ratio,
                        add_all_atom_pos=add_all_atom_pos,
                        min_lig_size=min_lig_size,
                        use_all_chains=use_all_chains,
                        stage_num=stage_num,
                    )
                elif is_ranking_dataset:
                    split_dataset = PDBBindForRanking(
                        data_dir=data_dir,
                        split_path=split_path,
                        tr_std=-1.,
                        esm_embeddings_path=esm_emb_path,
                        sequences_path=sequences_path,
                        max_lig_size=conf.max_lig_size,
                        limit_complexes=conf.limit_complexes_valid,
                        no_cache=conf.no_cache,
                        cache_path=conf.cache_path,
                        num_dataset_workers=1,
                        std_protein_pos=conf.std_protein_pos,
                        std_lig_pos=conf.std_lig_pos,
                        batch_size=conf.train_batch_size,
                        data_collator=complex_collate_fn,
                        dataset_type=dataset_type,
                        augm_ligand_transforms=False,
                        ligand_mask_ratio=conf.ligand_mask_ratio,
                        protein_mask_ratio=conf.protein_mask_ratio,
                        add_all_atom_pos=add_all_atom_pos,
                        min_lig_size=min_lig_size,
                        use_all_chains=use_all_chains,
                        stage_num=stage_num,
                    ) 
                else:
                    split_dataset = PDBBind(
                        data_dir=data_dir,
                        split_path=split_path,
                        tr_std=conf.tr_std,
                        esm_embeddings_path=esm_emb_path,
                        sequences_path=sequences_path,
                        max_lig_size=conf.max_lig_size,
                        limit_complexes=conf.limit_complexes_train,
                        no_cache=conf.no_cache,
                        cache_path=conf.cache_path,
                        num_dataset_workers=1,
                        num_new_conformations=conf.num_new_conformations,
                        std_protein_pos=conf.std_protein_pos,
                        std_lig_pos=conf.std_lig_pos,
                        augm_ligand_transforms=conf.augm_ligand_transforms,
                        ligand_mask_ratio=conf.ligand_mask_ratio,
                        protein_mask_ratio=conf.protein_mask_ratio,
                        dataset_type=dataset_type,
                        predicted_ligand_transforms_path=predicted_ligand_transforms_path,
                        chain_mapping_path=chain_mapping_path,
                        inverse_crop_ids_path=inverse_crop_ids_path,
                        add_all_atom_pos=add_all_atom_pos,
                        min_lig_size=min_lig_size,
                        use_predicted_tr_only=use_predicted_tr_only,
                        randomize_bond_neighbors=randomize_bond_neighbors,
                        data_dir_conf=data_dir_conf,
                        is_train_dataset=is_train_dataset,
                        n_preds_to_use=n_preds_to_use,
                        use_all_chains=use_all_chains,
                        stage_num=stage_num,
                    )

            if sample_same_complexes_in_batch:
                split_dataset = SameComplexPDBBind(dataset=split_dataset, batch_size=conf.train_batch_size, 
                                                   data_collator=complex_collate_fn)
            elif use_sorted_batching:
                split_dataset = PDBBindWithSortedBatching(dataset=split_dataset, batch_limit=conf.batch_limit,
                                                          data_collator=complex_collate_fn)
            split_datasets.append(split_dataset)

        for dataset in split_datasets:
            if sample_same_complexes_in_batch:
                print(split, dataset.dataset.dataset_type, len(dataset), 'same complexes in batch')
            elif use_sorted_batching:
                print(split, dataset.dataset.dataset_type, len(dataset), len(dataset.dataset.complexes), 'sorted batching')
            else:
                print(split, dataset.dataset_type, len(dataset))

        if return_separately:
            if sample_same_complexes_in_batch or use_sorted_batching:
                all_datasets[split] = {dataset.dataset.dataset_type: dataset for dataset in split_datasets}
            else:
                all_datasets[split] = {dataset.dataset_type: dataset for dataset in split_datasets}
        else:
            if len(split_datasets) > 1:
                combined_dataset = torch.utils.data.ConcatDataset(split_datasets)
                print(split, len(combined_dataset))
            else:
                combined_dataset = split_datasets[0]
                
            all_datasets[split] = combined_dataset

    return all_datasets
    