import time
from typing import Dict, List, Any
from tqdm import tqdm

from data.encode import num_atom_features, num_bond_features
from data.utils import set_seed, dict_cuda_copy, dict_list_cuda_copy
from data.structures import PackedMolGraph
from data.load_data import load_data, GeoMolDataset, SupportedDatasets
from data.load_multi_data import load_multi_data, MultiGeoMolDataset, SupportedMultiDatasets
from net.conformation_model import ConformationModel, MultiConformationModel
from train.config import QM7ConformationConfig, QM8ConformationConfig, QM9ConformationConfig, \
    GeomDRUGConformationConfig, GeomDRUGSmallConformationConfig, \
    GeomQM9ConformationConfig, GeomQM9SmallConformationConfig
from train.utils.save_log import save_log


def rdkit_conf(special_config: dict = None, dataset_name: str = SupportedDatasets.QM7,
               token: str = 'rdkit', dataset_token=None, seed: int = 0,
               force_save=False, use_cuda=False, use_tqdm=False):
    # Load Dataset
    set_seed(seed)
    if dataset_name == SupportedDatasets.QM7:
        config = QM7ConformationConfig
    elif dataset_name == SupportedDatasets.QM8:
        config = QM8ConformationConfig
    elif dataset_name == SupportedDatasets.QM9:
        config = QM9ConformationConfig
    else:
        assert False
    if special_config:
        config.update(special_config)
    print('Config:')
    for k, v in config.items():
        print(f'\t{k}: {v}')
    print('Loading data...', end='\t')
    train_set, validate_set, test_set, properties = load_data(
        dataset_name=dataset_name,
        n_mol_per_pack=config['N_MOL_PER_PACK'],
        n_pack_per_batch=config['N_PACK_PER_BATCH'],
        dataset_token=dataset_token,
        seed=seed,
        force_save=force_save,
        use_cuda=use_cuda,
        use_phi_psi=True
    )
    print('Finished')

    # Build Model
    print('Building Model...', end='\t')
    model = ConformationModel(
        atom_dim=num_atom_features(),
        bond_dim=num_bond_features(),
        config=config,
        dataset_name=dataset_name,
        use_cuda=use_cuda
    )
    if use_cuda:
        model.cuda()
    print('Finished')

    logs: List[Dict[str, Any]] = []

    def evaluate(dataset: GeoMolDataset, dataset_token: str) -> float:
        model.eval()

        n_batch = len(dataset)
        list_n_mol = []
        dict_list_loss = {}
        if use_tqdm:
            iteration = tqdm(dataset, total=n_batch)
        else:
            iteration = dataset
        for packed_mol_graphs, smiles_set, target, dft_geometry, rdkit_geometry, extra_dict in iteration:
            assert isinstance(packed_mol_graphs, PackedMolGraph)
            if use_cuda:
                loss_dict = model.directly_compare(
                    atom_ftr=packed_mol_graphs.atom_ftr.cuda(),
                    bond_ftr=packed_mol_graphs.bond_ftr.cuda(),
                    mask_matrices=packed_mol_graphs.mask_matrices.cuda_copy(),
                    target_pos_ftr=dft_geometry.cuda(),
                    rdkit_pos_ftr=rdkit_geometry.cuda(),
                    smiles_set=smiles_set,
                    extra_dict=dict_cuda_copy(extra_dict)
                )
            else:
                loss_dict = model.directly_compare(
                    atom_ftr=packed_mol_graphs.atom_ftr,
                    bond_ftr=packed_mol_graphs.bond_ftr,
                    mask_matrices=packed_mol_graphs.mask_matrices,
                    target_pos_ftr=dft_geometry,
                    rdkit_pos_ftr=rdkit_geometry,
                    smiles_set=smiles_set,
                    extra_dict=extra_dict
                )
            list_n_mol.append(packed_mol_graphs.n_mol)
            for loss_key, loss_value in loss_dict.items():
                dict_list_loss.setdefault(loss_key, []).append(loss_value)

        # total_2_mol = sum(map(lambda x: x * x, list_n_mol))
        # dict_loss = {loss_key: sum([list_loss[i] * list_n_mol[i] ** 2 / total_2_mol for i in range(n_batch)])
        #              for loss_key, list_loss in dict_list_loss.items()}
        total_mol = sum(list_n_mol)
        dict_loss = {loss_key: sum([list_loss[i] * list_n_mol[i] / total_mol for i in range(len(list_n_mol))])
                     for loss_key, list_loss in dict_list_loss.items()}

        for loss_key, loss_value in dict_loss.items():
            print(f'\t\t\t{str.upper(loss_key)} LOSS: {loss_value}')
        logs[-1].update({
            f'{dataset_token}_{loss_key}_loss': float(loss_value) for loss_key, loss_value in dict_loss.items()
        })
        return dict_loss['adj-1']

    epoch = 0
    model.adapt(epoch=epoch)
    logs.append({'epoch': epoch})
    t1 = time.time()
    # print('\t\tEvaluating Train:')
    # evaluate(train_set, 'train')
    # print('\t\tEvaluating Validate:')
    # m = evaluate(validate_set, 'validate')
    print('\t\tEvaluating Test:')
    evaluate(test_set, 'test')
    t2 = time.time()

    print('\tEvaluating Time: {}'.format(int(t2 - t1)))
    logs[-1].update({'eval_time': t2 - t1})
    save_log(logs, directory=f'{dataset_name}-conf', tag=token)


def rdkit_multi_conf(special_config: dict = None, dataset_name: str = SupportedMultiDatasets.GEOM_QM9,
                     token: str = 'rdkit', dataset_token=None, seed: int = 0,
                     force_save=False, use_cuda=False, use_tqdm=False):
    # Load Dataset
    set_seed(seed)
    if dataset_name == SupportedMultiDatasets.GEOM_DRUGS_SMALL:
        config = GeomDRUGSmallConformationConfig
    elif dataset_name == SupportedMultiDatasets.GEOM_DRUGS:
        config = GeomDRUGConformationConfig
    elif dataset_name == SupportedMultiDatasets.GEOM_QM9_SMALL:
        config = GeomQM9SmallConformationConfig
    elif dataset_name == SupportedMultiDatasets.GEOM_QM9:
        config = GeomQM9ConformationConfig
    else:
        assert False
    if special_config:
        config.update(special_config)
    print('Config:')
    for k, v in config.items():
        print(f'\t{k}: {v}')
    print('Loading data...', end='\t')
    train_set, validate_set, test_set = load_multi_data(
        dataset_name=dataset_name,
        n_mol_per_pack=config['N_MOL_PER_PACK'],
        n_pack_per_batch=config['N_PACK_PER_BATCH'],
        dataset_token=dataset_token,
        seed=seed,
        force_save=force_save,
        use_cuda=use_cuda,
        use_tqdm=use_tqdm,
        use_phi_psi=True
    )
    print('Finished')

    # Build Model
    print('Building Model...', end='\t')
    model = MultiConformationModel(
        atom_dim=num_atom_features(),
        bond_dim=num_bond_features(),
        config=config,
        dataset_name=dataset_name,
        use_cuda=use_cuda
    )
    if use_cuda:
        model.cuda()
    print('Finished')

    # Iterating Stage
    logs: List[Dict[str, float]] = []

    def evaluate(dataset: MultiGeoMolDataset, dataset_token: str, simple=False) -> float:
        model.eval()

        n_batch = len(dataset)
        list_n_mol = []
        dict_list_loss = {}
        if use_tqdm:
            iteration = tqdm(dataset, total=n_batch)
        else:
            iteration = dataset
        cnt = 0
        for packed_mol_graphs, smiles, list_target_conf, list_rdkit_conf, extra_dict in iteration:
            cnt += 1
            # if simple and cnt % 4 != 1:
            #     continue
            assert isinstance(packed_mol_graphs, PackedMolGraph)
            if use_cuda:
                loss_dict = model.directly_compare(
                    atom_ftr=packed_mol_graphs.atom_ftr.cuda(),
                    bond_ftr=packed_mol_graphs.bond_ftr.cuda(),
                    mask_matrices=packed_mol_graphs.mask_matrices.cuda_copy(),
                    list_target_conf=[target_conf.cuda() for target_conf in list_target_conf],
                    list_rdkit_conf=[rdkit_conf.cuda() for rdkit_conf in list_rdkit_conf],
                    smiles=smiles,
                    extra_dict=dict_list_cuda_copy(extra_dict)
                )
            else:
                loss_dict = model.directly_compare(
                    atom_ftr=packed_mol_graphs.atom_ftr,
                    bond_ftr=packed_mol_graphs.bond_ftr,
                    mask_matrices=packed_mol_graphs.mask_matrices,
                    list_target_conf=list_target_conf,
                    list_rdkit_conf=list_rdkit_conf,
                    smiles=smiles,
                    extra_dict=extra_dict
                )
            list_n_mol.append(packed_mol_graphs.n_mol)
            for loss_key, loss_value in loss_dict.items():
                dict_list_loss.setdefault(loss_key, []).append(loss_value)

        # total_2_mol = sum(map(lambda x: x * x, list_n_mol))
        # dict_loss = {loss_key: sum([list_loss[i] * list_n_mol[i] ** 2 / total_2_mol for i in range(len(list_n_mol))])
        #              for loss_key, list_loss in dict_list_loss.items()}
        total_mol = sum(list_n_mol)
        dict_loss = {loss_key: sum([list_loss[i] * list_n_mol[i] / total_mol for i in range(len(list_n_mol))])
                     for loss_key, list_loss in dict_list_loss.items()}

        for loss_key, loss_value in dict_loss.items():
            print(f'\t\t\t{str.upper(loss_key)} LOSS: {loss_value}')
        logs[-1].update({
            f'{dataset_token}_{loss_key}_loss': float(loss_value) for loss_key, loss_value in dict_loss.items()
        })
        return dict_loss['adj3']

    epoch = 0
    logs.append({'epoch': epoch})
    t1 = time.time()
    # print('\t\tEvaluating Validate:')
    # m = evaluate(validate_set, 'validate', simple=epoch < config['EPOCH'])
    print('\t\tEvaluating Test:')
    evaluate(test_set, 'test', simple=epoch < config['EPOCH'])
    t2 = time.time()

    print('\tEvaluating Time: {}'.format(int(t2 - t1)))
    logs[-1].update({'eval_time': t2 - t1})
    save_log(logs, directory=f'{dataset_name}-conf', tag=token)
