import os
import csv
import lmdb
import logging
import numpy as np
import pickle as pkl
import networkx as nx
import itertools
from scipy import linalg
from typing import Sequence, Dict
from dataclasses import dataclass, field
from omegaconf import MISSING
import torch
from fairseq import utils
from fairseq.data import FairseqDataset
from fairseq.dataclass import FairseqDataclass
from fairseq.tasks import FairseqTask, register_task

from ..utils import Alphabet, set_cpu_num
from ._data_process import PromptConvert

logger = logging.getLogger(__name__)

def load_GO_annot(filename):
    # Load GO annotations
    onts = ['mf', 'bp', 'cc']
    prot2annot = {}
    goterms = {ont: [] for ont in onts}
    gonames = {ont: [] for ont in onts}

    with open(os.path.join(filename, 'go/nrPDB-GO_2019.06.18_annot.tsv'), mode='r') as tsvfile:
        reader = csv.reader(tsvfile, delimiter='\t')

        # molecular function
        next(reader, None)  # skip the headers
        goterms[onts[0]] = next(reader)
        next(reader, None)  # skip the headers
        gonames[onts[0]] = next(reader)

        # biological process
        next(reader, None)  # skip the headers
        goterms[onts[1]] = next(reader)
        next(reader, None)  # skip the headers
        gonames[onts[1]] = next(reader)

        # cellular component
        next(reader, None)  # skip the headers
        goterms[onts[2]] = next(reader)
        next(reader, None)  # skip the headers
        gonames[onts[2]] = next(reader)

        next(reader, None)  # skip the headers
        counts = {ont: np.zeros(len(goterms[ont]), dtype=float) for ont in onts}
        for row in reader:
            prot, prot_goterms = row[0], row[1:]
            prot2annot[prot] = {ont: [] for ont in onts}
            for i in range(3):
                goterm_indices = [goterms[onts[i]].index(goterm) for goterm in prot_goterms[i].split(',') if goterm != '']
                prot2annot[prot][onts[i]] = np.zeros(len(goterms[onts[i]]))
                prot2annot[prot][onts[i]][goterm_indices] = 1.0
                counts[onts[i]][goterm_indices] += 1.0
    return prot2annot, goterms, gonames, counts


def load_name(filename, split):
    name = []
    if split == 'valid':
        split = 'test'
    with open(os.path.join(filename, f'go/nrPDB-GO_2019.06.18_{split}.txt'), 'r') as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=',')
        next(csv_reader)
        for row in csv_reader:
            name.append(row[0])
    return name


def load_name2sequence(filename):
    name2sequence = {}
    with open(os.path.join(filename, f'go/nrPDB-GO_2019.06.18_sequences.fasta'), 'r') as f:
        for line in f:
            if line.startswith('>'):
                name = line.replace('>', '').split()[0]
                name2sequence[name] = ''
            else:
                name2sequence[name] += line.replace('\n', '').strip()
    return name2sequence


class GOFunctionDataset(FairseqDataset):
    def __init__(self, data, alphabet) -> None:
        super().__init__()
        self.data = data
        self.alphebet = alphabet
        self.batch_converter = PromptConvert(alphabet)
        self.prompt_toks = ['<seq>', '<crd>', '<ppi>']
        print(self.prompt_toks)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

    def collater(self, raw_batch: Sequence[Dict]):
        sequences, labels = zip(*raw_batch)
        tokens = self.batch_converter(sequences, prompt_toks=self.prompt_toks)
        return tokens, torch.tensor(np.array(labels), dtype=torch.int64), self.prompt_toks

    def size(self, index):
        return len(self.data[index][0])
    
    def num_tokens(self, index):
        return len(self.data[index][0])

    def num_tokens_vec(self, indices):
        return np.array([self.num_tokens(index) for index in indices])


@dataclass
class GOFunctionTaskConfig(FairseqDataclass):
    data: str = field(default=MISSING)
    task_type: str = field(default='bp')


@register_task("sequence_multiclass_go", dataclass=GOFunctionTaskConfig)
class GOFunctionTask(FairseqTask):
    cfg: GOFunctionTaskConfig
    """Task for training masked language models (e.g., BERT, RoBERTa)"""

    def __init__(self, cfg: GOFunctionTaskConfig, data_path, name2annot, name2sequence, class_num, alphabet):
        super().__init__(cfg)
        self.data_path = data_path
        self.alphabet = alphabet
        self.name2annot = name2annot
        self.name2sequence = name2sequence
        self.class_num = class_num
        self.best_fmax = 0
        self.micro_aupr = 0


    @classmethod
    def setup_task(cls, cfg: GOFunctionTaskConfig, **kwargs):
        set_cpu_num(4)
        paths = utils.split_paths(cfg.data)
        assert len(paths) > 0
        alphabet = Alphabet.build_alphabet()
        logger.info(f"Alphabet: {len(alphabet)} types")
        data_path = os.path.join(cfg.data, 'go')
        name2annot, goterms, gonames, counts = load_GO_annot(data_path)
        name2sequence = load_name2sequence(data_path)
        class_num = len(goterms[cfg.task_type])

        return cls(cfg, data_path, name2annot, name2sequence, class_num, alphabet)

    def load_dataset(self, split: str, epoch=1, combine=False, **kwargs):
        names = load_name(self.data_path, split)
        data = [[self.name2sequence[name], self.name2annot[name][self.cfg.task_type]] for name in names]

        self.datasets[split] = GOFunctionDataset(data, self.alphabet)


    @property
    def source_dictionary(self):
        return self.alphabet

    @property
    def target_dictionary(self):
        return self.alphabet
