import pytorch_lightning as pl
import os, torch, json
import pandas as pd
import csv
from collections import Counter, defaultdict
from functools import partial
from data.io import read_fasta, read_meta
from data.vocab import load_esm_alphabet
from data.datasets import MultiTaskClassificationDataset, PairwiseClassificationDataset, PairwiseAlnClassificationDataset, StructsPairwiseClassificationDataset
# from data.utils import default_lm_collate_func, build_fake_pairwise_dataset, build_fake_multisource_batch_dataset, multiprot_lm_collate_func
# from torch.utils.data import DataLoader, random_split
from data.utils import structure_pairs_collate_func
import numpy as np
from data.data_modules.base_dm import ProteinGISAIDDataModule
from data.data_modules import register_dm
from utils.args import str2bool
from esm.data import BatchConverter, CoordBatchConverter
from esm.inverse_folding.multichain_util import load_structure
from biotite.structure.residues import get_residues
from biotite.sequence import ProteinSequence
from tqdm import tqdm

@register_dm("protein_classifier")
class ProteinClassifierDataModule(ProteinGISAIDDataModule):
    def __init__(self, args, vocab=None):
        super().__init__(args)
        if vocab is not None:
            self.vocab = vocab
        else:
            self.vocab = load_esm_alphabet(self.args.vocab)

    @classmethod
    def add_argparse_args(cls, parent_parser):
        parent_parser = super(ProteinClassifierDataModule, cls).add_argparse_args(parent_parser)
        parent_parser.add_argument('--meta_data_path', type=str, default="", help="Load location information.")
        parent_parser.add_argument('--labels', nargs="+", type=str, default=None)
        parent_parser.add_argument('--binary', type=str2bool, default="true")

        return parent_parser
    
    def setup_properties(self, ):
        properties = self.args.labels
        return properties

    def build_training_set(self, properties, properties_dict):
        dataset = read_fasta(self.args.data_path)
        meta_data = read_meta(self.args.meta_data_path)
        full_dataset = MultiTaskClassificationDataset(dataset, self.vocab, meta_data=meta_data, classification_tasks=self.args.labels, binary=self.args.binary)
        self.setup_properties_dict(full_dataset, properties, properties_dict)
        return full_dataset
    
    def build_predicting_set(self, pred_data_path, properties, properties_dict):
        meta_data = read_meta(self.args.meta_data_path)
        dataset = MultiTaskClassificationDataset(read_fasta(pred_data_path), self.vocab, meta_data=meta_data, classification_tasks=self.args.labels, binary=self.args.binary, predict=True)
        return dataset

    def build_testing_set(self, test_data_path, properties, properties_dict):
        # from collections import Counter
        meta_data = read_meta(self.args.meta_data_path)
        dataset = MultiTaskClassificationDataset(read_fasta(test_data_path), self.vocab, meta_data=meta_data, classification_tasks=self.args.labels, binary=self.args.binary)
        dataset = [x for x in dataset if x["Host"] in properties_dict["Host"]]
        # dataset = [x for x in dataset if x["Host"] != "Human"]
        return dataset 
    
    def load_properties_from_config(self, model_config, properties_dict):
        for label in self.args.labels:
            properties_dict[label] = getattr(model_config, "%s_dict" % label)
        return properties_dict

    # def set_collate_func(self, properties_dict, model_config, *args, **kwargs):
    #     # for label in self.args.labels:
    #     #     # setattr(self, "%s_vocab" % label, getattr(model_config, "%s_vocab" % label))
    #     #     # setattr(self, "%s_vocab" % label, getattr(model_config, "%s_dict" % label))
    #     #     properties_dict[label] = getattr(model_config, "%s_dict" % label)
    #     return super().set_collate_func(properties_dict, model_config, *args, **kwargs)

    def setup_properties_dict(self, full_dataset, properties, properties_dict):
        for label in properties:
            vocab = list(set([x[label] for x in full_dataset]))
            label2index = {loc: idx for idx, loc in enumerate(vocab)}
            if not (len(vocab) == 1 and vocab[0] is None):
                properties_dict[label] = label2index
                setattr(self.args, "%s_vocab" % label, vocab)
                setattr(self.args, "%s_dict" % label, label2index)

@register_dm("hi_regression")
class PairwiseRegressionDataModule(ProteinGISAIDDataModule):
    def __init__(self, args, vocab=None):
        super().__init__(args)
        if vocab is not None:
            self.vocab = vocab
        else:
            self.vocab = load_esm_alphabet(self.args.vocab)

    @classmethod
    def add_argparse_args(cls, parent_parser):
        parent_parser = super(PairwiseRegressionDataModule, cls).add_argparse_args(parent_parser)
        parent_parser.add_argument('--train_index_path', type=str, default=None)
        parent_parser.add_argument('--train_loss_weight_path', type=str, default=None)
        parent_parser.add_argument('--valid_index_path', type=str, default=None)
        parent_parser.add_argument('--valid_loss_weight_path', type=str, default=None)
        parent_parser.add_argument('--test_index_path', type=str, default=None)
        parent_parser.add_argument('--test_loss_weight_path', type=str, default=None)
        parent_parser.add_argument('--predict_index_path', type=str, default=None)
        parent_parser.add_argument('--category', type=str2bool, default="true")
        parent_parser.add_argument('--numerical', type=str2bool, default="false")
        parent_parser.add_argument('--numerical_interval', type=float, default=1.0)
        

        parent_parser.add_argument('--virus_id_col_name', type=str, default="virus")
        parent_parser.add_argument('--vaccine_id_col_name', type=str, default="reference")
        parent_parser.add_argument('--virus_seq_col_name', type=str, default="virus_seq")
        parent_parser.add_argument('--vaccine_seq_col_name', type=str, default="reference_seq")
        parent_parser.add_argument('--value_col_name', type=str, default="hi")

        parent_parser.add_argument('--prepend_special_token_for_vaccine', type=str, default=None, help="prepend special tokens for vaccine strains.")
        parent_parser.add_argument('--prepend_special_token_for_virus', type=str, default=None, help="prepend special tokens for vaccine strains.")
        return parent_parser
    
    def setup_properties(self, ):
        # return ["seq_label", "ref_seq_label"]
        return "label"
    
    def read_csv(self, path):
        data = []
        df = pd.read_csv(path)
        for virus_id, vaccine_id, virus_seq, vaccine_seq, value in zip(df[self.args.virus_id_col_name], df[self.args.vaccine_id_col_name], df[self.args.virus_seq_col_name], df[self.args.vaccine_seq_col_name], df[self.args.value_col_name]):                
            data.append((virus_id, vaccine_id, virus_seq, vaccine_seq, value))

        # data = []
        # with open(path) as csvfile:
        #     spamreader = csv.reader(csvfile)
        #     for i, row in enumerate(spamreader):
        #         if i == 0:
        #             headline = row
        #             continue
        #         data.append(row)
        return data

    def build_training_set(self, properties, properties_dict):      
        dataset = read_fasta(self.args.data_path)
        train_index_data = self.read_csv(self.args.train_index_path)
        train_dataset = PairwiseClassificationDataset(dataset, self.vocab, index_data=train_index_data, category=self.args.category)
        valid_index_data = self.read_csv(self.args.valid_index_path)
        valid_dataset = PairwiseClassificationDataset(dataset, self.vocab, index_data=valid_index_data, category=self.args.category)
        self.setup_properties_dict(train_dataset, properties_dict)
        return train_dataset, valid_dataset
        
    def build_predict_datasets(self, properties, properties_dict):
        dataset = read_fasta(self.args.data_path)
        pred_index_data = self.read_csv(self.args.predict_index_path)
        pred_dataset = PairwiseClassificationDataset(dataset, self.vocab, index_data=pred_index_data, category=self.args.category)
        self.predict_datasets = [pred_dataset]

    def build_test_datasets(self, properties, properties_dict, *args, **argv):
        dataset = read_fasta(self.args.data_path)
        test_index_data = self.read_csv(self.args.test_index_path)
        test_dataset = PairwiseClassificationDataset(dataset, self.vocab, index_data=test_index_data, category=self.args.category)
        self.test_datasets = [test_dataset]
    
    def load_properties_from_config(self, model_config, properties_dict):
        if self.args.category:
            properties_dict["label"] = getattr(model_config, "label_dict")
        return properties_dict

    # def set_collate_func(self, properties_dict, model_config, *args, **kwargs):
    #     # for label in self.args.labels:
    #     #     # setattr(self, "%s_vocab" % label, getattr(model_config, "%s_vocab" % label))
    #     #     # setattr(self, "%s_vocab" % label, getattr(model_config, "%s_dict" % label))
    #     #     properties_dict[label] = getattr(model_config, "%s_dict" % label)
    #     return super().set_collate_func(properties_dict, model_config, *args, **kwargs)

    def setup_properties_dict(self, train_dataset, properties_dict):
        if self.args.category or self.args.numerical:
            counter = Counter([x["label"] for x in train_dataset]).most_common()
            print(counter)
            vocab = list(set([x["label"] for x in train_dataset]))
            vocab.sort()
            label2index = {x: idx for idx, x in enumerate(vocab)}
            if not (len(vocab) == 1 and vocab[0] is None):
                properties_dict["label"] = label2index
                setattr(self.args, "label_vocab", vocab)
                setattr(self.args, "label_dict", label2index)
        setattr(self.args, "labels", ["label"])

@register_dm("hi_regression_aln")
class PairwiseRegressionAlnDataModule(PairwiseRegressionDataModule):
    def __init__(self, args, vocab=None):
        super().__init__(args, vocab)

    # @classmethod
    # def add_argparse_args(cls, parent_parser):
    #     parent_parser = super(PairwiseRegressionAlnDataModule, cls).add_argparse_args(parent_parser)
    #     # parent_parser.add_argument('--use_virus_msa', type=str2bool, default="false")
    #     # parent_parser.add_argument('--use_vaccine_msa', type=str2bool, default="false")
    #     return parent_parser

    def build_training_set(self, properties, properties_dict):
        train_data = self.read_csv(self.args.train_index_path)
        
        if self.args.train_loss_weight_path is not None:
            train_loss_weights = pd.read_csv(self.args.train_loss_weight_path)["loss_weight"]
        else:
            train_loss_weights = None
        
        train_dataset = PairwiseAlnClassificationDataset(train_data, self.vocab, category=self.args.category, \
            prepend_special_token_for_seq1=self.args.prepend_special_token_for_virus,
            prepend_special_token_for_seq2=self.args.prepend_special_token_for_vaccine,
            numerical_interval=self.args.numerical_interval, numerical=self.args.numerical,
            loss_weights=train_loss_weights
            )
            
        if self.args.valid_loss_weight_path is not None:
            valid_loss_weights = pd.read_csv(self.args.valid_loss_weight_path)["loss_weight"]
        else:
            valid_loss_weights = None
        
        valid_data = self.read_csv(self.args.valid_index_path)
        valid_dataset = PairwiseAlnClassificationDataset(valid_data, self.vocab, category=self.args.category, \
            prepend_special_token_for_seq1=self.args.prepend_special_token_for_virus,
            prepend_special_token_for_seq2=self.args.prepend_special_token_for_vaccine,
            numerical_interval=self.args.numerical_interval, numerical=self.args.numerical,
            loss_weights=valid_loss_weights
            )
        self.setup_properties_dict(train_dataset, properties_dict)
        return train_dataset, valid_dataset
    
    # def build_predicting_set(self, pred_data_path, properties, properties_dict):
    #     meta_data = read_meta(self.args.meta_data_path)
    #     dataset = MultiTaskClassificationDataset(read_fasta(pred_data_path), self.vocab, meta_data=meta_data, classification_tasks=self.args.labels, binary=self.args.binary, predict=True)
    #     return dataset

    def build_test_datasets(self, properties, properties_dict, *args, **argv):
        test_data = self.read_csv(self.args.test_index_path)
        if self.args.test_loss_weight_path is not None:
            test_loss_weights = pd.read_csv(self.args.test_loss_weight_path)["loss_weight"]
        else:
            test_loss_weights = None
        test_dataset = PairwiseAlnClassificationDataset(test_data, self.vocab, category=self.args.category,
            prepend_special_token_for_seq1=self.args.prepend_special_token_for_virus,
            prepend_special_token_for_seq2=self.args.prepend_special_token_for_vaccine,
            numerical_interval=self.args.numerical_interval, 
            numerical=self.args.numerical, 
            loss_weights=test_loss_weights)
        self.test_datasets = [test_dataset]
    
    def build_predict_datasets(self, properties, properties_dict, *args, **argv):
        # print(self.args.predict_index_path)
        pred_data = self.read_csv(self.args.predict_index_path)
        pred_dataset = PairwiseAlnClassificationDataset(pred_data, self.vocab, category=self.args.category,
            prepend_special_token_for_seq1=self.args.prepend_special_token_for_virus,
            prepend_special_token_for_seq2=self.args.prepend_special_token_for_vaccine,
            numerical_interval=self.args.numerical_interval, numerical=self.args.numerical)
        # print(pred_dataset[0])
        # print(pred_data[0])
        # print(len(pred_dataset[0]["src_seq"]))
        # print(self.args.virus_id_col_name)
        self.pred_datasets = [pred_dataset]


@register_dm("hi_regression_structs")
class StructsPairwiseRegressionDataModule(PairwiseRegressionDataModule):
    def __init__(self, args, vocab=None):
        super().__init__(args, vocab)

    @classmethod
    def add_argparse_args(cls, parent_parser):
        parent_parser = super(StructsPairwiseRegressionDataModule, cls).add_argparse_args(parent_parser)
        parent_parser.add_argument('--pdb_dir', type=str, default=None, help="where to find my pdb files. xxx.pdb")
        return parent_parser

    def build_seq2pdb(self, ):
        all_sequence_path = read_fasta(self.args.data_path)
        epiid_to_seq = {}
        # seq_to_epiids = defaultdict(list)
        for _, seq, desc in all_sequence_path:
            epi_id = None
            for x in desc.split("|"):
                if "EPI" in x and "EPI_ISL" not in x:
                    epi_id = x
            if epi_id is not None:
                epiid_to_seq[epi_id] = seq
                # seq_to_epiids[seq].append(epi_id)
        # print(len(seq_to_epiids))

        seq2pdb = {}
        for root, dirs, files in os.walk(self.args.pdb_dir, topdown=False):
            for file in files:
                if file.endswith(".pdb"):
                    # If we have enough time, we should do this
                    # structure = load_structure(os.path.join(root, file), "A")
                    # residue_identities = get_residues(structure)[1]
                    # seq = ''.join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities])
                    # print(seq)

                    epi_id = os.path.split(file)[-1].split(".pdb")[0]
                    epi_id = epi_id.split("#")
                    if len(epi_id) > 1:
                        epi_id = epi_id[1]
                    else:
                        epi_id = epi_id[0]
                    seq = epiid_to_seq[epi_id]
                    seq2pdb[seq] = os.path.join(root, file)
                    # id2pdb[os.path.split(file)[-1].split(".pdb")[0]] = 

        return seq2pdb, epiid_to_seq 

    
    def set_collate_func(self, properties_dict, *args, **kwargs):
        self.collate_fn = partial(structure_pairs_collate_func, \
                batch_converter=CoordBatchConverter(self.vocab), # self.vocab.get_batch_converter(max_positions=self.args.max_position_embeddings), 
                padding_idx=self.vocab.pad(), properties_dict=properties_dict, #  remove_gaps=self.args.remove_gaps
            )

    def setup_properties_dict(self, train_dataset, properties_dict):
        counter = train_dataset.count_labels() # dict, key = label_name, value = how many samples for each label
        vocab = list(counter.keys())
        label2index = {x: idx for idx, x in enumerate(vocab)}
        if not (len(vocab) == 1 and vocab[0] is None):
            properties_dict["label"] = label2index
            setattr(self.args, "label_vocab", vocab)
            setattr(self.args, "label_dict", label2index)
        setattr(self.args, "labels", ["label"])


    def get_pdbs(self, dataset, seq2pdb, epiid_to_seq):
        new_dataset = []
        fail_to_find_pdb = set()
        for data in dataset:
            id1, id2 = data[0], data[1]
            # print(epiid_to_seq[id1])
            # print(epiid_to_seq[id2])
            pdb1, pdb2 = seq2pdb.get(epiid_to_seq[id1], None), seq2pdb.get(epiid_to_seq[id2], None)
            # print(pdb1,  pdb2)
            if pdb1 is not None and pdb2 is not None:
                new_dataset.append([pdb1, pdb2] + data[2:])
                # for debug
                # new_dataset.append([pdb1, pdb2] + data[2:-1] + [epiid_to_seq[id1].replace("X", ""), epiid_to_seq[id2].replace("X", "")] + data[-1:])
            else:
                if epiid_to_seq[id1] not in seq2pdb:
                    fail_to_find_pdb.add(id1)
                if epiid_to_seq[id2] not in seq2pdb:
                    fail_to_find_pdb.add(id2)
        print("Cannot find the pdb structures for:", ",".join(list(fail_to_find_pdb)))
        return new_dataset

    def build_test_datasets(self, properties, properties_dict, *args, **argv):
        seq2pdb, epiid_to_seq = self.build_seq2pdb()
        test_data = self.read_csv(self.args.test_index_path)
        print(len(test_data))
        # print(test_data[0])
        # print
        # print(epiid_to_seq)
        test_data = self.get_pdbs(test_data, seq2pdb, epiid_to_seq)
        print(len(test_data), len(seq2pdb), len(epiid_to_seq))

        test_dataset = StructsPairwiseClassificationDataset(test_data, self.vocab)
        self.test_datasets = [test_dataset]

    def build_training_set(self, properties, properties_dict):
        # train_data = self.read_csv(self.args.train_index_path)
        # train_dataset = PairwiseAlnClassificationDataset(train_data, self.vocab)
        # valid_data = self.read_csv(self.args.valid_index_path)
        # valid_dataset = PairwiseAlnClassificationDataset(valid_data, self.vocab)
        # self.setup_properties_dict(train_dataset, properties_dict)

        seq2pdb, epiid_to_seq = self.build_seq2pdb()
        train_data = self.read_csv(self.args.train_index_path)
        print(len(train_data))
        train_data = self.get_pdbs(train_data, seq2pdb, epiid_to_seq)
        print(len(train_data))
        train_dataset = StructsPairwiseClassificationDataset(train_data, self.vocab)
        # print(train_dataset[0])

        valid_data = self.read_csv(self.args.valid_index_path)
        valid_data = self.get_pdbs(valid_data, seq2pdb, epiid_to_seq)
        valid_dataset = StructsPairwiseClassificationDataset(valid_data, self.vocab)

        self.setup_properties_dict(train_dataset, properties_dict)
        return train_dataset, valid_dataset

    