import os
import time
import numpy as np
import torch
import torch.optim as optim

from itertools import chain
from functools import reduce
from typing import Dict, List, Union
from tqdm import tqdm

from data.encode import num_atom_features, num_bond_features
from data.utils import set_seed, dict_cuda_copy, dict_list_cuda_copy
from data.structures import MaskMatrices, PackedMolGraph
from data.load_multi_data import load_multi_data, GeoMolDataset, MultiGeoMolDataset, SupportedMultiDatasets
from net.conformation_model import MultiConformationModel
from train.config import GeomQM9SmallConformationConfig, GeomQM9ConformationConfig, \
    GeomDRUGSmallConformationConfig, GeomDRUGConformationConfig
from train.utils.save_log import save_log

MODEL_DICT_DIR = 'train/conf'


def train_multi_conf(special_config: dict = None, dataset_name: str = SupportedMultiDatasets.GEOM_QM9,
                     token: str = 'default', dataset_token=None, seed: int = 0,
                     force_save=False, use_cuda=False, use_tqdm=False):
    # Load Dataset
    set_seed(seed)
    if dataset_name == SupportedMultiDatasets.GEOM_DRUGS_SMALL:
        config = GeomDRUGSmallConformationConfig
    elif dataset_name == SupportedMultiDatasets.GEOM_DRUGS:
        config = GeomDRUGConformationConfig
    elif dataset_name == SupportedMultiDatasets.GEOM_QM9_SMALL:
        config = GeomQM9SmallConformationConfig
    elif dataset_name == SupportedMultiDatasets.GEOM_QM9:
        config = GeomQM9ConformationConfig
    else:
        assert False
    if special_config:
        config.update(special_config)
    print('Config:')
    for k, v in config.items():
        print(f'\t{k}: {v}')
    print('Loading data...', end='\t')
    train_set, validate_set, test_set = load_multi_data(
        dataset_name=dataset_name,
        n_mol_per_pack=config['N_MOL_PER_PACK'],
        n_pack_per_batch=config['N_PACK_PER_BATCH'],
        dataset_token=dataset_token,
        seed=seed,
        force_save=force_save,
        use_cuda=use_cuda,
        use_tqdm=use_tqdm,
        use_phi_psi=True
    )
    print('Finished')

    # Build Model
    print('Building Model...', end='\t')
    model = MultiConformationModel(
        atom_dim=num_atom_features(),
        bond_dim=num_bond_features(),
        config=config,
        dataset_name=dataset_name,
        use_cuda=use_cuda
    )
    if use_cuda:
        model.cuda()
    print('Finished')
    print('\tStructure:')
    n_param = 0
    for name, param in model.named_parameters():
        print(f'\t\t{name}: {param.shape}')
        n_param += reduce(lambda x, y: x * y, param.shape)
    print(f'\t# Parameters: {n_param}')

    # Initialize Optimizer
    optimizer = optim.Adam(
        params=model.parameters(),
        lr=config['LR'],
        weight_decay=config['DECAY']
    )
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=config['GAMMA'])

    # Iterating Stage
    logs: List[Dict[str, Union[float, List[float]]]] = []
    best_epoch = 0
    best_metric = 9e6

    def train(dataset: GeoMolDataset) -> float:
        model.train()
        optimizer.zero_grad()

        n_batch = len(dataset)
        list_loss = []
        list_float_loss = []
        if use_tqdm:
            iteration = tqdm(enumerate(dataset), total=n_batch)
        else:
            iteration = enumerate(dataset)
        for i, (packed_mol_graphs, smiles_set, _, target_conf, rdkit_conf, extra_dict) in iteration:
            assert isinstance(packed_mol_graphs, PackedMolGraph)
            if use_cuda:
                loss = model.forward(
                    atom_ftr=packed_mol_graphs.atom_ftr.cuda(),
                    bond_ftr=packed_mol_graphs.bond_ftr.cuda(),
                    mask_matrices=packed_mol_graphs.mask_matrices.cuda_copy(),
                    target_conf=target_conf.cuda(),
                    rdkit_conf=rdkit_conf.cuda(),
                    extra_dict=dict_cuda_copy(extra_dict)
                )
            else:
                loss = model.forward(
                    atom_ftr=packed_mol_graphs.atom_ftr,
                    bond_ftr=packed_mol_graphs.bond_ftr,
                    mask_matrices=packed_mol_graphs.mask_matrices,
                    target_conf=target_conf,
                    rdkit_conf=rdkit_conf,
                    extra_dict=extra_dict
                )
            if not torch.isnan(loss):
                list_loss.append(loss)
                list_float_loss.append(float(loss))
            if len(list_loss) >= config['N_PACK_PER_BATCH'] or i == n_batch - 1 and len(list_loss):
                sum(list_loss).backward()
                optimizer.step()
                list_loss.clear()
        loss_value = sum(list_float_loss) / len(list_float_loss)
        print(f'\t\t\tTRAINING LOSS: {loss_value}')
        return loss_value

    def evaluate(dataset: MultiGeoMolDataset, dataset_token: str, simple=False) -> float:
        model.eval()

        n_batch = len(dataset)
        list_n_mol = []
        dict_list_loss = {}
        if use_tqdm:
            iteration = tqdm(dataset, total=n_batch)
        else:
            iteration = dataset
        for packed_mol_graphs, smiles, list_target_conf, list_rdkit_conf, extra_dict in iteration:
            assert isinstance(packed_mol_graphs, PackedMolGraph)
            if use_cuda:
                loss_dict = model.evaluate(
                    atom_ftr=packed_mol_graphs.atom_ftr.cuda(),
                    bond_ftr=packed_mol_graphs.bond_ftr.cuda(),
                    mask_matrices=packed_mol_graphs.mask_matrices.cuda_copy(),
                    list_target_conf=[target_conf.cuda() for target_conf in list_target_conf],
                    list_rdkit_conf=[rdkit_conf.cuda() for rdkit_conf in list_rdkit_conf],
                    smiles=smiles,
                    extra_dict=dict_list_cuda_copy(extra_dict)
                )
            else:
                loss_dict = model.evaluate(
                    atom_ftr=packed_mol_graphs.atom_ftr,
                    bond_ftr=packed_mol_graphs.bond_ftr,
                    mask_matrices=packed_mol_graphs.mask_matrices,
                    list_target_conf=list_target_conf,
                    list_rdkit_conf=list_rdkit_conf,
                    smiles=smiles,
                    extra_dict=extra_dict
                )
            list_n_mol.append(packed_mol_graphs.n_mol)
            for loss_key, loss_value in loss_dict.items():
                dict_list_loss.setdefault(loss_key, []).append(loss_value)

        # total_2_mol = sum(map(lambda x: x * x, list_n_mol))
        # dict_loss = {loss_key: sum([list_loss[i] * list_n_mol[i] ** 2 / total_2_mol for i in range(len(list_n_mol))])
        #              for loss_key, list_loss in dict_list_loss.items()}
        total_mol = sum(list_n_mol)
        dict_loss = {loss_key: sum([list_loss[i] * list_n_mol[i] / total_mol for i in range(len(list_n_mol))])
                     for loss_key, list_loss in dict_list_loss.items()}

        for loss_key, loss_value in dict_loss.items():
            print(f'\t\t\t{str.upper(loss_key)} LOSS: {loss_value}')
        logs[-1].update({
            f'{dataset_token}_{loss_key}_loss': float(loss_value) for loss_key, loss_value in dict_loss.items()
        })
        return dict_loss['mat']

    if not os.path.isdir(MODEL_DICT_DIR):
        os.mkdir(MODEL_DICT_DIR)
    if not os.path.isdir(f'{MODEL_DICT_DIR}/{dataset_name}'):
        os.mkdir(f'{MODEL_DICT_DIR}/{dataset_name}')

    for epoch in range(0, config['EPOCH'] + 1):
        model.adapt(epoch=epoch)
        logs.append({'epoch': epoch})
        on_training_losses = []
        print()
        print(f'##### IN EPOCH {epoch} #####')
        print('\tCurrent LR: {:.3e}'.format(optimizer.state_dict()['param_groups'][0]['lr']))
        print('\t\tTraining:')
        t0 = time.time()
        if epoch:
            for _ in range(config['TRAIN']):
                on_training_losses.append(train(train_set))
                scheduler.step()
        logs[-1].update({'on_training_losses': on_training_losses})
        t1 = time.time()
        print('\t\tEvaluating Validate:')
        m = evaluate(validate_set, 'validate', simple=epoch < config['EPOCH'])
        if m < best_metric:
            best_metric = m
            best_epoch = epoch
            print(f'\tSaving Model...')
            torch.save(model.state_dict(), f'{MODEL_DICT_DIR}/{dataset_name}/{token}.pkl')
        print('\t\tEvaluating Test:')
        evaluate(test_set, 'test', simple=epoch < config['EPOCH'])
        t2 = time.time()

        print('\tTraining Time: {}'.format(int(t1 - t0)))
        print('\tEvaluating Time: {}'.format(int(t2 - t1)))
        logs[-1].update({'train_time': t1 - t0})
        logs[-1].update({'eval_time': t2 - t1})
        logs[-1].update({'best_epoch': best_epoch})
        save_log(logs, directory=f'{dataset_name}-conf', tag=token)
