from ogb.lsc import PCQM4MDataset
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import Dataset, Subset, DataLoader
from pytorch_lightning import LightningDataModule
from typing import Sequence, Dict, Tuple, Optional, Callable
import numpy as np
import pandas as pd
from parser import SmilesParser, get_tokenizers
import config
from utils import get_random_split, Evaluator
from functools import partial
from graphormer.data import get_dataset as get_graph_dataset_info
from graphormer.collator import Batch as BatchedGraphData
from graphormer.collator import collator

dataset_dict = None

def get_smiles_dataset_info(dataset_name: str):
    global dataset_dict
    if dataset_dict is not None:
        return dataset_dict
    
    if dataset_name == 'ogb_lsc':
        smiles_dataset = OGBPCQM4MSmilesDataset()
        dataset_dict = {
            'output_dim': 1,
            'loss_fn': F.l1_loss,
            'metric': 'mae',
            'metric_mode': 'min',
            'evaluator': Evaluator('mae'),
            'dataset': smiles_dataset,
            'vocab_size': smiles_dataset.get_vocab_size(),
            'max_len': 150,
            'split': smiles_dataset.get_idx_split()
        }
    elif dataset_name in ['qm9', 'qm8', 'qm7']:
        dataset = QMSmilesDataset(dataset_name)
        output_dim = {'qm9': 12, 'qm8': 16, 'qm7': 1}
        metric = {'qm9': 'multi_mae', 'qm8': 'mae', 'qm7': 'mae'}
        dataset_dict = {
            'output_dim': output_dim[dataset_name],
            'loss_fn': F.l1_loss,
            'metric': metric[dataset_name],
            'metric_mode': 'min',
            'evaluator': Evaluator(metric[dataset_name]),
            'dataset': dataset,
            'vocab_size': dataset.get_vocab_size(),
            'max_len': 100,
            'split': get_random_split(len(dataset))
        }
    elif dataset_name == 'freesolv':
        dataset = PhyChemSmilesDataset(dataset_name)
        dataset_dict = {
            'output_dim': 1,
            'loss_fn': F.mse_loss,
            'metric': 'rmse',
            'metric_mode': 'min',
            'evaluator': Evaluator('rmse'),
            'dataset': dataset,
            'vocab_size': dataset.get_vocab_size(),
            'max_len': 400,
            'split': get_random_split(len(dataset))
        }
    else:
        raise NotImplementedError()
    return dataset_dict


class SmilesDataModule(LightningDataModule):

    def __init__(
        self,
        dataset_name: str,
        num_workers: int = 8,
        batch_size: int = 20,
        seed: int = 8,
        offline_feature_save_path: str = '',
        *args,
        **kwargs
    ):
        super(SmilesDataModule, self).__init__(*args, **kwargs)
        dataset_dict = get_smiles_dataset_info(dataset_name)
        self.dataset_dict = dataset_dict

        saved_feature = None
        if offline_feature_save_path != '':
            saved_feature = torch.load(offline_feature_save_path)

        if 'split' in dataset_dict:
            self.dataset = dataset_dict['dataset']
            if saved_feature is not None:
                self.dataset.load_saved_feature(saved_feature)
            self.split = dataset_dict['split']
            self.train_dataset = Subset(self.dataset, self.split['train'])
            self.valid_dataset = Subset(self.dataset, self.split['valid'])
            self.test_dataset = Subset(self.dataset, self.split['test'])

        self.num_workers = num_workers
        self.batch_size = batch_size

    def set_attn_bias(self, option: bool):
        self.dataset.need_attn_bias = option

    def train_dataloader(self):
        loader = DataLoader(
            dataset=self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            collate_fn=parsed_smiles_collate_fn,
            pin_memory=True
        )
        print('len(train_dataloader): {}'.format(len(loader)))
        return loader

    def val_dataloader(self):
        loader = DataLoader(
            dataset=self.valid_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=parsed_smiles_collate_fn,
            pin_memory=True
        )
        print('len(val_dataloader): {}'.format(len(loader)))
        return loader

    def test_dataloader(self):
        loader = DataLoader(
            dataset=self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=parsed_smiles_collate_fn,
            pin_memory=True
        )
        print('len(test_dataloader): {}'.format(len(loader)))
        return loader


class DistillationDataModule(LightningDataModule):
    def __init__(
        self,
        dataset_name: str,
        num_workers: int = 8,
        batch_size: int = 20,
        seed: int = 8,
        multi_hop_max_dist: int = 5,
        spatial_pos_max: int = 1024,
        *args,
        **kwargs
    ):
        super(DistillationDataModule, self).__init__(*args, **kwargs)
        self.dataset_name = dataset_name
        self.smiles_dataset_dict = get_smiles_dataset_info(dataset_name)
        self.graph_dataset_dict = get_graph_dataset_info(dataset_name)
        self.smiles_dataset = self.smiles_dataset_dict['dataset']
        self.graph_dataset = self.graph_dataset_dict['dataset']
        self.dataset = DistillationDataset(self.smiles_dataset, self.graph_dataset)
        
        self.num_workers = num_workers
        self.batch_size = batch_size
        self.multi_hop_max_dist = multi_hop_max_dist
        self.spatial_pos_max = spatial_pos_max

        if 'split' in self.smiles_dataset_dict:
            self.split = self.smiles_dataset_dict['split']
            self.train_dataset = Subset(self.dataset, self.split['train'])
            self.valid_dataset = Subset(self.dataset, self.split['valid'])
            self.test_dataset = Subset(self.dataset, self.split['test'])

        # collate function for dataloaders
        self.smiles_collate_fn = parsed_smiles_collate_fn
        self.graph_collate_fn = partial(
            collator,
            max_node=self.graph_dataset_dict['max_node'],
            multi_hop_max_dist=self.multi_hop_max_dist,
            spatial_pos_max=self.spatial_pos_max
        )
        self.collate_fn = distillation_collate_fn(
            self.smiles_collate_fn,
            self.graph_collate_fn
        )

    def train_dataloader(self):
        loader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            pin_memory=True
        )
        print('len(train_dataloader): {}'.format(len(loader)))
        return loader
    
    def val_dataloader(self):
        loader = DataLoader(
            self.valid_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            pin_memory=True
        )
        print('len(val_dataloader): {}'.format(len(loader)))
        return loader
    
    def test_dataloader(self):
        loader = DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            pin_memory=True
        )
        print('len(test_dataloader): {}'.format(len(loader)))
        return loader


class ParsedSmilesData():
    def __init__(
        self, idx: int,
        atoms: Sequence[int],
        bonds: Sequence[int],
        bonds_index: Sequence[Tuple[int, int]],
        y: Tensor,
        feature: Sequence[Tensor] = None,
        mol_token_id: int = 0, # id of molecule token
        pad_id: int = 1, # id of padding token
        need_target_attn_bias: bool = False,
        need_atom_token_attn_mask: bool = False,
    ):
        self.idx, self.atoms, self.bonds, self.bonds_index = idx, atoms, bonds, bonds_index
        self.ids = [mol_token_id] + atoms + bonds
        self.pe = [(i, 0) for i in range(len(atoms) + 1)] + bonds_index
        self.pad_id = pad_id
        self.y = y
        self.feature = feature
        self.length = len(self.ids)

        if need_target_attn_bias:
            attnb = torch.zeros((self.length, self.length))
            for cnt, e in enumerate(self.bonds_index):
                bond_ind = 1 + len(self.atoms) + cnt
                attnb[e[0], e[1]] = 1.0
                attnb[e[1], e[0]] = 1.0
                attnb[e[0], bond_ind] = 1.0
                attnb[bond_ind, e[0]] = 1.0
                attnb[e[1], bond_ind] = 1.0
                attnb[bond_ind, e[1]] = 1.0
            attnb_mask = torch.zeros((self.length, self.length))
            attnb_mask[1: self.length, 0: self.length] = 1.0
            attnb.requires_grad = False
            attnb_mask.requires_grad = False

            self.attnb = attnb
            self.attnb_mask = attnb_mask
        
        if need_atom_token_attn_mask:
            num_atoms, num_bonds = len(self.atoms), len(self.bonds)
            atom_token_attn_mask = torch.zeros((self.length, self.length))
            atom_token_attn_mask[1 : num_atoms + 1, 0 : num_atoms + 1] = 1.0
            atom_token_attn_mask.requires_grad = False
            self.atom_token_attn_mask = atom_token_attn_mask

    def __len__(self):
        return self.length


class SmilesDataset(Dataset):

    def __init__(self, smiles_data: Sequence[str], target: np.ndarray,
                 standardize: bool = False, id2tk: Sequence[str] = None,
                 tk2id: Dict[str, int] = None, dataset_name: str = None):
        super(SmilesDataset, self).__init__()
        assert len(smiles_data) == target.shape[0]
        target = target.astype(np.float32)
        self.parser = SmilesParser()
        self.smiles_data = smiles_data

        if id2tk is None:
            self.id2tk, self.tk2id = get_tokenizers(smiles_data, dataset_name)
        else:
            self.id2tk, self.tk2id = id2tk, tk2id
        print('Tokenizers obtained, {} tokens'.format(len(self.id2tk)))

        num_cols = target.shape[1]
        # Regression on standardized target variables if specified.
        if standardize:
            print('Standardizing target variables...')
            self.target = np.zeros(target.shape).astype(np.float32)
            mean, stddev = [], []
            for i in range(num_cols):
                col = target[:, i]
                col_mean, col_std = np.nanmean(col), np.nanstd(col)
                mean.append(col_mean)
                stddev.append(col_std)
                self.target[:, i] = (col - col_mean) / col_std
                print('Successfully standardized column {}, mean {:.4f}, std {:.4f}.'.format(
                    i, col_mean, col_std
                ))
            if len(stddev) == 1:
                self.stddev = stddev[0]
            else:
                self.stddev = stddev
        else:
            print('No standardization on target variable.')
            self.target = target
            self.stddev = 1.0
        
        self.need_attn_bias = False

    def __len__(self) -> int:
        return len(self.smiles_data)

    def get_vocab_size(self) -> int:
        return len(self.id2tk)

    def load_saved_feature(self, mol_feature):
        # mol feature for distillation
        if isinstance(mol_feature, Tensor):
            mol_feature = [mol_feature] # Sequence[Tensor]
        self.mol_feature = mol_feature
        for t in self.mol_feature:
            t.requires_grad = False

        print('length of feature list: {}, feature size: {}'.format(
            len(self.mol_feature), mol_feature[0].shape[0]
        ))
        assert mol_feature[0].shape[0] == len(self.smiles_data)

    def __getitem__(self, ind: int):
        s = self.smiles_data[ind]
        atoms, bonds, bonds_index = self.parser(s)
        atoms = [self.tk2id[tk] for tk in atoms]
        bonds = [self.tk2id[tk] for tk in bonds]
        y = torch.tensor(self.target[ind, :])

        feature = None
        if hasattr(self, 'mol_feature'):
            feature = [f[ind] for f in self.mol_feature]

        if len(atoms) > 148:
            return None
        else:
            return ParsedSmilesData(
                idx=ind, atoms=atoms, bonds=bonds, bonds_index=bonds_index, y=y, feature=feature,
                mol_token_id=self.tk2id['<s>'], pad_id=self.tk2id['<pad>'],
                need_target_attn_bias=self.need_attn_bias,
                need_atom_token_attn_mask=True,
            )


class BatchedSmilesData:
    def __init__(
        self,
        x: Tensor,
        pad_mask: Tensor,
        pe_index: Tensor,
        y: Tensor,
        feature: Optional[Sequence[Tensor]],
        target_attnb: Optional[Tensor],
        attnb_mask: Optional[Tensor],
        atom_token_attn_mask: Optional[Tensor],
    ):
        self.x, self.pad_mask, self.pe_index, self.y, self.feature = x, pad_mask, pe_index, y, feature
        self.target_attnb = target_attnb
        self.attnb_mask = attnb_mask
        self.atom_token_attn_mask = atom_token_attn_mask

    def to(self, device):
        self.x = self.x.to(device)
        self.pad_mask = self.pad_mask.to(device)
        self.pe_index = self.pe_index.to(device)
        self.y = self.y.to(device)
        
        if self.feature is not None:
            self.feature = [f.to(device) for f in self.feature]

        if self.target_attnb is not None:
            self.target_attnb = self.target_attnb.to(device)
            self.attnb_mask = self.attnb_mask.to(device)
        
        if self.atom_token_attn_mask is not None:
            self.atom_token_attn_mask = self.atom_token_attn_mask.to(device)
        return self
    
    def __len__(self):
        return self.x.shape[0]


def parsed_smiles_collate_fn(items: Sequence[ParsedSmilesData]) -> BatchedSmilesData:
    items = list(filter(lambda x: x is not None, items))
    lengths = list(map(len, items))
    max_len = max(lengths)
    pad_id = items[0].pad_id
    # add paddings
    padded_sequence = [item.ids + [pad_id for _ in range(max_len - len(item))]\
        for item in items]
    padded_pe = [item.pe + [(0, 0) for _ in range(max_len - len(item))]\
        for item in items]
    pad_mask = [[False for _ in range(len(item))] + \
        [True for _ in range(max_len - len(item))] for item in items]
    x = torch.tensor(padded_sequence)
    y = torch.stack([item.y for item in items])
    pe_index = torch.tensor(padded_pe)
    pad_mask = torch.tensor(pad_mask)

    # stack target feature vectors
    feature = None
    if items[0].feature is not None:
        feature = []
        for i in range(len(items[0].feature)):
            feature.append(torch.stack([item.feature[i] for item in items]))

    # extend and stack attention weight biases
    target_attnb, attnb_mask = None, None
    if hasattr(items[0], 'attnb'):
        padded_attnb, padded_mask = [], []
        for item in items:
            pad_length = max_len - len(item)
            pad_size = (0, pad_length, 0, pad_length)
            padded_attnb.append(F.pad(item.attnb, pad_size))
            padded_mask.append(F.pad(item.attnb_mask, pad_size))

        target_attnb = torch.stack(padded_attnb)
        attnb_mask = torch.stack(padded_mask)

    batched_atom_token_attn_mask = None
    if hasattr(items[0], 'atom_token_attn_mask'):
        padded_atom_token_attn_mask = []
        for item in items:
            pad_length = max_len - len(item)
            pad_size = (0, pad_length, 0, pad_length)
            padded_atom_token_attn_mask.append(
                F.pad(item.atom_token_attn_mask, pad_size))
        batched_atom_token_attn_mask = torch.stack(padded_atom_token_attn_mask)
    
    return BatchedSmilesData(
        x=x, pad_mask=pad_mask, pe_index=pe_index, y=y, feature=feature,
        target_attnb=target_attnb, attnb_mask=attnb_mask,
        atom_token_attn_mask=batched_atom_token_attn_mask
    )


class DistillationDataset(Dataset):
    # combination of a SMILES dataset and a molecule graph dataset.
    def __init__(
        self,
        smiles_dataset: Dataset,
        graph_dataset: Dataset,
    ):
        self.smiles_dataset = smiles_dataset
        self.graph_dataset = graph_dataset
    
    def __getitem__(self, ind: int):
        ind = int(ind)
        return self.smiles_dataset[ind], self.graph_dataset[ind]


class BatchedDistillationData:
    def __init__(
        self,
        smiles_data: BatchedSmilesData,
        graph_data: BatchedGraphData,
    ):
        assert len(smiles_data) == len(graph_data)
        self.smiles_data = smiles_data
        self.graph_data = graph_data
    
    def to(self, device):
        self.smiles_data = self.smiles_data.to(device)
        self.graph_data = self.graph_data.to(device)
        return self

    def __len__(self):
        return len(self.smiles_data)


def distillation_collate_fn(smiles_collate_fn: Callable, graph_collate_fn: Callable) -> Callable:
    def fn(items):
        #print(items, items[0], items[0][0], items[0][1])
        smiles_items = [item[0] for item in items]
        graph_items = [item[1] for item in items]
        batched_smiles_items = smiles_collate_fn(smiles_items)
        batched_graph_items = graph_collate_fn(graph_items)
        return BatchedDistillationData(
            smiles_data=batched_smiles_items,
            graph_data=batched_graph_items
        )
    return fn


class QMSmilesDataset(SmilesDataset):
    def __init__(self, dataset_name: str):
        qm_data = pd.read_csv(config.dataset_path[dataset_name])
        tasks = config.dataset_tasks[dataset_name]
        standardize = True if dataset_name in ('qm9', 'qm7') else False
        smiles_data = qm_data['smiles']
        target = np.zeros((qm_data.shape[0], len(tasks)))
        for i in range(len(tasks)):
            target[:, i] = qm_data[tasks[i]]

        super(QMSmilesDataset, self).__init__(
            smiles_data=smiles_data,
            target=target,
            standardize=standardize,
            dataset_name=dataset_name,
        )


class PhyChemSmilesDataset(SmilesDataset):
    def __init__(self, dataset_name: str):
        assert dataset_name in ['lipop', 'freesolv', 'esol']
        data = pd.read_csv(config.dataset_path[dataset_name])
        task = config.dataset_tasks[dataset_name][0]
        smiles_data = data['smiles']
        target = np.zeros((len(smiles_data), 1))
        target[:, 0] = data[task]

        super(PhyChemSmilesDataset, self).__init__(
            smiles_data=smiles_data,
            target=target,
            standardize=False
        )


class OGBPCQM4MSmilesDataset(SmilesDataset):
    def __init__(self):
        ogb_data = PCQM4MDataset(root='./dataset', only_smiles=True)
        self.ogb_data = ogb_data
        smiles_data = []
        target = []
        for i in range(len(ogb_data)):
            smiles_data.append(ogb_data[i][0])
            target.append([ogb_data[i][1]])
        target = np.array(target)
        super(OGBPCQM4MSmilesDataset, self).__init__(
            smiles_data=smiles_data,
            target=target,
            standardize=False,
            dataset_name='ogb_lsc',
        )
    
    def get_idx_split(self):
        return self.ogb_data.get_idx_split()