import copy
from functools import partial
import json
import logging
import os
import pickle
from typing import Optional, Sequence, List, Any

import ml_collections as mlc
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import RandomSampler

from openfold.data import (
    data_pipeline,
    feature_pipeline,
    mmcif_parsing,
    templates,
)
from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap


class OpenFoldSingleDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        data_dir: str,
        alignment_dir: str,
        template_mmcif_dir: str,
        max_template_date: str,
        config: mlc.ConfigDict,
        kalign_binary_path: str = "/usr/bin/kalign",
        max_template_hits: int = 4,
        obsolete_pdbs_file_path: Optional[str] = None,
        template_release_dates_cache_path: Optional[str] = None,
        shuffle_top_k_prefiltered: Optional[int] = None,
        treat_pdb_as_distillation: bool = True,
        mapping_path: Optional[str] = None,
        mode: str = "train",
        _output_raw: bool = False,
        _alignment_index: Optional[Any] = None,
    ):

        super(OpenFoldSingleDataset, self).__init__()
        self.data_dir = data_dir
        self.alignment_dir = alignment_dir
        self.config = config
        self.treat_pdb_as_distillation = treat_pdb_as_distillation
        self.mode = mode
        self._output_raw = _output_raw
        self._alignment_index = _alignment_index

        valid_modes = ["train", "eval", "predict"]
        if mode not in valid_modes:
            raise ValueError(f"mode must be one of {valid_modes}")

        if template_release_dates_cache_path is None:
            logging.warning(
                "Template release dates cache does not exist. Remember to run "
                "scripts/generate_mmcif_cache.py before running OpenFold"
            )

        if _alignment_index is not None:
            self._chain_ids = list(_alignment_index.keys())
        elif mapping_path is None:
            self._chain_ids = list(os.listdir(alignment_dir))
        else:
            with open(mapping_path, "r") as f:
                self._chain_ids = [l.strip() for l in f.readlines()]

        self._chain_id_to_idx_dict = {
            chain: i for i, chain in enumerate(self._chain_ids)
        }

        template_featurizer = templates.TemplateHitFeaturizer(
            mmcif_dir=template_mmcif_dir,
            max_template_date=max_template_date,
            max_hits=max_template_hits,
            kalign_binary_path=kalign_binary_path,
            release_dates_path=template_release_dates_cache_path,
            obsolete_pdbs_path=obsolete_pdbs_file_path,
            _shuffle_top_k_prefiltered=shuffle_top_k_prefiltered,
        )

        self.data_pipeline = data_pipeline.DataPipeline(
            template_featurizer=template_featurizer,
        )

        if not self._output_raw:
            self.feature_pipeline = feature_pipeline.FeaturePipeline(config)

    def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, _alignment_index):
        with open(path, "r") as f:
            mmcif_string = f.read()

        mmcif_object = mmcif_parsing.parse(file_id=file_id, mmcif_string=mmcif_string)

        if mmcif_object.mmcif_object is None:
            raise list(mmcif_object.errors.values())[0]

        mmcif_object = mmcif_object.mmcif_object

        data = self.data_pipeline.process_mmcif(
            mmcif=mmcif_object,
            alignment_dir=alignment_dir,
            chain_id=chain_id,
            _alignment_index=_alignment_index,
        )

        return data

    def chain_id_to_idx(self, chain_id):
        return self._chain_id_to_idx_dict[chain_id]

    def idx_to_chain_id(self, idx):
        return self._chain_ids[idx]

    def __getitem__(self, idx):
        name = self.idx_to_chain_id(idx)
        alignment_dir = os.path.join(self.alignment_dir, name)

        _alignment_index = None
        if self._alignment_index is not None:
            alignment_dir = self.alignment_dir
            _alignment_index = self._alignment_index[name]

        if self.mode == "train" or self.mode == "eval":
            spl = name.rsplit("_", 1)
            if len(spl) == 2:
                file_id, chain_id = spl
            else:
                (file_id,) = spl
                chain_id = None

            path = os.path.join(self.data_dir, file_id)
            if os.path.exists(path + ".cif"):
                data = self._parse_mmcif(
                    path + ".cif",
                    file_id,
                    chain_id,
                    alignment_dir,
                    _alignment_index,
                )
            elif os.path.exists(path + ".core"):
                data = self.data_pipeline.process_core(
                    path + ".core",
                    alignment_dir,
                    _alignment_index,
                )
            elif os.path.exists(path + ".pdb"):
                data = self.data_pipeline.process_pdb(
                    pdb_path=path + ".pdb",
                    alignment_dir=alignment_dir,
                    is_distillation=self.treat_pdb_as_distillation,
                    chain_id=chain_id,
                    _alignment_index=_alignment_index,
                )
            else:
                raise ValueError("Invalid file type")
        else:
            path = os.path.join(name, name + ".fasta")
            data = self.data_pipeline.process_fasta(
                fasta_path=path,
                alignment_dir=alignment_dir,
                _alignment_index=_alignment_index,
            )

        if self._output_raw:
            return data

        feats = self.feature_pipeline.process_features(data, self.mode)

        return feats

    def __len__(self):
        return len(self._chain_ids)


def deterministic_train_filter(
    chain_data_cache_entry: Any,
    max_resolution: float = 9.0,
    max_single_aa_prop: float = 0.8,
) -> bool:

    resolution = chain_data_cache_entry.get("resolution", None)
    if resolution is not None and resolution > max_resolution:
        return False

    seq = chain_data_cache_entry["seq"]
    counts = {}
    for aa in seq:
        counts.setdefault(aa, 0)
        counts[aa] += 1
    largest_aa_count = max(counts.values())
    largest_single_aa_prop = largest_aa_count / len(seq)
    if largest_single_aa_prop > max_single_aa_prop:
        return False

    return True


def get_stochastic_train_filter_prob(
    chain_data_cache_entry: Any,
) -> List[float]:

    probabilities = []

    cluster_size = chain_data_cache_entry.get("cluster_size", None)
    if cluster_size is not None and cluster_size > 0:
        probabilities.append(1 / cluster_size)

    chain_length = len(chain_data_cache_entry["seq"])
    probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))

    out = 1
    for p in probabilities:
        out *= p

    return out


class OpenFoldDataset(torch.utils.data.Dataset):

    def __init__(
        self,
        datasets: Sequence[OpenFoldSingleDataset],
        probabilities: Sequence[int],
        epoch_len: int,
        chain_data_cache_paths: List[str],
        generator: torch.Generator = None,
        _roll_at_init: bool = True,
    ):
        self.datasets = datasets
        self.probabilities = probabilities
        self.epoch_len = epoch_len
        self.generator = generator

        self.chain_data_caches = []
        for path in chain_data_cache_paths:
            with open(path, "r") as fp:
                self.chain_data_caches.append(json.load(fp))

        def looped_shuffled_dataset_idx(dataset_len):
            while True:

                weights = [1.0 for _ in range(dataset_len)]
                shuf = torch.multinomial(
                    torch.tensor(weights),
                    num_samples=dataset_len,
                    replacement=False,
                    generator=self.generator,
                )
                for idx in shuf:
                    yield idx

        def looped_samples(dataset_idx):
            max_cache_len = int(epoch_len * probabilities[dataset_idx])
            dataset = self.datasets[dataset_idx]
            idx_iter = looped_shuffled_dataset_idx(len(dataset))
            chain_data_cache = self.chain_data_caches[dataset_idx]
            while True:
                weights = []
                idx = []
                for _ in range(max_cache_len):
                    candidate_idx = next(idx_iter)
                    chain_id = dataset.idx_to_chain_id(candidate_idx)
                    chain_data_cache_entry = chain_data_cache[chain_id]
                    if not deterministic_train_filter(chain_data_cache_entry):
                        continue

                    p = get_stochastic_train_filter_prob(
                        chain_data_cache_entry,
                    )
                    weights.append([1.0 - p, p])
                    idx.append(candidate_idx)

                samples = torch.multinomial(
                    torch.tensor(weights),
                    num_samples=1,
                    generator=self.generator,
                )
                samples = samples.squeeze()

                cache = [i for i, s in zip(idx, samples) if s]

                for datapoint_idx in cache:
                    yield datapoint_idx

        self._samples = [looped_samples(i) for i in range(len(self.datasets))]

        if _roll_at_init:
            self.reroll()

    def __getitem__(self, idx):
        dataset_idx, datapoint_idx = self.datapoints[idx]
        return self.datasets[dataset_idx][datapoint_idx]

    def __len__(self):
        return self.epoch_len

    def reroll(self):
        dataset_choices = torch.multinomial(
            torch.tensor(self.probabilities),
            num_samples=self.epoch_len,
            replacement=True,
            generator=self.generator,
        )

        self.datapoints = []
        for dataset_idx in dataset_choices:
            samples = self._samples[dataset_idx]
            datapoint_idx = next(samples)
            self.datapoints.append((dataset_idx, datapoint_idx))


class OpenFoldBatchCollator:
    def __init__(self, config, stage="train"):
        self.stage = stage
        self.feature_pipeline = feature_pipeline.FeaturePipeline(config)

    def __call__(self, raw_prots):
        processed_prots = []
        for prot in raw_prots:
            features = self.feature_pipeline.process_features(prot, self.stage)
            processed_prots.append(features)

        stack_fn = partial(torch.stack, dim=0)
        return dict_multimap(stack_fn, processed_prots)


class OpenFoldDataLoader(torch.utils.data.DataLoader):
    def __init__(self, *args, config, stage="train", generator=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.config = config
        self.stage = stage

        if generator is None:
            generator = torch.Generator()

        self.generator = generator
        self._prep_batch_properties_probs()

    def _prep_batch_properties_probs(self):
        keyed_probs = []
        stage_cfg = self.config[self.stage]

        max_iters = self.config.common.max_recycling_iters
        if stage_cfg.supervised:
            clamp_prob = self.config.supervised.clamp_prob
            keyed_probs.append(("use_clamped_fape", [1 - clamp_prob, clamp_prob]))

        if stage_cfg.uniform_recycling:
            recycling_probs = [1.0 / (max_iters + 1) for _ in range(max_iters + 1)]
        else:
            recycling_probs = [0.0 for _ in range(max_iters + 1)]
            recycling_probs[-1] = 1.0

        keyed_probs.append(("no_recycling_iters", recycling_probs))

        keys, probs = zip(*keyed_probs)
        max_len = max([len(p) for p in probs])
        padding = [[0.0] * (max_len - len(p)) for p in probs]

        self.prop_keys = keys
        self.prop_probs_tensor = torch.tensor(
            [p + pad for p, pad in zip(probs, padding)],
            dtype=torch.float32,
        )

    def _add_batch_properties(self, batch):
        samples = torch.multinomial(
            self.prop_probs_tensor,
            num_samples=1,
            replacement=True,
            generator=self.generator,
        )

        aatype = batch["aatype"]
        batch_dims = aatype.shape[:-2]
        recycling_dim = aatype.shape[-1]
        no_recycling = recycling_dim
        for i, key in enumerate(self.prop_keys):
            sample = int(samples[i][0])
            sample_tensor = torch.tensor(
                sample, device=aatype.device, requires_grad=False
            )
            orig_shape = sample_tensor.shape
            sample_tensor = sample_tensor.view(
                (1,) * len(batch_dims) + sample_tensor.shape + (1,)
            )
            sample_tensor = sample_tensor.expand(
                batch_dims + orig_shape + (recycling_dim,)
            )
            batch[key] = sample_tensor

            if key == "no_recycling_iters":
                no_recycling = sample

        resample_recycling = lambda t: t[..., : no_recycling + 1]
        batch = tensor_tree_map(resample_recycling, batch)

        return batch

    def __iter__(self):
        it = super().__iter__()

        def _batch_prop_gen(iterator):
            for batch in iterator:
                yield self._add_batch_properties(batch)

        return _batch_prop_gen(it)


class OpenFoldDataModule(pl.LightningDataModule):
    def __init__(
        self,
        config: mlc.ConfigDict,
        template_mmcif_dir: str,
        max_template_date: str,
        train_data_dir: Optional[str] = None,
        train_alignment_dir: Optional[str] = None,
        train_chain_data_cache_path: Optional[str] = None,
        distillation_data_dir: Optional[str] = None,
        distillation_alignment_dir: Optional[str] = None,
        distillation_chain_data_cache_path: Optional[str] = None,
        val_data_dir: Optional[str] = None,
        val_alignment_dir: Optional[str] = None,
        predict_data_dir: Optional[str] = None,
        predict_alignment_dir: Optional[str] = None,
        kalign_binary_path: str = "/usr/bin/kalign",
        train_mapping_path: Optional[str] = None,
        distillation_mapping_path: Optional[str] = None,
        obsolete_pdbs_file_path: Optional[str] = None,
        template_release_dates_cache_path: Optional[str] = None,
        batch_seed: Optional[int] = None,
        train_epoch_len: int = 50000,
        _alignment_index_path: Optional[str] = None,
        **kwargs,
    ):
        super(OpenFoldDataModule, self).__init__()

        self.config = config
        self.template_mmcif_dir = template_mmcif_dir
        self.max_template_date = max_template_date
        self.train_data_dir = train_data_dir
        self.train_alignment_dir = train_alignment_dir
        self.train_chain_data_cache_path = train_chain_data_cache_path
        self.distillation_data_dir = distillation_data_dir
        self.distillation_alignment_dir = distillation_alignment_dir
        self.distillation_chain_data_cache_path = distillation_chain_data_cache_path
        self.val_data_dir = val_data_dir
        self.val_alignment_dir = val_alignment_dir
        self.predict_data_dir = predict_data_dir
        self.predict_alignment_dir = predict_alignment_dir
        self.kalign_binary_path = kalign_binary_path
        self.train_mapping_path = train_mapping_path
        self.distillation_mapping_path = distillation_mapping_path
        self.template_release_dates_cache_path = template_release_dates_cache_path
        self.obsolete_pdbs_file_path = obsolete_pdbs_file_path
        self.batch_seed = batch_seed
        self.train_epoch_len = train_epoch_len

        if self.train_data_dir is None and self.predict_data_dir is None:
            raise ValueError(
                "At least one of train_data_dir or predict_data_dir must be "
                "specified"
            )

        self.training_mode = self.train_data_dir is not None

        if self.training_mode and train_alignment_dir is None:
            raise ValueError("In training mode, train_alignment_dir must be specified")
        elif not self.training_mode and predict_alignment_dir is None:
            raise ValueError(
                "In inference mode, predict_alignment_dir must be specified"
            )
        elif val_data_dir is not None and val_alignment_dir is None:
            raise ValueError(
                "If val_data_dir is specified, val_alignment_dir must "
                "be specified as well"
            )

        self._alignment_index = None
        if _alignment_index_path is not None:
            with open(_alignment_index_path, "r") as fp:
                self._alignment_index = json.load(fp)

    def setup(self):

        dataset_gen = partial(
            OpenFoldSingleDataset,
            template_mmcif_dir=self.template_mmcif_dir,
            max_template_date=self.max_template_date,
            config=self.config,
            kalign_binary_path=self.kalign_binary_path,
            template_release_dates_cache_path=self.template_release_dates_cache_path,
            obsolete_pdbs_file_path=self.obsolete_pdbs_file_path,
        )

        if self.training_mode:
            train_dataset = dataset_gen(
                data_dir=self.train_data_dir,
                alignment_dir=self.train_alignment_dir,
                mapping_path=self.train_mapping_path,
                max_template_hits=self.config.train.max_template_hits,
                shuffle_top_k_prefiltered=self.config.train.shuffle_top_k_prefiltered,
                treat_pdb_as_distillation=False,
                mode="train",
                _output_raw=True,
                _alignment_index=self._alignment_index,
            )

            distillation_dataset = None
            if self.distillation_data_dir is not None:
                distillation_dataset = dataset_gen(
                    data_dir=self.distillation_data_dir,
                    alignment_dir=self.distillation_alignment_dir,
                    mapping_path=self.distillation_mapping_path,
                    max_template_hits=self.train.max_template_hits,
                    treat_pdb_as_distillation=True,
                    mode="train",
                    _output_raw=True,
                )

                d_prob = self.config.train.distillation_prob

            if distillation_dataset is not None:
                datasets = [train_dataset, distillation_dataset]
                d_prob = self.config.train.distillation_prob
                probabilities = [1 - d_prob, d_prob]
                chain_data_cache_paths = [
                    self.train_chain_data_cache_path,
                    self.distillation_chain_data_cache_path,
                ]
            else:
                datasets = [train_dataset]
                probabilities = [1.0]
                chain_data_cache_paths = [
                    self.train_chain_data_cache_path,
                ]

            self.train_dataset = OpenFoldDataset(
                datasets=datasets,
                probabilities=probabilities,
                epoch_len=self.train_epoch_len,
                chain_data_cache_paths=chain_data_cache_paths,
                _roll_at_init=False,
            )

            if self.val_data_dir is not None:
                self.eval_dataset = dataset_gen(
                    data_dir=self.val_data_dir,
                    alignment_dir=self.val_alignment_dir,
                    mapping_path=None,
                    max_template_hits=self.config.eval.max_template_hits,
                    mode="eval",
                    _output_raw=True,
                )
            else:
                self.eval_dataset = None
        else:
            self.predict_dataset = dataset_gen(
                data_dir=self.predict_data_dir,
                alignment_dir=self.predict_alignment_dir,
                mapping_path=None,
                max_template_hits=self.config.predict.max_template_hits,
                mode="predict",
            )

    def _gen_dataloader(self, stage):
        generator = torch.Generator()
        if self.batch_seed is not None:
            generator = generator.manual_seed(self.batch_seed)

        dataset = None
        if stage == "train":
            dataset = self.train_dataset

            dataset.reroll()
        elif stage == "eval":
            dataset = self.eval_dataset
        elif stage == "predict":
            dataset = self.predict_dataset
        else:
            raise ValueError("Invalid stage")

        batch_collator = OpenFoldBatchCollator(self.config, stage)

        dl = OpenFoldDataLoader(
            dataset,
            config=self.config,
            stage=stage,
            generator=generator,
            batch_size=self.config.data_module.data_loaders.batch_size,
            num_workers=self.config.data_module.data_loaders.num_workers,
            collate_fn=batch_collator,
        )

        return dl

    def train_dataloader(self):
        return self._gen_dataloader("train")

    def val_dataloader(self):
        if self.eval_dataset is not None:
            return self._gen_dataloader("eval")
        return None

    def predict_dataloader(self):
        return self._gen_dataloader("predict")


class DummyDataset(torch.utils.data.Dataset):
    def __init__(self, batch_path):
        with open(batch_path, "rb") as f:
            self.batch = pickle.load(f)

    def __getitem__(self, idx):
        return copy.deepcopy(self.batch)

    def __len__(self):
        return 1000


class DummyDataLoader(pl.LightningDataModule):
    def __init__(self, batch_path):
        super().__init__()
        self.dataset = DummyDataset(batch_path)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.dataset)
