import os
import time
import numpy as np
import torch
import torch.optim as optim

from itertools import chain
from functools import reduce
from typing import Dict, List
from tqdm import tqdm

from data.utils import set_seed, get_mean_std, dict_cuda_copy
from data.structures import MaskMatrices, PackedMolGraph
from data.load_data import load_data, GeoMolDataset, SupportedDatasets
from net.force_model import ForceModel
from train.config import ForceConfig
from train.utils.save_log import save_log
from train.utils.loss_functions import rmse_loss

MODEL_DICT_DIR = 'train/force'


def pretrain_force(dataset_name=SupportedDatasets.QM9, special_config: dict = None,
                   token: str = 'force-qm9', dataset_token='force', seed: int = 0,
                   force_save=False, use_cuda=False, use_tqdm=False):
    # Load Dataset
    set_seed(seed)
    config = ForceConfig
    if special_config:
        config.update(special_config)
    print('Config:')
    for k, v in config.items():
        print(f'\t{k}: {v}')
    print('Loading data...', end='\t')
    train_set, validate_set, test_set, properties = load_data(
        dataset_name=dataset_name,
        n_mol_per_pack=config['N_MOL_PER_PACK'],
        n_pack_per_batch=config['N_PACK_PER_BATCH'],
        dataset_token=dataset_token,
        seed=seed,
        force_save=force_save,
        use_cuda=use_cuda,
        use_disturb=True
    )
    print('Finished')

    # Build Model
    print('Building Model...', end='\t')
    model = ForceModel(use_cuda=use_cuda)
    if use_cuda:
        model.cuda()
    print('Finished')
    print('\tStructure:')
    n_param = 0
    for name, param in model.named_parameters():
        print(f'\t\t{name}: {param.shape}')
        n_param += reduce(lambda x, y: x * y, param.shape)
    print(f'\t# Parameters: {n_param}')

    # Initialize Optimizer
    optimizer = optim.Adam(
        params=model.parameters(),
        lr=config['LR'],
        weight_decay=config['DECAY']
    )
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=config['GAMMA'])

    # Iterating Stage
    logs: List[Dict[str, float]] = []
    best_epoch = 0
    best_metric = 999

    def train(dataset: GeoMolDataset):
        model.train()
        optimizer.zero_grad()

        n_batch = len(dataset)
        loss_parts = []
        list_loss = []
        if use_tqdm:
            iteration = tqdm(enumerate(dataset), total=n_batch)
        else:
            iteration = enumerate(dataset)
        for i, (packed_mol_graphs, smiles_set, target, dft_geometry, rdkit_geometry, extra_dict) in iteration:
            assert isinstance(packed_mol_graphs, PackedMolGraph)
            disturb_geometries = extra_dict['disturb_geometries']
            for disturb_geometry in disturb_geometries:
                if use_cuda:
                    _, force, distance = model.forward(
                        atom_ftr=packed_mol_graphs.atom_ftr.cuda(),
                        bond_ftr=packed_mol_graphs.bond_ftr.cuda(),
                        pos=disturb_geometry.cuda(),
                        mask_matrices=packed_mol_graphs.mask_matrices.cuda_copy(),
                        target_pos=dft_geometry.cuda()
                    )
                else:
                    _, force, distance = model.forward(
                        atom_ftr=packed_mol_graphs.atom_ftr,
                        bond_ftr=packed_mol_graphs.bond_ftr,
                        pos=disturb_geometry,
                        mask_matrices=packed_mol_graphs.mask_matrices,
                        target_pos=dft_geometry
                    )
                loss = rmse_loss(force, distance)
                loss_parts.append(loss)
            list_loss.append(sum(loss_parts) / len(loss_parts))
            loss_parts.clear()
            if len(list_loss) >= config['N_PACK_PER_BATCH'] or i == n_batch - 1:
                sum(list_loss).backward()
                optimizer.step()
                list_loss.clear()

    def evaluate(dataset: GeoMolDataset, batch_token: str) -> float:
        model.eval()

        n_batch = len(dataset)
        list_n_mol = []
        loss_parts = []
        list_loss = []
        if use_tqdm:
            iteration = tqdm(dataset, total=n_batch)
        else:
            iteration = dataset
        for packed_mol_graphs, smiles_set, target, dft_geometry, rdkit_geometry, extra_dict in iteration:
            assert isinstance(packed_mol_graphs, PackedMolGraph)
            disturb_geometries = extra_dict['disturb_geometries']
            for disturb_geometry in disturb_geometries:
                if use_cuda:
                    _, force, distance = model.forward(
                        atom_ftr=packed_mol_graphs.atom_ftr.cuda(),
                        bond_ftr=packed_mol_graphs.bond_ftr.cuda(),
                        pos=disturb_geometry.cuda(),
                        mask_matrices=packed_mol_graphs.mask_matrices.cuda_copy(),
                        target_pos=dft_geometry.cuda()
                    )
                else:
                    _, force, distance = model.forward(
                        atom_ftr=packed_mol_graphs.atom_ftr,
                        bond_ftr=packed_mol_graphs.bond_ftr,
                        pos=disturb_geometry,
                        mask_matrices=packed_mol_graphs.mask_matrices,
                        target_pos=dft_geometry
                    )
                loss = float(rmse_loss(force, distance))
                loss_parts.append(loss)
            list_loss.append(sum(loss_parts) / len(loss_parts))
            loss_parts.clear()
            list_n_mol.append(packed_mol_graphs.n_mol)

        # total_2_mol = sum(map(lambda x: x * x, list_n_mol))
        # loss = sum([list_loss[i] * list_n_mol[i] ** 2 / total_2_mol for i in range(n_batch)])
        total__mol = sum(list_n_mol)
        loss = sum([list_loss[i] * list_n_mol[i] / total_mol for i in range(n_batch)])

        print(f'\t\t\tDISTANCE RMSE: {loss}')
        logs[-1][f'{batch_token}_loss'] = loss
        return loss

    if not os.path.isdir(MODEL_DICT_DIR):
        os.mkdir(MODEL_DICT_DIR)

    for epoch in range(1, config['EPOCH'] + 1):
        logs.append({'epoch': epoch})
        print()
        print(f'##### IN EPOCH {epoch} #####')
        print('\tCurrent LR: {:.3e}'.format(optimizer.state_dict()['param_groups'][0]['lr']))
        print('\t\tTraining:')
        t0 = time.time()
        for _ in range(config['TRAIN']):
            train(train_set)
        t1 = time.time()
        print('\t\tEvaluating Train:')
        evaluate(train_set, 'train')
        print('\t\tEvaluating Validate:')
        m = evaluate(validate_set, 'validate')
        print('\t\tEvaluating Test:')
        evaluate(test_set, 'test')
        t2 = time.time()
        scheduler.step()

        print('\tTraining Time: {}'.format(int(t1 - t0)))
        print('\tEvaluating Time: {}'.format(int(t2 - t1)))
        logs[-1].update({'train_time': t1 - t0})
        logs[-1].update({'eval_time': t2 - t1})

        if m < best_metric:
            best_metric = m
            best_epoch = epoch
            print(f'\tSaving Model...')
            torch.save(model.state_dict(), f'{MODEL_DICT_DIR}/{token}.pkl')
        logs[-1].update({'best_epoch': best_epoch})
        save_log(logs, directory=f'force', tag=token)
