import ast
import copy
import os
from typing import Any, Callable, Dict, List, Tuple, Union

import numpy as np
from rdkit import RDLogger
from rdkit import Chem
from rdkit.Chem.rdchem import Mol as RDMol
from rdkit.Chem import Descriptors
from ruamel.yaml import YAML
import torch
from torch import Tensor
import torch.nn as nn
import torch_geometric.data as gd
from torch.distributions.dirichlet import Dirichlet

from gflownet.algo.advantage_actor_critic import A2C
from gflownet.algo.envelope_q_learning import EnvelopeQLearning
from gflownet.algo.soft_q_learning import SoftQLearning
from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce
from gflownet.algo.trajectory_balance import TrajectoryBalance
from gflownet.data.qm9 import QM9Dataset
from gflownet.envs.graph_building_env import GraphBuildingEnv
from gflownet.envs.mol_building_env import MolBuildingEnvContext
from gflownet.models.graph_transformer import GraphTransformerEnvelopeQL
from gflownet.models.graph_transformer import GraphTransformerGFN
import gflownet.models.mxmnet as mxmnet
from gflownet.train import FlatRewards
from gflownet.train import GFNTask
from gflownet.train import GFNTrainer
from gflownet.train import RewardScalar
from gflownet.utils.transforms import thermometer
from gflownet.utils import sascore
from gflownet.tasks.seh_frag_moo import MultiObjectiveStatsHook  # TODO: refactor this


def _safe(f, x, default):
    try:
        return f(x)
    except Exception:
        return default


class QM9MOODataset(QM9Dataset):
    def compute_other_flat_rewards(self, mol):
        logp = np.exp(-(_safe(Descriptors.MolLogP, mol, 0) - 2.5)**2 / 2)
        sa = (10 - _safe(sascore.calculateScore, mol, 10)) / 9  # Turn into a [0-1] reward
        molwt = np.exp(-(_safe(Descriptors.MolWt, mol, 1000) - 105)**2 / 150)
        return logp, sa, molwt

    def __getitem__(self, idx):
        mol = Chem.MolFromSmiles(self.df['SMILES'][self.idcs[idx]])
        return mol, [self.df[self.target][self.idcs[idx]], *self.compute_other_flat_rewards(mol)]


class QM9GapMOOTask(GFNTask):
    """This class captures conditional information generation and reward transforms"""
    def __init__(self, dataset: QM9MOODataset, temperature_distribution: str, temperature_parameters: Tuple[float],
                 wrap_model: Callable[[nn.Module], nn.Module] = None):
        self._wrap_model = wrap_model
        self.models = self.load_task_models()
        self.dataset = dataset
        self.temperature_sample_dist = temperature_distribution
        self.temperature_dist_params = temperature_parameters
        self.number_of_objectives = 4

    def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards:
        return FlatRewards(torch.as_tensor(y).float())

    def inverse_flat_reward_transform(self, rp):
        return rp

    def load_task_models(self):
        gap_model = mxmnet.MXMNet(mxmnet.Config(128, 6, 5.0))
        # TODO: this path should be part of the config?
        state_dict = torch.load('../data/qm9/mxmnet_gap_model.pt')
        gap_model.load_state_dict(state_dict)
        gap_model.cuda()
        gap_model, self.device = self._wrap_model(gap_model)
        return {'mxmnet_gap': gap_model}

    def sample_conditional_information(self, n):
        beta = None
        if self.temperature_sample_dist == 'gamma':
            beta = self.rng.gamma(*self.temperature_dist_params, n).astype(np.float32)
        elif self.temperature_sample_dist == 'uniform':
            beta = self.rng.uniform(*self.temperature_dist_params, n).astype(np.float32)
        elif self.temperature_sample_dist == 'beta':
            beta = self.rng.beta(*self.temperature_dist_params, n).astype(np.float32)
        beta_enc = thermometer(torch.tensor(beta), 32, 0, 32)  # TODO: hyperparameters
        m = Dirichlet(torch.FloatTensor([1.5] * self.number_of_objectives))
        preferences = m.sample([n])
        encoding = torch.cat([beta_enc, preferences], 1)
        return {'beta': torch.tensor(beta), 'encoding': encoding, 'preferences': preferences}

    def cond_info_to_reward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar:
        if isinstance(flat_reward, list):
            if isinstance(flat_reward[0], Tensor):
                flat_reward = torch.stack(flat_reward)
            else:
                flat_reward = torch.tensor(flat_reward)
        scalar_reward = (flat_reward * cond_info['preferences']).sum(1)
        return scalar_reward**cond_info['beta']

    def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]:
        graphs = [mxmnet.mol2graph(i) for i in mols]  # type: ignore[attr-defined]
        is_valid = torch.tensor([i is not None for i in graphs]).bool()
        if not is_valid.any():
            return FlatRewards(torch.zeros((0, 4))), is_valid
        batch = gd.Batch.from_data_list([i for i in graphs if i is not None])
        batch.to(self.device)
        preds = self.models['mxmnet_gap'](batch).reshape((-1,)).data.cpu() / mxmnet.HAR2EV  # type: ignore[attr-defined]
        preds[preds.isnan()] = 1
        preds = self.flat_reward_transform(preds).clip(1e-4, 2).reshape((-1, 1))

        other_flats = torch.as_tensor(
            [self.dataset.compute_other_flat_rewards(i) for i, v in zip(mols, is_valid) if v.item()])
        flat_rewards = torch.cat([preds, other_flats], 1).float()
        return FlatRewards(flat_rewards), is_valid


class QM9GapMOOTrainer(GFNTrainer):
    def default_hps(self) -> Dict[str, Any]:
        return {
            'algo': 'TB',
            'bootstrap_own_reward': False,
            'learning_rate': 1e-4,
            'global_batch_size': 64,
            'num_emb': 128,
            'num_layers': 4,
            'tb_epsilon': None,
            'illegal_action_logreward': -50,
            'reward_loss_multiplier': 1,
            'temperature_sample_dist': 'uniform',
            'temperature_dist_params': '(.5, 32)',
            'weight_decay': 1e-8,
            'num_data_loader_workers': 8,
            'momentum': 0.9,
            'adam_eps': 1e-8,
            'lr_decay': 20000,
            'Z_lr_decay': 20000,
            'clip_grad_type': 'norm',
            'clip_grad_param': 10,
            'random_action_prob': .001,
            'sampling_tau': 0.,
        }

    def setup(self):
        hps = self.hps
        RDLogger.DisableLog('rdApp.*')
        self.rng = np.random.default_rng(142857)
        self.env = GraphBuildingEnv()
        self.ctx = MolBuildingEnvContext(['H', 'C', 'N', 'F', 'O'], num_cond_dim=32 + 4)
        self.training_data = QM9MOODataset(hps['qm9_h5_path'], train=True, target='gap')
        self.test_data = QM9MOODataset(hps['qm9_h5_path'], train=False, target='gap')

        if hps['algo'] == 'TB':
            self.algo = TrajectoryBalance(self.env, self.ctx, self.rng, hps, max_nodes=9)
        elif hps['algo'] == 'SQL':
            self.algo = SoftQLearning(self.env, self.ctx, self.rng, hps, max_nodes=9)
        elif hps['algo'] == 'A2C':
            self.algo = A2C(self.env, self.ctx, self.rng, hps, max_nodes=9)
        elif hps['algo'] == 'MOREINFORCE':
            self.algo = MultiObjectiveReinforce(self.env, self.ctx, self.rng, hps, max_nodes=9)
        elif hps['algo'] == 'MOQL':
            self.algo = EnvelopeQLearning(self.env, self.ctx, self.rng, hps, max_nodes=9)

        if hps['algo'] == 'MOQL':
            model = GraphTransformerEnvelopeQL(self.ctx, num_emb=hps['num_emb'], num_layers=hps['num_layers'],
                                               num_objectives=4)
        else:
            model = GraphTransformerGFN(self.ctx, num_emb=hps['num_emb'], num_layers=hps['num_layers'])
            
        self.model = model
        # Separate Z parameters from non-Z to allow for LR decay on the former
        Z_params = list(model.logZ.parameters())
        non_Z_params = [i for i in self.model.parameters() if all(id(i) != id(j) for j in Z_params)]
        self.opt = torch.optim.Adam(non_Z_params, hps['learning_rate'], (hps['momentum'], 0.999),
                                    weight_decay=hps['weight_decay'], eps=hps['adam_eps'])
        self.opt_Z = torch.optim.Adam(Z_params, hps['learning_rate'], (0.9, 0.999))
        self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2**(-steps / hps['lr_decay']))
        self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR(self.opt_Z, lambda steps: 2**(-steps / hps['Z_lr_decay']))

        self.sampling_tau = hps['sampling_tau']
        if self.sampling_tau > 0:
            self.sampling_model = copy.deepcopy(model)
        else:
            self.sampling_model = self.model
        self.offline_ratio = hps.get('offline_ratio', 0.5)
        eps = hps['tb_epsilon']
        hps['tb_epsilon'] = ast.literal_eval(eps) if isinstance(eps, str) else eps

        self.task = QM9GapMOOTask(self.training_data, hps['temperature_sample_dist'],
                                  ast.literal_eval(hps['temperature_dist_params']), wrap_model=self._wrap_model_mp)
        self.mb_size = hps['global_batch_size']
        self.clip_grad_param = hps['clip_grad_param']
        self.clip_grad_callback = {
            'value': (lambda params: torch.nn.utils.clip_grad_value_(params, self.clip_grad_param)),
            'norm': (lambda params: torch.nn.utils.clip_grad_norm_(params, self.clip_grad_param)),
            'none': (lambda x: None)
        }[hps['clip_grad_type']]
        shook = MultiObjectiveStatsHook(512, self.hps['log_dir'])
        #shook.compute_hsri = True
        self.sampling_hooks.append(shook)
        self.algo.task = self.task

    def step(self, loss: Tensor):
        loss.backward()
        for i in self.model.parameters():
            self.clip_grad_callback(i)
        self.opt.step()
        self.opt.zero_grad()
        self.opt_Z.step()
        self.opt_Z.zero_grad()
        self.lr_sched.step()
        self.lr_sched_Z.step()
        if self.sampling_tau > 0:
            for a, b in zip(self.model.parameters(), self.sampling_model.parameters()):
                b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau))


def main():
    # Example of how to run this task
    yaml = YAML(typ="safe", pure=True)
    config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'qm9_moo.yaml')
    with open(config_file, 'r') as f:
        _hps = yaml.load(f)

    hps = {
        **_hps,
        'log_dir': '/scratch/logs/qm9_gap_mxmnet_moo/run_sql_0/',
        'algo': 'SQL',
        'sql_alpha': 0.01,
    }
    trial = QM9GapMOOTrainer(hps, torch.device('cuda'))
    trial.verbose = True
    trial.run()


if __name__ == '__main__':
    main()
