import os
import time
import numpy as np
import torch
import torch.optim as optim

from itertools import chain
from functools import reduce
from typing import Dict, List, Union
from tqdm import tqdm

from data.encode import num_atom_features, num_bond_features
from data.utils import set_seed, get_mean_std, dict_cuda_copy
from data.structures import MaskMatrices, PackedMolGraph
from data.load_data import load_data, GeoMolDataset, SupportedDatasets
from net.conformation_model import ConformationModel
from train.config import QM7ConformationConfig, QM8ConformationConfig, QM9ConformationConfig
from train.utils.save_log import save_log

MODEL_DICT_DIR = 'train/conf'


def train_conf(special_config: dict = None, dataset_name: str = SupportedDatasets.QM7,
               token: str = 'default', dataset_token=None, seed: int = 0,
               force_save=False, use_cuda=False, use_tqdm=False):
    # Load Dataset
    set_seed(seed)
    if dataset_name == SupportedDatasets.QM7:
        config = QM7ConformationConfig
    elif dataset_name == SupportedDatasets.QM8:
        config = QM8ConformationConfig
    elif dataset_name == SupportedDatasets.QM9:
        config = QM9ConformationConfig
    else:
        assert False
    if special_config:
        config.update(special_config)
    print('Config:')
    for k, v in config.items():
        print(f'\t{k}: {v}')
    print('Loading data...', end='\t')
    train_set, validate_set, test_set, properties = load_data(
        dataset_name=dataset_name,
        n_mol_per_pack=config['N_MOL_PER_PACK'],
        n_pack_per_batch=config['N_PACK_PER_BATCH'],
        dataset_token=dataset_token,
        seed=seed,
        force_save=force_save,
        use_cuda=use_cuda,
        use_phi_psi=True,
        use_tqdm=use_tqdm
    )
    print('Finished')

    # Build Model
    print('Building Model...', end='\t')
    model = ConformationModel(
        atom_dim=num_atom_features(),
        bond_dim=num_bond_features(),
        config=config,
        dataset_name=dataset_name,
        use_cuda=use_cuda
    )
    if use_cuda:
        model.cuda()
    print('Finished')
    print('\tStructure:')
    n_param = 0
    for name, param in model.named_parameters():
        print(f'\t\t{name}: {param.shape}')
        n_param += reduce(lambda x, y: x * y, param.shape)
    print(f'\t# Parameters: {n_param}')

    # Initialize Optimizer
    optimizer = optim.Adam(
        params=model.parameters(),
        lr=config['LR'],
        weight_decay=config['DECAY']
    )
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=config['GAMMA'])

    # Iterating Stage
    logs: List[Dict[str, Union[float, List[float]]]] = []
    best_epoch = 0
    best_metric = 9e6

    def train(dataset: GeoMolDataset) -> float:
        model.train()
        optimizer.zero_grad()

        n_batch = len(dataset)
        list_loss = []
        list_float_loss = []
        if use_tqdm:
            iteration = tqdm(enumerate(dataset), total=n_batch)
        else:
            iteration = enumerate(dataset)
        for i, (packed_mol_graphs, smiles_set, target, dft_geometry, rdkit_geometry, extra_dict) in iteration:
            assert isinstance(packed_mol_graphs, PackedMolGraph)
            if use_cuda:
                loss = model.forward(
                    atom_ftr=packed_mol_graphs.atom_ftr.cuda(),
                    bond_ftr=packed_mol_graphs.bond_ftr.cuda(),
                    mask_matrices=packed_mol_graphs.mask_matrices.cuda_copy(),
                    target_pos_ftr=dft_geometry.cuda(),
                    rdkit_pos_ftr=rdkit_geometry.cuda(),
                    extra_dict=dict_cuda_copy(extra_dict)
                )
            else:
                loss = model.forward(
                    atom_ftr=packed_mol_graphs.atom_ftr,
                    bond_ftr=packed_mol_graphs.bond_ftr,
                    mask_matrices=packed_mol_graphs.mask_matrices,
                    target_pos_ftr=dft_geometry,
                    rdkit_pos_ftr=rdkit_geometry,
                    extra_dict=extra_dict
                )
            if not torch.isnan(loss):
                list_loss.append(loss)
                list_float_loss.append(float(loss))
            if len(list_loss) >= config['N_PACK_PER_BATCH'] or i == n_batch - 1 and len(list_loss):
                sum(list_loss).backward()
                optimizer.step()
                list_loss.clear()
        loss_value = sum(list_float_loss) / len(list_float_loss)
        print(f'\t\t\tTRAINING LOSS: {loss_value}')
        return loss_value

    def evaluate(dataset: GeoMolDataset, dataset_token: str) -> float:
        model.eval()

        n_batch = len(dataset)
        list_n_mol = []
        dict_list_loss = {}
        if use_tqdm:
            iteration = tqdm(dataset, total=n_batch)
        else:
            iteration = dataset
        for packed_mol_graphs, smiles_set, target, dft_geometry, rdkit_geometry, extra_dict in iteration:
            assert isinstance(packed_mol_graphs, PackedMolGraph)
            if use_cuda:
                loss_dict = model.evaluate(
                    atom_ftr=packed_mol_graphs.atom_ftr.cuda(),
                    bond_ftr=packed_mol_graphs.bond_ftr.cuda(),
                    mask_matrices=packed_mol_graphs.mask_matrices.cuda_copy(),
                    target_pos_ftr=dft_geometry.cuda(),
                    rdkit_pos_ftr=rdkit_geometry.cuda(),
                    smiles_set=smiles_set,
                    extra_dict=dict_cuda_copy(extra_dict)
                )
            else:
                loss_dict = model.evaluate(
                    atom_ftr=packed_mol_graphs.atom_ftr,
                    bond_ftr=packed_mol_graphs.bond_ftr,
                    mask_matrices=packed_mol_graphs.mask_matrices,
                    target_pos_ftr=dft_geometry,
                    rdkit_pos_ftr=rdkit_geometry,
                    smiles_set=smiles_set,
                    extra_dict=extra_dict
                )
            list_n_mol.append(packed_mol_graphs.n_mol)
            for loss_key, loss_value in loss_dict.items():
                dict_list_loss.setdefault(loss_key, []).append(loss_value)

        # total_2_mol = sum(map(lambda x: x * x, list_n_mol))
        # dict_loss = {loss_key: sum([list_loss[i] * list_n_mol[i] ** 2 / total_2_mol for i in range(n_batch)])
        #              for loss_key, list_loss in dict_list_loss.items()}
        total_mol = sum(list_n_mol)
        dict_loss = {loss_key: sum([list_loss[i] * list_n_mol[i] / total_mol for i in range(len(list_n_mol))])
                     for loss_key, list_loss in dict_list_loss.items()}

        for loss_key, loss_value in dict_loss.items():
            print(f'\t\t\t{str.upper(loss_key)} LOSS: {loss_value}')
        logs[-1].update({
            f'{dataset_token}_{loss_key}_loss': float(loss_value) for loss_key, loss_value in dict_loss.items()
        })
        return dict_loss['rmsd']

    if not os.path.isdir(MODEL_DICT_DIR):
        os.mkdir(MODEL_DICT_DIR)
    if not os.path.isdir(f'{MODEL_DICT_DIR}/{dataset_name}'):
        os.mkdir(f'{MODEL_DICT_DIR}/{dataset_name}')

    for epoch in range(0, config['EPOCH'] + 1):
        model.adapt(epoch=epoch)
        logs.append({'epoch': epoch})
        on_training_losses = []
        print()
        print(f'##### IN EPOCH {epoch} #####')
        print('\tCurrent LR: {:.3e}'.format(optimizer.state_dict()['param_groups'][0]['lr']))
        print('\t\tTraining:')
        t0 = time.time()
        if epoch:
            for _ in range(config['TRAIN']):
                on_training_losses.append(train(train_set))
                scheduler.step()
        logs[-1].update({'on_training_losses': on_training_losses})
        t1 = time.time()
        # print('\t\tEvaluating Train:')
        # evaluate(train_set, 'train')
        print('\t\tEvaluating Validate:')
        m = evaluate(validate_set, 'validate')
        if m < best_metric:
            best_metric = m
            best_epoch = epoch
            print(f'\tSaving Model...')
            torch.save(model.state_dict(), f'{MODEL_DICT_DIR}/{dataset_name}/{token}.pkl')
        print('\t\tEvaluating Test:')
        evaluate(test_set, 'test')
        t2 = time.time()

        print('\tTraining Time: {}'.format(int(t1 - t0)))
        print('\tEvaluating Time: {}'.format(int(t2 - t1)))
        logs[-1].update({'train_time': t1 - t0})
        logs[-1].update({'eval_time': t2 - t1})
        logs[-1].update({'best_epoch': best_epoch})
        save_log(logs, directory=f'{dataset_name}-conf', tag=token)
