import copy
from functools import partial
import json
import logging
import os
import pickle
from typing import Optional, Sequence, List, Any

import ml_collections as mlc
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import RandomSampler

from myopenfold.data import (
    data_pipeline,
    feature_pipeline,
    mmcif_parsing,
    templates,
)
from myopenfold.utils.tensor_utils import tensor_tree_map, dict_multimap

import logging

logger = logging.getLogger("myopenfold.data.data_modules")


class OpenFoldSingleDataset(torch.utils.data.Dataset):
    def __init__(self,
        data_dir: str,
        alignment_dir: str, 
        config: mlc.ConfigDict,
        template_mmcif_dir: Optional[str] = None,
        max_template_date: Optional[str] = None,
        chain_data_cache_path: Optional[str] = None,
        kalign_binary_path: str = '/usr/bin/kalign',
        max_template_hits: int = 4,
        obsolete_pdbs_file_path: Optional[str] = None,
        template_release_dates_cache_path: Optional[str] = None,
        shuffle_top_k_prefiltered: Optional[int] = None,
        treat_pdb_as_distillation: bool = True,
        filter_path: Optional[str] = None,
        mode: str = "train", 
        alignment_index: Optional[Any] = None,
        _output_raw: bool = False,
        _structure_index: Optional[Any] = None,
        rank = None,
        epoch = 0,
        base_seed = 280421310721,
        subsample_index = None,
        exclude_index = None,
        **kwargs,
    ):
        """
            Args:
                data_dir:
                    A path to a directory containing mmCIF files (in train
                    mode) or FASTA files (in inference mode).
                alignment_dir:
                    A path to a directory containing only data in the format 
                    output by an AlignmentRunner 
                    (defined in openfold.features.alignment_runner).
                    I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
                    or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
                    files.
                template_mmcif_dir:
                    Path to a directory containing template mmCIF files.
                config:
                    A dataset config object. See openfold.config
                chain_data_cache_path:
                    Path to cache of data_dir generated by
                    scripts/generate_chain_data_cache.py
                kalign_binary_path:
                    Path to kalign binary.
                max_template_hits:
                    An upper bound on how many templates are considered. During
                    training, the templates ultimately used are subsampled
                    from this total quantity.
                template_release_dates_cache_path:
                    Path to the output of scripts/generate_mmcif_cache.
                obsolete_pdbs_file_path:
                    Path to the file containing replacements for obsolete PDBs.
                shuffle_top_k_prefiltered:
                    Whether to uniformly shuffle the top k template hits before
                    parsing max_template_hits of them. Can be used to
                    approximate DeepMind's training-time template subsampling
                    scheme much more performantly.
                treat_pdb_as_distillation:
                    Whether to assume that .pdb files in the data_dir are from
                    the self-distillation set (and should be subjected to
                    special distillation set preprocessing steps).
                mode:
                    "train", "val", or "predict"
        """
        super(OpenFoldSingleDataset, self).__init__()
        self.data_dir = data_dir

        self.chain_data_cache = None
        if chain_data_cache_path is not None:
            with open(chain_data_cache_path, "r") as fp:
                self.chain_data_cache = json.load(fp)
            assert isinstance(self.chain_data_cache, dict)

        self.alignment_dir = alignment_dir
        self.config = config
        self.treat_pdb_as_distillation = treat_pdb_as_distillation
        self.mode = mode
        self.alignment_index = alignment_index
        self._output_raw = _output_raw
        self._structure_index = _structure_index

        self.supported_exts = [".cif", ".core", ".pdb"]

        self.rank = rank
        self.epoch = epoch
        self.base_seed = base_seed

        valid_modes = ["train", "eval", "predict"]
        if(mode not in valid_modes):
            raise ValueError(f'mode must be one of {valid_modes}')

        # if(template_release_dates_cache_path is None):
        #     logging.warning(
        #         "Template release dates cache does not exist. Remember to run "
        #         "scripts/generate_mmcif_cache.py before running OpenFold"
        #     )

        #. get chain ids from alignment_index or alignment_dir
        if(alignment_index is not None):
            self._chain_ids = list(alignment_index.keys())
        else:
            self._chain_ids = list(os.listdir(alignment_dir))

        #. if filter_path is not None, filter chain ids accordingly
        if(filter_path is not None):
            with open(filter_path, "r") as f:
                chains_to_include = set([l.strip() for l in f.readlines()])

            self._chain_ids = [
                c for c in self._chain_ids if c in chains_to_include
            ]

        #. if chain_data_cache is not None
        if self.chain_data_cache is not None:
            # Filter to include only chains where we have structure data
            # (entries in chain_data_cache)
            original_chain_ids = self._chain_ids
            self._chain_ids = [
                c for c in self._chain_ids if c in self.chain_data_cache or c.split("_")[0] in self.chain_data_cache
            ]
            if len(self._chain_ids) < len(original_chain_ids):
                missing = [
                    c for c in original_chain_ids
                    if c not in self.chain_data_cache
                ]
                max_to_print = 100
                missing_examples = ", ".join(missing[:max_to_print])
                if len(missing) > max_to_print:
                    missing_examples += ", ..."
                logging.warning(
                    "Removing %d alignment entries (%s) with no corresponding "
                    "entries in chain_data_cache (%s).",
                    len(missing),
                    missing_examples,
                    chain_data_cache_path)

        # NOTE add this to allow subsampling for dataset
        self.subsample_index = subsample_index
        self.exclude_index = exclude_index
        if subsample_index is not None and exclude_index is not None:
            raise ValueError('Cannot specify both `subsample_index` and `exclude_index`')
        if subsample_index is not None:
            logger.info('Subsampling chain index according to proveded `subsample_index`..')
            self._original_chain_ids = self._chain_ids
            self._chain_ids = [self._chain_ids[i] for i in subsample_index]
            self.subsample_index_dict = {i: subsample_index[i] for i in range(len(subsample_index))}
        elif exclude_index is not None:
            logger.info('Subsampling chain index according to proveded `exclude_index`..')
            self._original_chain_ids = self._chain_ids
            subsample_index = [i for i in range(len(self._original_chain_ids)) if i not in exclude_index]
            self._chain_ids = [self._chain_ids[i] for i in subsample_index]
            self.subsample_index_dict = {i: subsample_index[i] for i in range(len(subsample_index))}
        else:
            self.subsample_index_dict = None

        #. create a dictionary mapping chain ids to indices
        self._chain_id_to_idx_dict = {
            chain: i for i, chain in enumerate(self._chain_ids)
        }

        #. initialize a template featurizer
        #. add a condition to bypass template featurizer initialization
        if template_mmcif_dir is None:
            template_featurizer = None
        else:
            template_featurizer = templates.TemplateHitFeaturizer(
                mmcif_dir=template_mmcif_dir,
                max_template_date=max_template_date,
                max_hits=max_template_hits,
                kalign_binary_path=kalign_binary_path,
                release_dates_path=template_release_dates_cache_path,
                obsolete_pdbs_path=obsolete_pdbs_file_path,
                _shuffle_top_k_prefiltered=shuffle_top_k_prefiltered,
            )

        #. initialize a data pipeline, the template featurizer is passed to it
        self.data_pipeline = data_pipeline.DataPipeline(
            template_featurizer=template_featurizer,
        )

        #. initialize a feature pipeline
        if(not self._output_raw):
            self.feature_pipeline = feature_pipeline.FeaturePipeline(config)

        self.generator = torch.Generator()

        self.shuffle_seed = 280421310721


    def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index):
        with open(path, 'r') as f:
            mmcif_string = f.read()

        #. parse mmcif info
        mmcif_object = mmcif_parsing.parse(
            file_id=file_id, mmcif_string=mmcif_string
        )

        # Crash if an error is encountered. Any parsing errors should have
        # been dealt with at the alignment stage.
        if(mmcif_object.mmcif_object is None):
            raise list(mmcif_object.errors.values())[0]

        mmcif_object = mmcif_object.mmcif_object

        #. process by data pipeline
        data = self.data_pipeline.process_mmcif(
            mmcif=mmcif_object,
            alignment_dir=alignment_dir,
            chain_id=chain_id,
            alignment_index=alignment_index
        )

        return data

    def chain_id_to_idx(self, chain_id):
        return self._chain_id_to_idx_dict[chain_id]

    def idx_to_chain_id(self, idx):
        return self._chain_ids[idx]

    def set_shuffle_seed(self, seed):
        self.shuffle_seed = seed

    def _shuffle_msa(self, data):
        if 'msa' in data:
            num_seq = data["msa"].shape[0]
            g = np.random.default_rng(self.shuffle_seed)

            shuffled = g.permutation(np.arange(num_seq - 1)) + 1
            index_order = np.concatenate([[0], shuffled])
            data['msa'] = data['msa'][index_order]
            return data
        else:
            logger.warning("No msa found in data, skip shuffling")
            return data

    def __getitem__(self, idx):
        worker_info = torch.utils.data.get_worker_info()
        # logger.info("Rank %s, Worker %s : Getting item %d and set seed to %d.." % (str(self.rank), str(worker_info.id) if worker_info is not None else 'None', idx, self.epoch + idx))  # DEBUG
        #. name : pdb id and chain id
        
        name = self.idx_to_chain_id(idx)

        # NOTE 230425 add this to ensure same seeded for each data, even when subsampling is enbaled
        if self.subsample_index_dict is not None:
            idx = self.subsample_index_dict[idx]
        torch.manual_seed(self.base_seed + self.epoch + idx)

        #. get alignment dir, which is the parent_dir + name
        alignment_dir = os.path.join(self.alignment_dir, name)

        #. this is some contradictory code. alignment_index is always None.
        alignment_index = None
        if(self.alignment_index is not None):
            alignment_dir = self.alignment_dir
            alignment_index = self.alignment_index[name]

        #. get pdb id and chain id
        if(self.mode == 'train' or self.mode == 'eval'):
            spl = name.rsplit('_', 1)
            if(len(spl) == 2):
                file_id, chain_id = spl
            else:
                file_id, = spl
                chain_id = None

            #. get file path by pdb id
            path = os.path.join(self.data_dir, file_id)
            structure_index_entry = None
            #. get file extension
            if(self._structure_index is not None):
                structure_index_entry = self._structure_index[name]
                assert(len(structure_index_entry["files"]) == 1)
                filename, _, _ = structure_index_entry["files"][0]
                ext = os.path.splitext(filename)[1]
            else:
                ext = None
                for e in self.supported_exts:
                    if(os.path.exists(path + e)):
                        ext = e
                        break

                if(ext is None):
                    raise ValueError("Invalid file type")

            #. parse file
            path += ext
            if(ext == ".cif"):
                data = self._parse_mmcif(
                    path, file_id, chain_id, alignment_dir, alignment_index,
                )
            elif(ext == ".core"):
                data = self.data_pipeline.process_core(
                    path, alignment_dir, alignment_index,
                )
            elif(ext == ".pdb"):
                structure_index = None
                if(self._structure_index is not None):
                    structure_index = self._structure_index[name]
                data = self.data_pipeline.process_pdb(
                    pdb_path=path,
                    alignment_dir=alignment_dir,
                    is_distillation=self.treat_pdb_as_distillation,
                    chain_id=chain_id,
                    alignment_index=alignment_index,
                    _structure_index=structure_index,
                )
            else:
                raise ValueError("Extension branch missing") 
        else:
            path = os.path.join(name, name + ".fasta")
            data = self.data_pipeline.process_fasta(
                fasta_path=path,
                alignment_dir=alignment_dir,
                alignment_index=alignment_index,
            )

        if self.config.common.sample_msa.pre_shuffled:
            # print("Pre-shuffling MSA")
            data = self._shuffle_msa(data)

        #. if output raw, return data
        if(self._output_raw):
            return data
        #. else, process by feature pipeline
        feats = self.feature_pipeline.process_features(
            data, self.mode 
        )

        #. [batch_idx] * N_res
        # if self.subsample_index is not None:
        #     feats['batch_idx'] = torch.tensor([self.subsample_index_dict[idx] for _ in range(feats["aatype"].shape[-1])], dtype=torch.int64, device=feats["aatype"].device)
        # else:
        feats["batch_idx"] = torch.tensor([idx for _ in range(feats["aatype"].shape[-1])], dtype=torch.int64, device=feats["aatype"].device)

        return feats

    def __len__(self):
        return len(self._chain_ids) 


class OpenFoldBatchCollator:
    def __call__(self, prots):
        stack_fn = partial(torch.stack, dim=0)
        return dict_multimap(stack_fn, prots) 


class OpenFoldDataLoader(torch.utils.data.DataLoader):
    def __init__(self, *args, config, stage="train", seed=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.config = config
        self.stage = stage
        self.seed = seed

        generator = torch.Generator()
        generator.manual_seed(self.seed)
        # print("In dataloader init", self.seed, torch.rand(20, generator=generator))
        
        self.generator = generator
        self._prep_batch_properties_probs()

    #. should have some legacy redundancy here
    def _prep_batch_properties_probs(self):
        keyed_probs = []
        stage_cfg = self.config[self.stage]

        max_iters = self.config.common.max_recycling_iters
        
        #. by default, uniform recycling only if training
        #. list of <max_iters+1> prob
        if(stage_cfg.uniform_recycling):
            recycling_probs = [
                1. / (max_iters + 1) for _ in range(max_iters + 1)
            ]
        else:
            recycling_probs = [
                0. for _ in range(max_iters + 1)
            ]
            recycling_probs[-1] = 1.
        
        #. tuple of (key, list of <max_iters+1> prob)
        keyed_probs.append(
            ("no_recycling_iters", recycling_probs)
        )


        keys, probs = zip(*keyed_probs)
        max_len = max([len(p) for p in probs])
        padding = [[0.] * (max_len - len(p)) for p in probs] 
        
        self.prop_keys = keys
        #. [1, max_iters+1]
        self.prop_probs_tensor = torch.tensor(
            [p + pad for p, pad in zip(probs, padding)],
            dtype=torch.float32,
        )

    def _add_batch_properties(self, batch):
        # print(batch["aatype"].shape)

        #. sample the number of recycling iterations for the current batch
        #. [1, 1]
        samples = torch.multinomial(
            self.prop_probs_tensor,
            num_samples=1, # 1 per row
            replacement=True,
            generator=self.generator
        )

        aatype = batch["aatype"]
        batch_dims = aatype.shape[:-2]
        recycling_dim = aatype.shape[-1]
        no_recycling = recycling_dim
        for i, key in enumerate(self.prop_keys):
            #. the number of recycling iterations for the current batch
            sample = int(samples[i][0])
            sample_tensor = torch.tensor(
                sample, 
                device=aatype.device, 
                requires_grad=False
            )
            orig_shape = sample_tensor.shape

            #. sample_tensor [..., recycling_dim]
            sample_tensor = sample_tensor.view(
                (1,) * len(batch_dims) + sample_tensor.shape + (1,)
            )
            sample_tensor = sample_tensor.expand(
                batch_dims + orig_shape + (recycling_dim,)
            )
            batch[key] = sample_tensor

            if(key == "no_recycling_iters"):
                no_recycling = sample 
        
        #. truncate the recycling dimension to the number of recycling iterations this batch
        resample_recycling = lambda t: t[..., :no_recycling + 1]
        batch = tensor_tree_map(resample_recycling, batch)

        return batch

    def __iter__(self):
        it = super().__iter__()

        def _batch_prop_gen(iterator):
            for batch in iterator:
                yield self._add_batch_properties(batch)

        return _batch_prop_gen(it)


class OpenFoldDataModule(pl.LightningDataModule):
    def __init__(self,
        config: mlc.ConfigDict,
        template_mmcif_dir: str,
        max_template_date: str,
        train_data_dir: Optional[str] = None,
        train_alignment_dir: Optional[str] = None,
        train_chain_data_cache_path: Optional[str] = None,
        distillation_data_dir: Optional[str] = None,
        distillation_alignment_dir: Optional[str] = None,
        distillation_chain_data_cache_path: Optional[str] = None,
        val_data_dir: Optional[str] = None,
        val_alignment_dir: Optional[str] = None,
        predict_data_dir: Optional[str] = None,
        predict_alignment_dir: Optional[str] = None,
        kalign_binary_path: str = '/usr/bin/kalign',
        train_filter_path: Optional[str] = None,
        distillation_filter_path: Optional[str] = None,
        obsolete_pdbs_file_path: Optional[str] = None,
        template_release_dates_cache_path: Optional[str] = None,
        batch_seed: Optional[int] = None,
        train_epoch_len: int = 50000, 
        _distillation_structure_index_path: Optional[str] = None,
        alignment_index_path: Optional[str] = None,
        distillation_alignment_index_path: Optional[str] = None,
        **kwargs
    ):
        super(OpenFoldDataModule, self).__init__()

        self.config = config
        self.template_mmcif_dir = template_mmcif_dir
        self.max_template_date = max_template_date
        self.train_data_dir = train_data_dir
        self.train_alignment_dir = train_alignment_dir
        self.train_chain_data_cache_path = train_chain_data_cache_path
        self.distillation_data_dir = distillation_data_dir
        self.distillation_alignment_dir = distillation_alignment_dir
        self.distillation_chain_data_cache_path = (
            distillation_chain_data_cache_path
        )
        self.val_data_dir = val_data_dir
        self.val_alignment_dir = val_alignment_dir
        self.predict_data_dir = predict_data_dir
        self.predict_alignment_dir = predict_alignment_dir
        self.kalign_binary_path = kalign_binary_path
        self.train_filter_path = train_filter_path
        self.distillation_filter_path = distillation_filter_path
        self.template_release_dates_cache_path = (
            template_release_dates_cache_path
        )
        self.obsolete_pdbs_file_path = obsolete_pdbs_file_path
        self.batch_seed = batch_seed
        self.train_epoch_len = train_epoch_len

        if(self.train_data_dir is None and self.predict_data_dir is None):
            raise ValueError(
                'At least one of train_data_dir or predict_data_dir must be '
                'specified'
            )

        self.training_mode = self.train_data_dir is not None

        if(self.training_mode and train_alignment_dir is None):
            raise ValueError(
                'In training mode, train_alignment_dir must be specified'
            )
        elif(not self.training_mode and predict_alignment_dir is None):
            raise ValueError(
                'In inference mode, predict_alignment_dir must be specified'
            )      
        elif(val_data_dir is not None and val_alignment_dir is None):
            raise ValueError(
                'If val_data_dir is specified, val_alignment_dir must '
                'be specified as well'
        )

        # An ad-hoc measure for our particular filesystem restrictions
        self._distillation_structure_index = None
        if(_distillation_structure_index_path is not None):
            with open(_distillation_structure_index_path, "r") as fp:
                self._distillation_structure_index = json.load(fp)
        
        self.alignment_index = None
        if(alignment_index_path is not None):
            with open(alignment_index_path, "r") as fp:
                self.alignment_index = json.load(fp)

        self.distillation_alignment_index = None
        if(distillation_alignment_index_path is not None):
            with open(distillation_alignment_index_path, "r") as fp:
                self.distillation_alignment_index = json.load(fp)

    def setup(self):
        # Most of the arguments are the same for the three datasets 
        dataset_gen = partial(OpenFoldSingleDataset,
            template_mmcif_dir=self.template_mmcif_dir,
            max_template_date=self.max_template_date,
            config=self.config,
            kalign_binary_path=self.kalign_binary_path,
            template_release_dates_cache_path=
                self.template_release_dates_cache_path,
            obsolete_pdbs_file_path=
                self.obsolete_pdbs_file_path,
        )

        if(self.training_mode):
            train_dataset = dataset_gen(
                data_dir=self.train_data_dir,
                chain_data_cache_path=self.train_chain_data_cache_path,
                alignment_dir=self.train_alignment_dir,
                filter_path=self.train_filter_path,
                max_template_hits=self.config.train.max_template_hits,
                shuffle_top_k_prefiltered=
                    self.config.train.shuffle_top_k_prefiltered,
                treat_pdb_as_distillation=False,
                mode="train",
                alignment_index=self.alignment_index,
            )

            distillation_dataset = None
            if(self.distillation_data_dir is not None):
                distillation_dataset = dataset_gen(
                    data_dir=self.distillation_data_dir,
                    chain_data_cache_path=self.distillation_chain_data_cache_path,
                    alignment_dir=self.distillation_alignment_dir,
                    filter_path=self.distillation_filter_path,
                    max_template_hits=self.config.train.max_template_hits,
                    treat_pdb_as_distillation=True,
                    mode="train",
                    alignment_index=self.distillation_alignment_index,
                    _structure_index=self._distillation_structure_index,
                )

                d_prob = self.config.train.distillation_prob
           
            if(distillation_dataset is not None):
                datasets = [train_dataset, distillation_dataset]
                d_prob = self.config.train.distillation_prob
                probabilities = [1. - d_prob, d_prob]
            else:
                datasets = [train_dataset]
                probabilities = [1.]

            generator = None
            if(self.batch_seed is not None):
                generator = torch.Generator()
                generator = generator.manual_seed(self.batch_seed + 1)
            
            self.train_dataset = OpenFoldDataset(
                datasets=datasets,
                probabilities=probabilities,
                epoch_len=self.train_epoch_len,
                generator=generator,
                _roll_at_init=False,
            )
    
            if(self.val_data_dir is not None):
                self.eval_dataset = dataset_gen(
                    data_dir=self.val_data_dir,
                    alignment_dir=self.val_alignment_dir,
                    filter_path=None,
                    max_template_hits=self.config.eval.max_template_hits,
                    mode="eval",
                )
            else:
                self.eval_dataset = None
        else:           
            self.predict_dataset = dataset_gen(
                data_dir=self.predict_data_dir,
                alignment_dir=self.predict_alignment_dir,
                filter_path=None,
                max_template_hits=self.config.predict.max_template_hits,
                mode="predict",
            )

    def _gen_dataloader(self, stage):
        generator = torch.Generator()
        if(self.batch_seed is not None):
            generator = generator.manual_seed(self.batch_seed)

        dataset = None
        if(stage == "train"):
            dataset = self.train_dataset
            # Filter the dataset, if necessary
            dataset.reroll()
        elif(stage == "eval"):
            dataset = self.eval_dataset
        elif(stage == "predict"):
            dataset = self.predict_dataset
        else:
            raise ValueError("Invalid stage")

        batch_collator = OpenFoldBatchCollator()

        dl = OpenFoldDataLoader(
            dataset,
            config=self.config,
            stage=stage,
            generator=generator,
            batch_size=self.config.data_module.data_loaders.batch_size,
            num_workers=self.config.data_module.data_loaders.num_workers,
            collate_fn=batch_collator,
        )

        return dl

    def train_dataloader(self):
        return self._gen_dataloader("train") 

    def val_dataloader(self):
        if(self.eval_dataset is not None):
            return self._gen_dataloader("eval")
        return None

    def predict_dataloader(self):
        return self._gen_dataloader("predict") 


class DummyDataset(torch.utils.data.Dataset):
    def __init__(self, batch_path):
        with open(batch_path, "rb") as f:
            self.batch = pickle.load(f)

    def __getitem__(self, idx):
        return copy.deepcopy(self.batch)

    def __len__(self):
        return 1000


class DummyDataLoader(pl.LightningDataModule):
    def __init__(self, batch_path):
        super().__init__()
        self.dataset = DummyDataset(batch_path)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.dataset)
