import os
import time
import numpy as np
import torch
import torch.optim as optim

from itertools import chain
from functools import reduce
from typing import Dict, List
from tqdm import tqdm

from data.encode import num_atom_features, num_bond_features
from data.utils import set_seed, get_mean_std
from data.structures import MaskMatrices, PackedMolGraph
from data.load_data import load_data, GeoMolDataset, SupportedDatasets, dataset_is_geometrical, dataset_is_regressive
from net.utils.components import MLP
from net.property_model import PropertyModel
from train.config import QM7PropertyConfig, QM8PropertyConfig
from train.utils.save_log import save_log
from train.utils.loss_functions import multi_mae_loss, multi_mse_loss

MODEL_DICT_DIR = 'train/prop'


def train_qm(special_config: dict = None, dataset_name: str = SupportedDatasets.QM7, token: str = 'default',
             seed: int = 0,
             force_save=False, use_cuda=False, use_tqdm=False):
    # Load Dataset
    set_seed(seed)
    if dataset_name == SupportedDatasets.QM7:
        config = QM7PropertyConfig
    elif dataset_name == SupportedDatasets.QM8:
        config = QM8PropertyConfig
    else:
        assert False
    if special_config:
        config.update(special_config)
    print('Config:')
    for k, v in config.items():
        print(f'\t{k}: {v}')
    print('Loading data...', end='\t')
    train_set, validate_set, test_set, properties = load_data(
        dataset_name=dataset_name,
        n_mol_per_pack=config['N_MOL_PER_PACK'],
        n_pack_per_batch=config['N_PACK_PER_BATCH'],
        dataset_token=token,
        seed=seed,
        force_save=force_save,
        use_cuda=use_cuda
    )
    _, stddev_p = get_mean_std(properties)
    stddev_p = stddev_p.numpy()
    print('Finished')

    # Build Model
    print('Building Model...', end='\t')
    model = PropertyModel(
        atom_dim=num_atom_features(),
        bond_dim=num_bond_features(),
        config=config,
        use_cuda=use_cuda
    )
    mlp = MLP(
        in_dim=config['HM_DIM'],
        out_dim=properties.shape[1],
        hidden_dims=config['CLASSIFIER_HIDDENS'],
        use_cuda=use_cuda
    )
    if use_cuda:
        model.cuda()
        mlp.cuda()
    print('Finished')
    print('\tStructure:')
    n_param = 0
    for name, param in chain(model.named_parameters(), mlp.named_parameters()):
        print(f'\t\t{name}: {param.shape}')
        n_param += reduce(lambda x, y: x * y, param.shape)
    print(f'\t# Parameters: {n_param}')

    # Initialize Optimizer
    optimizer = optim.Adam(
        params=chain(model.parameters(), mlp.parameters()),
        lr=config['LR'],
        weight_decay=config['DECAY']
    )
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=config['GAMMA'])

    # Iterating Stage
    logs: List[Dict[str, float]] = []
    best_epoch = 0
    best_metric = 999

    def train(dataset: GeoMolDataset):
        model.train()
        mlp.train()
        optimizer.zero_grad()

        n_batch = len(dataset)
        list_loss = []
        if use_tqdm:
            iteration = tqdm(enumerate(dataset), total=n_batch)
        else:
            iteration = enumerate(dataset)
        for i, (packed_mol_graphs, smiles_set, target, dft_geometry, rdkit_geometry, extra_dict) in iteration:
            assert isinstance(packed_mol_graphs, PackedMolGraph)
            if use_cuda:
                fp, *_ = model.forward(
                    atom_ftr=packed_mol_graphs.atom_ftr.cuda(),
                    bond_ftr=packed_mol_graphs.bond_ftr.cuda(),
                    mask_matrices=packed_mol_graphs.mask_matrices.cuda_copy(),
                    pos_ftr=dft_geometry.cuda() if dataset_is_geometrical(dataset_name) else rdkit_geometry.cuda()
                )
            else:
                fp, *_ = model.forward(
                    atom_ftr=packed_mol_graphs.atom_ftr,
                    bond_ftr=packed_mol_graphs.bond_ftr,
                    mask_matrices=packed_mol_graphs.mask_matrices,
                    pos_ftr=dft_geometry if dataset_is_geometrical(dataset_name) else rdkit_geometry
                )
            predict = mlp.forward(fp)
            list_loss.append(multi_mse_loss(predict, target))
            if len(list_loss) >= config['N_PACK_PER_BATCH'] or i == n_batch - 1:
                sum(list_loss).backward()
                optimizer.step()
                list_loss.clear()

    def evaluate(dataset: GeoMolDataset, dataset_token: str) -> float:
        model.eval()
        mlp.eval()

        n_batch = len(dataset)
        list_n_mol = []
        list_loss = []
        list_multi_mae = []
        list_total_mae = []
        if use_tqdm:
            iteration = tqdm(dataset, total=n_batch)
        else:
            iteration = dataset
        for packed_mol_graphs, smiles_set, target, dft_geometry, rdkit_geometry, extra_dict in iteration:
            assert isinstance(packed_mol_graphs, PackedMolGraph)
            if use_cuda:
                fp, *_ = model.forward(
                    atom_ftr=packed_mol_graphs.atom_ftr.cuda(),
                    bond_ftr=packed_mol_graphs.bond_ftr.cuda(),
                    mask_matrices=packed_mol_graphs.mask_matrices.cuda_copy(),
                    pos_ftr=dft_geometry.cuda() if dataset_is_geometrical(dataset_name) else rdkit_geometry.cuda()
                )
            else:
                fp, *_ = model.forward(
                    atom_ftr=packed_mol_graphs.atom_ftr,
                    bond_ftr=packed_mol_graphs.bond_ftr,
                    mask_matrices=packed_mol_graphs.mask_matrices,
                    pos_ftr=dft_geometry if dataset_is_geometrical(dataset_name) else rdkit_geometry
                )
            predict = mlp.forward(fp)
            list_n_mol.append(packed_mol_graphs.n_mol)
            list_loss.append(multi_mse_loss(predict, target).cpu().item())
            multi_mae = multi_mae_loss(predict, target, explicit=True)
            total_mae = multi_mae.sum()
            list_multi_mae.append(multi_mae.cpu().detach().numpy())
            list_total_mae.append(total_mae.cpu().item())

        total_mol = sum(list_n_mol)
        total_2_mol = sum(map(lambda x: x * x, list_n_mol))
        loss = sum([list_loss[i] * list_n_mol[i] ** 2 / total_2_mol for i in range(n_batch)])
        multi_mae = sum([list_multi_mae[i] * list_n_mol[i] / total_mol for i in range(n_batch)]) * stddev_p
        if dataset_name == SupportedDatasets.QM8:
            total_mae = np.mean(multi_mae)
            total_mae = float(total_mae)
        else:
            total_mae = sum([list_total_mae[i] * list_n_mol[i] / total_mol for i in range(n_batch)])
        print(f'\t\t\tLOSS: {loss}')
        print(f'\t\t\tMULTI-MAE: {multi_mae[0]}')
        print(f'\t\t\tTOTAL MAE: {total_mae}')
        logs[-1].update({
            f'{dataset_token}_loss': loss,
            f'{dataset_token}_metric': total_mae,
            f'{dataset_token}_multi_metric': multi_mae.tolist()[0],
        })
        return float(total_mae)

    if not os.path.isdir(MODEL_DICT_DIR):
        os.mkdir(MODEL_DICT_DIR)
    if not os.path.isdir(f'{MODEL_DICT_DIR}/{dataset_name}'):
        os.mkdir(f'{MODEL_DICT_DIR}/{dataset_name}')

    for epoch in range(1, config['EPOCH'] + 1):
        logs.append({'epoch': epoch})
        print()
        print(f'##### IN EPOCH {epoch} #####')
        print('\tCurrent LR: {:.3e}'.format(optimizer.state_dict()['param_groups'][0]['lr']))
        print('\t\tTraining:')
        t0 = time.time()
        train(train_set)
        t1 = time.time()
        print('\t\tEvaluating Train:')
        evaluate(train_set, 'train')
        print('\t\tEvaluating Validate:')
        m = evaluate(validate_set, 'validate')
        print('\t\tEvaluating Test:')
        evaluate(test_set, 'test')
        t2 = time.time()
        scheduler.step()

        print('\tTraining Time: {}'.format(int(t1 - t0)))
        print('\tEvaluating Time: {}'.format(int(t2 - t1)))
        logs[-1].update({'train_time': t1 - t0})
        logs[-1].update({'eval_time': t2 - t1})

        if m < best_metric:
            best_metric = m
            best_epoch = epoch
            print(f'\tSaving Model...')
            torch.save(model.state_dict(), f'{MODEL_DICT_DIR}/{dataset_name}/{token}-model.pkl')
            torch.save(mlp.state_dict(), f'{MODEL_DICT_DIR}/{dataset_name}/{token}-mlp.pkl')
        logs[-1].update({'best_epoch': best_epoch})
        save_log(logs, directory=f'{dataset_name}-prop', tag=token)
