import os
import logging
import lmdb
import numpy as np
import pickle as pkl
import networkx as nx
import itertools
from scipy import linalg
from typing import Sequence, Dict
from dataclasses import dataclass, field
from omegaconf import MISSING
import torch
from fairseq import utils
from fairseq.data import FairseqDataset
from fairseq.dataclass import FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
import tensorboardX

from ..utils import Alphabet, set_cpu_num
from ._data_process import PromptConvert, MaskedConverter

logger = logging.getLogger(__name__)

class FusionDataset(FairseqDataset):
    def __init__(self, data_dir, split, alphabet) -> None:
        super().__init__()
        self.data_dir = data_dir
        self.split = split
        self.alphebet = alphabet
        self.converter = PromptConvert(alphabet)

        self.link_path = os.path.join(data_dir, 'string/physical-link-1024-new')
        self.sequence_path = os.path.join(data_dir, 'string/sequence')
        self.crd_path = os.path.join(data_dir, 'af2db')
        self.mlm_path = os.path.join(data_dir, 'uniref50')

        self.loss = np.array([1000, 1000, 1000])
        self.position = np.array([0, 0])
        self.prob = [1/3, 1/3, 1/3]
        
        self.current = ['mlm', 'crd', 'ppi']
        # self.prompt_toks = [['<seq>'], ['<crd>'], ['<ppi>']]
        self.prompt_toks = [[], [], []]
        self.converter = [MaskedConverter(alphabet), PromptConvert(alphabet), PromptConvert(alphabet)]
        self.current_idx = 0
        print(self.prompt_toks)

    def open_lmdb(self):
        self.link_env = lmdb.open(self.link_path, create=False, subdir=True, readonly=True, lock=False)
        self.link_txn = self.link_env.begin(write=False)
        self.sequence_env = lmdb.open(self.sequence_path, create=False, subdir=True, readonly=True, lock=False)
        self.sequence_txn = self.sequence_env.begin(write=False)
        self.link_data_size = int(self.link_txn.get('data_size'.encode()))
        self.link_data_lens = pkl.loads(self.link_txn.get('data_lens'.encode()))

        self.crd_env = lmdb.open(self.crd_path, create=False, subdir=True, readonly=True, lock=False)
        self.crd_txn = self.crd_env.begin(write=False)
        self.crd_data_size = int(self.crd_txn.get('data_size'.encode()).decode())
        self.crd_data_lens = pkl.loads(self.crd_txn.get('data_lens'.encode()))

        self.mlm_env = lmdb.open(self.mlm_path, create=False, subdir=True, readonly=True, lock=False)
        self.mlm_txn = self.mlm_env.begin(write=False)
        self.mlm_data_size = int(self.mlm_txn.get('data_size'.encode()).decode())
        self.mlm_data_lens = pkl.loads(self.mlm_txn.get('data_lens'.encode()))

        
    def __getitem__(self, index):
        if self.current[self.current_idx] == 'ppi':
            index = index % self.link_data_size
            graph = pkl.loads(self.link_txn.get(str(index).encode()))
            # graph = self.link_graph[index]
            sequences = []
            for node in graph.nodes:
                sequences.append(self.sequence_txn.get(node.encode()).decode())
            return graph, sequences
        elif self.current[self.current_idx] == 'crd':
            index = index % self.crd_data_size
            data = pkl.loads(self.crd_txn.get(str(index).encode()))
            # data = self.crd_data[index]
            sequence = data['sequence']
            coord = data['coord']
            return sequence, coord
        elif self.current[self.current_idx] == 'mlm':
            index = index % self.mlm_data_size
            sequence = self.mlm_txn.get(str(index).encode()).decode()
            return sequence

    def __len__(self):
        if not hasattr(self, 'mlm_data_size'):
            self.open_lmdb()
        if self.split == 'valid':
            return 100
        return max(self.link_data_size, self.crd_data_size, self.mlm_data_size)

    def collater(self, raw_batch: Sequence[Dict]):
        if self.current[self.current_idx] == 'ppi':
            graphs, sequences = zip(*raw_batch)
            sequences = list(itertools.chain(*sequences))
            link_targets = []
            for graph in graphs:
                link_targets.append(np.triu(nx.adjacency_matrix(graph).todense()))
            link_targets = torch.tensor(linalg.block_diag(*link_targets) - np.tril(np.ones((len(sequences), len(sequences)), dtype=np.int8)))
            link_tokens = self.converter[self.current_idx](sequences, prompt_toks=self.prompt_toks[self.current_idx])
            sample = (link_tokens, link_targets, self.prompt_toks[self.current_idx]), self.current[self.current_idx]
        elif self.current[self.current_idx] == 'crd':
            sequences, coords = zip(*raw_batch)
            tokens = self.converter[self.current_idx](sequences, prompt_toks=self.prompt_toks[self.current_idx])
            sample = (tokens, coords, self.prompt_toks[self.current_idx]), self.current[self.current_idx]
        elif self.current[self.current_idx] == 'mlm':
            origin_tokens, masked_tokens, masked_targets = self.converter[self.current_idx](raw_batch, prompt_toks=self.prompt_toks[self.current_idx])
            sample = (origin_tokens, masked_tokens, masked_targets, self.prompt_toks[self.current_idx]), self.current[self.current_idx]
        self.current_idx = (self.current_idx + 1) % 3
        return sample

    def size(self, index):
        return 1024
    
    def num_tokens(self, index):
        return 1024

    def num_tokens_vec(self, indices):
        return np.array([self.num_tokens(index) for index in indices])


@dataclass
class FusionTaskConfig(FairseqDataclass):
    data: str = field(default=MISSING)


@register_task("fusion", dataclass=FusionTaskConfig)
class FusionTask(FairseqTask):
    cfg: FusionTaskConfig
    """Task for training masked language models (e.g., BERT, RoBERTa)"""

    def __init__(self, cfg: FusionTaskConfig, alphabet):
        super().__init__(cfg)
        self.alphabet = alphabet

    @classmethod
    def setup_task(cls, cfg: FusionTaskConfig, **kwargs):
        paths = utils.split_paths(cfg.data)
        assert len(paths) > 0
        alphabet = Alphabet.build_alphabet()
        logger.info(f"Alphabet: {len(alphabet)} types")
        return cls(cfg, alphabet)

    def load_dataset(self, split: str, epoch=1, combine=False, **kwargs):
        self.datasets[split] = FusionDataset(self.cfg.data, split, self.alphabet)

    @property
    def source_dictionary(self):
        return self.alphabet

    @property
    def target_dictionary(self):
        return self.alphabet


