import torch

from typing import Tuple, Dict, Any
from functools import reduce
from tqdm import tqdm

from data.encode import num_atom_features, num_bond_features
from data.utils import set_seed, dict_cuda_copy, dict_list_cuda_copy
from data.structures import PackedMolGraph
from data.load_data import load_data, GeoMolDataset, SupportedDatasets
from data.load_multi_data import load_multi_data, MultiGeoMolDataset, SupportedMultiDatasets
from net.conformation_model import ConformationModel, MultiConformationModel
from train.config import QM7ConformationConfig, QM8ConformationConfig, QM9ConformationConfig, \
    GeomDRUGConformationConfig, GeomDRUGSmallConformationConfig, \
    GeomQM9ConformationConfig, GeomQM9SmallConformationConfig

MODEL_DICT_DIR = 'train/conf'


def recover_config(dataset_name: str, special_config: dict = None) -> Dict[str, Any]:
    if dataset_name == SupportedDatasets.QM7:
        config = QM7ConformationConfig
    elif dataset_name == SupportedDatasets.QM8:
        config = QM8ConformationConfig
    elif dataset_name == SupportedDatasets.QM9:
        config = QM9ConformationConfig
    elif dataset_name == SupportedMultiDatasets.GEOM_DRUGS_SMALL:
        config = GeomDRUGSmallConformationConfig
    elif dataset_name == SupportedMultiDatasets.GEOM_DRUGS:
        config = GeomDRUGConformationConfig
    elif dataset_name == SupportedMultiDatasets.GEOM_QM9_SMALL:
        config = GeomQM9SmallConformationConfig
    elif dataset_name == SupportedMultiDatasets.GEOM_QM9:
        config = GeomQM9ConformationConfig
    else:
        assert False
    if special_config:
        config.update(special_config)
    return config


def recover_datasets(special_config: dict = None, dataset_name: str = SupportedDatasets.QM7,
                     dataset_token=None, seed: int = 0, force_save=False, use_cuda=False
                     ) -> Tuple[GeoMolDataset, GeoMolDataset, GeoMolDataset]:
    set_seed(seed)
    config = recover_config(dataset_name, special_config)
    print('Loading data...', end='\t')
    train_set, validate_set, test_set, properties = load_data(
        dataset_name=dataset_name,
        n_mol_per_pack=config['N_MOL_PER_PACK'],
        n_pack_per_batch=config['N_PACK_PER_BATCH'],
        dataset_token=dataset_token,
        seed=seed,
        force_save=force_save,
        use_cuda=use_cuda,
        use_phi_psi=True
    )
    print('Finished')
    return train_set, validate_set, test_set


def recover_multi_datasets(special_config: dict = None, dataset_name: str = SupportedMultiDatasets.GEOM_QM9_SMALL,
                           dataset_token=None, seed: int = 0, force_save=False, use_cuda=False
                           ) -> Tuple[GeoMolDataset, MultiGeoMolDataset, MultiGeoMolDataset]:
    set_seed(seed)
    config = recover_config(dataset_name, special_config)
    print('Loading data...', end='\t')
    train_set, validate_set, test_set = load_multi_data(
        dataset_name=dataset_name,
        n_mol_per_pack=config['N_MOL_PER_PACK'],
        n_pack_per_batch=config['N_PACK_PER_BATCH'],
        dataset_token=dataset_token,
        seed=seed,
        force_save=force_save,
        use_cuda=use_cuda,
        use_phi_psi=True
    )
    print('Finished')
    return train_set, validate_set, test_set


def recover_conf(special_config: dict = None, dataset_name: str = SupportedDatasets.QM7,
                 token: str = 'default', use_cuda=False) -> ConformationModel:
    # Load Dataset
    config = recover_config(dataset_name, special_config)
    print('Config:')
    for k, v in config.items():
        print(f'\t{k}: {v}')

    # Build Model
    print('Building Model...', end='\t')
    model = ConformationModel(
        atom_dim=num_atom_features(),
        bond_dim=num_bond_features(),
        config=config,
        dataset_name=dataset_name,
        use_cuda=use_cuda
    )
    model.load_state_dict(torch.load(f'{MODEL_DICT_DIR}/{dataset_name}/{token}.pkl', map_location=torch.device('cpu')))
    if use_cuda:
        model.cuda()
    print('Finished')
    print('\tStructure:')
    n_param = 0
    for name, param in model.named_parameters():
        print(f'\t\t{name}: {param.shape}')
        n_param += reduce(lambda x, y: x * y, param.shape)
    print(f'\t# Parameters: {n_param}')

    model.eval()
    return model


def recover_multi_conf(special_config: dict = None, dataset_name: str = SupportedMultiDatasets.GEOM_QM9_SMALL,
                       token: str = 'default', use_cuda=False) -> MultiConformationModel:
    # Load Dataset
    config = recover_config(dataset_name, special_config)
    print('Config:')
    for k, v in config.items():
        print(f'\t{k}: {v}')

    # Build Model
    print('Building Model...', end='\t')
    model = MultiConformationModel(
        atom_dim=num_atom_features(),
        bond_dim=num_bond_features(),
        config=config,
        dataset_name=dataset_name,
        use_cuda=use_cuda
    )
    model.load_state_dict(torch.load(f'{MODEL_DICT_DIR}/{dataset_name}/{token}.pkl', map_location=torch.device('cpu')))
    if use_cuda:
        model.cuda()
    print('Finished')
    print('\tStructure:')
    n_param = 0
    for name, param in model.named_parameters():
        print(f'\t\t{name}: {param.shape}')
        n_param += reduce(lambda x, y: x * y, param.shape)
    print(f'\t# Parameters: {n_param}')

    model.eval()
    return model


def eval_conf(model: ConformationModel, dataset: GeoMolDataset, dataset_token: str, use_tqdm=False, use_cuda=False):
    print(f'{dataset_token}:')
    model.eval()
    n_batch = len(dataset)
    list_n_mol = []
    dict_list_loss = {}
    if use_tqdm:
        iteration = tqdm(dataset, total=n_batch)
    else:
        iteration = dataset
    for packed_mol_graphs, smiles_set, target, dft_geometry, rdkit_geometry, extra_dict in iteration:
        assert isinstance(packed_mol_graphs, PackedMolGraph)
        if use_cuda:
            loss_dict = model.evaluate(
                atom_ftr=packed_mol_graphs.atom_ftr.cuda(),
                bond_ftr=packed_mol_graphs.bond_ftr.cuda(),
                mask_matrices=packed_mol_graphs.mask_matrices.cuda_copy(),
                target_pos_ftr=dft_geometry.cuda(),
                rdkit_pos_ftr=rdkit_geometry.cuda(),
                smiles_set=smiles_set,
                extra_dict=dict_cuda_copy(extra_dict)
            )
        else:
            loss_dict = model.evaluate(
                atom_ftr=packed_mol_graphs.atom_ftr,
                bond_ftr=packed_mol_graphs.bond_ftr,
                mask_matrices=packed_mol_graphs.mask_matrices,
                target_pos_ftr=dft_geometry,
                rdkit_pos_ftr=rdkit_geometry,
                smiles_set=smiles_set,
                extra_dict=extra_dict
            )
        list_n_mol.append(packed_mol_graphs.n_mol)
        for loss_key, loss_value in loss_dict.items():
            dict_list_loss.setdefault(loss_key, []).append(loss_value)

    # total_2_mol = sum(map(lambda x: x * x, list_n_mol))
    # dict_loss = {loss_key: sum([list_loss[i] * list_n_mol[i] ** 2 / total_2_mol for i in range(n_batch)])
    #              for loss_key, list_loss in dict_list_loss.items()}
    total_mol = sum(list_n_mol)
    dict_loss = {loss_key: sum([list_loss[i] * list_n_mol[i] / total_mol for i in range(len(list_n_mol))])
                 for loss_key, list_loss in dict_list_loss.items()}

    for loss_key, loss_value in dict_loss.items():
        print(f'\t\t\t{str.upper(loss_key)} LOSS: {loss_value}')

    return dict_loss


def eval_multi_conf(model: MultiConformationModel, dataset: MultiGeoMolDataset, dataset_token: str,
                    use_tqdm=False, use_cuda=False):
    print(f'{dataset_token}:')
    model.eval()
    n_batch = len(dataset)
    list_n_mol = []
    dict_list_loss = {}
    if use_tqdm:
        iteration = tqdm(dataset, total=n_batch)
    else:
        iteration = dataset
    for packed_mol_graphs, smiles, list_target_conf, list_rdkit_conf, extra_dict in iteration:
        assert isinstance(packed_mol_graphs, PackedMolGraph)
        if use_cuda:
            loss_dict = model.evaluate(
                atom_ftr=packed_mol_graphs.atom_ftr.cuda(),
                bond_ftr=packed_mol_graphs.bond_ftr.cuda(),
                mask_matrices=packed_mol_graphs.mask_matrices.cuda_copy(),
                list_target_conf=[target_conf.cuda() for target_conf in list_target_conf],
                list_rdkit_conf=[rdkit_conf.cuda() for rdkit_conf in list_rdkit_conf],
                smiles=smiles,
                extra_dict=dict_list_cuda_copy(extra_dict)
            )
        else:
            loss_dict = model.evaluate(
                atom_ftr=packed_mol_graphs.atom_ftr,
                bond_ftr=packed_mol_graphs.bond_ftr,
                mask_matrices=packed_mol_graphs.mask_matrices,
                list_target_conf=list_target_conf,
                list_rdkit_conf=list_rdkit_conf,
                smiles=smiles,
                extra_dict=extra_dict
            )
        list_n_mol.append(packed_mol_graphs.n_mol)
        for loss_key, loss_value in loss_dict.items():
            dict_list_loss.setdefault(loss_key, []).append(loss_value)

    # total_2_mol = sum(map(lambda x: x * x, list_n_mol))
    # dict_loss = {loss_key: sum([list_loss[i] * list_n_mol[i] ** 2 / total_2_mol for i in range(len(list_n_mol))])
    #              for loss_key, list_loss in dict_list_loss.items()}
    total_mol = sum(list_n_mol)
    dict_loss = {loss_key: sum([list_loss[i] * list_n_mol[i] / total_mol for i in range(len(list_n_mol))])
                 for loss_key, list_loss in dict_list_loss.items()}

    for loss_key, loss_value in dict_loss.items():
        print(f'\t\t\t{str.upper(loss_key)} LOSS: {loss_value}')

    return dict_loss
