import argparse
import os
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple

import numpy as np
import pandas as pd
import torch

from constants import MISSING_PROTEINS


class GetDataFiles:
    def __init__(self, config):
        self.config = config
        self.data_dir = Path(self.config.common.data_path)
        self.ontology = self.config.dataset.ontology
        self.features = self.config.features
        self.use_lm = self.features.get('use_lm')
        self.use_af = self.features.get('use_af')

        if self.use_lm:
            self.lm_dir = self.data_dir / Path(self.features.lm.dir)
        if self.use_af:
            self.use_single = self.features.af.use_single_rep
            self.use_states = self.features.af.use_states_rep
            self.use_pair = self.features.af.use_pair_rep

            if self.use_states:
                self.single_rep = "states"
            elif self.use_single:
                self.single_rep = "single"
            if (self.use_single or self.use_states):
                self.af_single_dir = self.data_dir / Path(self.features.af.dir) / Path(self.single_rep)
            if self.use_pair:
                self.pair_rep = "pair"
                self.af_pair_dir = self.data_dir / Path(self.features.af.dir) / Path(self.pair_rep)
            self.af_iter = self.features.af.recycling_iter

        self.labels_dir = os.path.join(self.data_dir, f"{self.ontology}_{self.config.common.labels_path}")
        self.scaling_dir = self.data_dir / Path(self.config.common.scaling_dir)
        self.sim_metric = self.config.dataset.similarity_metric
        self.seed = self.config.dataset.seed
        df = pd.read_csv(os.path.join(self.data_dir, self.config.dataset.metadata))
        df = df[['name', 'seq', 'length', f'{self.ontology}_one_hot', f"{self.sim_metric}_split_{self.seed}", f"{self.sim_metric}_thr_{self.seed}", f"{self.sim_metric}_clust_{self.seed}"]]
        self.df = df[df[f"{self.ontology}_one_hot"].notna()]

    def get_labels(self) -> Dict[str, Set[int]]:
        train_one_hots = self.df[self.df[f"{self.sim_metric}_split_{self.seed}"] == "train"][f"{self.ontology}_one_hot"].to_list()
        train_one_hots = [[int(i) for i in p.strip('][').split(', ')] for p in train_one_hots]
        train_labels = set([i for p in train_one_hots for i in p])

        val_one_hots = self.df[self.df[f"{self.sim_metric}_split_{self.seed}"] == "val"][f"{self.ontology}_one_hot"].to_list()
        val_one_hots = [[int(i) for i in p.strip('][').split(', ')] for p in val_one_hots]
        val_labels = set([i for p in val_one_hots for i in p])

        test_one_hots = self.df[self.df[f"{self.sim_metric}_split_{self.seed}"] == "test"][f"{self.ontology}_one_hot"].to_list()
        test_one_hots = [[int(i) for i in p.strip('][').split(', ')] for p in test_one_hots]
        test_labels = set([i for p in test_one_hots for i in p])

        labels = {
            "train": train_labels,
            "val": val_labels,
            "test": test_labels
            }
        return labels

    def __call__(self, 
                 split: str, 
                 max_length: int, 
                 missing_proteins: Optional[List[str]]=MISSING_PROTEINS
                 ):
        if missing_proteins:
            df = self.df[~self.df["name"].isin(missing_proteins)]
            df = df[df[f"{self.sim_metric}_split_{self.seed}"] == split]
        else:
            df = self.df[self.df[f"{self.sim_metric}_split_{self.seed}"] == split]
        names = df.name.tolist()
        seqs = df.seq.tolist()
        seq_lengths = df.length.tolist()
        one_hots = df[f"{self.ontology}_one_hot"].tolist()

        protein_thresholds = df[f"{self.sim_metric}_thr_{self.seed}"].tolist()
        protein_clusters = df[f"{self.sim_metric}_clust_{self.seed}"].tolist()
        sequences, labels, lengths, one_hot_labels, thresholds, clusters = [], [], [], [], [], []

        if 'structure' in self.features.inputs:
            ca_dir = os.path.join(self.data_dir, self.config.coords_ca_path)
            cb_dir = os.path.join(self.data_dir, self.config.coords_b_path)
            c_dir = os.path.join(self.data_dir, self.config.coords_c_path)
            n_dir = os.path.join(self.data_dir, self.config.coords_n_path)
            ca_coords, cb_coords, c_coords, n_coords = [], [], [], []

        if self.use_lm:
            lm_embeds = [] 
            if self.features.standardize:
                lm_mean_files, lm_std_files = [], []
        if self.use_af:
            af_single_files, af_pair_files, af_mean_files, af_std_files = [], [], [], []

        for name, seq, seq_len, one_hot, thr, clust in zip(names, seqs, seq_lengths, one_hots, protein_thresholds, protein_clusters):
            if seq_len <= max_length:
                sequences.append(seq)
                labels.append(os.path.join(self.labels_dir, name + '.npy'))
                lengths.append(seq_len)
                one_hot_labels.append(one_hot)
                thresholds.append(thr)
                clusters.append(clust)

                if self.features.get('use_lm'):
                    lm_embeds.append(os.path.join(self.lm_dir, name + '.pt'))

                if self.features.get('use_af'):
                    if (self.use_single or self.use_states):
                        single_file = self.af_single_dir / Path(name) / Path(f"{name}_{self.single_rep}_{self.af_iter}.pt")
                        af_single_files.append(single_file)
                    if self.use_pair:
                        pair_file = self.af_pair_dir / Path(name) / Path(f"{name}_{self.pair_rep}_{self.af_iter}.pt")
                        af_pair_files.append(pair_file)

                if 'structure' in self.features.inputs:
                    ca_coords.append(os.path.join(ca_dir, name + '.npy'))
                    cb_coords.append(os.path.join(cb_dir, name + '.npy'))
                    c_coords.append(os.path.join(c_dir, name + '.npy'))
                    n_coords.append(os.path.join(n_dir, name + '.npy'))

        outputs = {
            'seqs': sequences, 
            'labels': labels, 
            'lengths': lengths, 
            'one_hots': one_hot_labels,
            'thresholds': thresholds, 
            'clusters': clusters
            }

        if self.features.get('use_lm'):
            outputs['lm_embeds'] = lm_embeds
            if self.features.standardize:
                outputs['lm_mean'] = self.scaling_dir / Path(f"{self.ontology}_mean_{self.features['lm']['dir']}.pt")
                outputs['lm_std'] = self.scaling_dir / Path(f"{self.ontology}_std_{self.features['lm']['dir']}.pt")

        if self.features.get('use_af'):
            if (self.use_single or self.use_states):
                outputs['af_single'] = af_single_files
            if self.use_pair:
                outputs['af_pair'] = af_pair_files
            if self.features.standardize:
                outputs['af_mean'] = self.scaling_dir / Path(f"{self.ontology}_mean_af_{self.single_rep}_{self.af_iter}.pt")
                outputs['af_std'] = self.scaling_dir / Path(f"{self.ontology}_std_af_{self.single_rep}_{self.af_iter}.pt")
 
        if 'structure' in self.features.inputs:
            outputs['ca'] = ca_coords
            outputs['cb'] = cb_coords
            outputs['c'] = c_coords
            outputs['n'] = n_coords
        return outputs


class EmbeddingDataset(torch.utils.data.Dataset):
    """Map-style torch dataset for loading embeddings."""
    def __init__(self, config, data: Dict):
        super(EmbeddingDatasetIterable).__init__()
        self.config = config
        self.features = self.config.features
        self.label_files = data['labels']
        self.one_hots = data['one_hots']
        self.lengths = data['lengths']
        self.thresholds = data['thresholds'] 
        self.clusters = data['clusters'] 
        self.lm_rep = 'mean_representations' if self.features.embedding_type == 'mean' else 'representations'
        self.names = [Path(file).stem for file in self.label_files]

        if self.features.get('use_lm'):
            self.lm_embeds = data['lm_embeds']
            self.lm_layer = self.features.lm.layer
            assert len(self.lm_embeds) == len(self.label_files)
            if self.features.standardize:
                self.lm_mean = data['lm_mean']
                self.lm_std = data['lm_std']

        if self.features.get('use_af'):
            #self.af_rep = self.features.af.embedding
            self.use_single = self.features.af.use_single_rep
            self.use_states = self.features.af.use_states_rep
            self.use_pair = self.features.af.use_pair_rep

            self.struct_nlayer = self.features.af.struct_nlayer
            self.start_nlayer = self.features.af.start_concat_layer
            if (self.use_single or self.use_states):
                self.af_single_reps = data['af_single']
                assert len(self.af_single_reps) == len(self.label_files)
            if self.use_pair:
                self.af_pair_reps = data['af_pair']
                assert len(self.af_pair_reps) == len(self.label_files)

            if self.features.standardize:
                self.af_mean = data['af_mean']
                self.af_std = data['af_std']
    
    def __len__(self):
        'Denotes the total number of samples'
        return len(self.label_files)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        label = self.label_files[index]
        one_hot = self.one_hots[index]
        name = self.names[index]
        length = self.lengths[index]
        threshold = self.thresholds[index]
        cluster = self.clusters[index]

        if self.features.get('use_lm'):
            lm_embed = self.lm_embeds[index]
            if self.features.standardize:
                lm_mean_file = self.lm_mean
                lm_std_file = self.lm_std

        if self.features.get('use_af'):
            if (self.use_single or self.use_states):
                af_single = self.af_single_reps[index]
            if self.use_pair:
                af_pair = self.af_pair_reps[index]
            if self.features.standardize:
                af_mean_file = self.af_mean
                af_std_file = self.af_std

        # Load data and get label
        y = np.load(label)

        if self.features.get('use_lm'):
            x = torch.load(lm_embed)[self.lm_rep][self.lm_layer]
            outputs = (name, one_hot, length, threshold, cluster, y, x)

            if self.features.standardize:
                x_mean = torch.load(lm_mean_file)
                x_std = torch.load(lm_std_file)
                outputs += (x_mean, x_std,)

        if self.features.get('use_af'):
            if self.use_single:
                x = torch.load(af_single)

                if self.features.standardize:
                    x_mean = torch.load(af_mean_file)
                    x_std = torch.load(af_std_file)

            elif self.use_states:
                if self.features.af.concat_states_reps:
                    # load all single reps from Structure module up until nth layer
                    embed_base = str(af_single).split('.')[0]
                    embed_filenames = [f"{embed_base}_l_{i}.pt" for i in range(self.start_nlayer, self.struct_nlayer + 1)]
                    mean_base = str(af_mean_file).split('.')[0]
                    std_base = str(af_std_file).split('.')[0]
                    mean_filenames = [f"{mean_base}_l_{i}.pt" for i in range(self.start_nlayer, self.struct_nlayer + 1)]
                    std_filenames = [f"{std_base}_l_{i}.pt" for i in range(self.start_nlayer, self.struct_nlayer + 1)]

                    # N_res, 384 * len(filenames)
                    x = np.concatenate([torch.load(f) for f in embed_filenames], axis=-1)

                    if self.features.standardize:
                        # N_res, 384 * len(filenames)
                        x_mean = np.concatenate([torch.load(f) for f in mean_filenames], axis=-1)
                        x_std = np.concatenate([torch.load(f) for f in std_filenames], axis=-1)
                else:
                    # only use single rep from Structure module from nth layer
                    embed_file = f"{str(af_single).split('.')[0]}_l_{self.struct_nlayer}.pt"
                    mean_file = f"{str(af_mean_file).split('.')[0]}_l_{self.struct_nlayer}.pt"
                    std_file = f"{str(af_std_file).split('.')[0]}_l_{self.struct_nlayer}.pt"
                    x = torch.load(embed_file)

                    if self.features.standardize:
                        x_mean = torch.load(mean_file)
                        x_std = torch.load(std_file)

            outputs = (name, one_hot, length, threshold, cluster, y, x)
            if self.features.standardize:
                outputs += (x_mean, x_std,)

            if self.use_pair:
                x_pair = torch.load(af_pair)
                if outputs is not None:
                    outputs += (x_pair,)
                else:
                    outputs =  (name, one_hot, length, threshold, cluster, y, x_pair)
        return outputs



def padding_collator(
        batch: List[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]],
        standardize: bool,
        use_lm: bool,
        use_af: bool,
        use_single: bool,
        use_pair: bool
    ) -> Dict[str, torch.Tensor]:

    max_len = max(emb_and_label[6].shape[0] for emb_and_label in batch)   # emb_and_label[0]

    if batch[0][6].ndim == 1:    # sequence  [N_res]
        embs = np.stack([np.pad(e_l[6], (0, max_len - len(e_l[6]))) for e_l in batch])   # 0

    elif batch[0][6].ndim == 2:  # embedding  [N_res, embed_dim]
        embs = np.stack([np.pad(e_l[6], ((0, max_len - len(e_l[6])), (0, 0))) for e_l in batch])

    elif batch[0][6].ndim == 3:   # embedding [M, N_res, embed_dim]  (e.g M concatenated single reps) -> pad along N_res axis
        max_len = max(emb_and_label[6].shape[1] for emb_and_label in batch)
        embs = np.stack([np.pad(e_l[6], ((0, 0), (0, max_len - e_l[6].shape[1]), (0, 0))) for e_l in batch])

    names = np.stack([emb_and_label[0] for emb_and_label in batch])
    one_hots = np.stack([emb_and_label[1] for emb_and_label in batch])
    lengths = np.stack([emb_and_label[2] for emb_and_label in batch])
    thresholds = np.stack([emb_and_label[3] for emb_and_label in batch])
    clusters = np.stack([emb_and_label[4] for emb_and_label in batch])
    labels = np.stack([emb_and_label[5] for emb_and_label in batch]).astype(float)

    outputs = {
        "embeds": torch.tensor(embs), 
        "labels": torch.tensor(labels), 
        "names": names, 
        "one_hots": one_hots,
        "lengths": lengths,
        "thresholds": thresholds,
        "clusters": clusters
    }

    if standardize:
        means = np.stack([emb_and_label[7] for emb_and_label in batch])
        stds = np.stack([emb_and_label[8] for emb_and_label in batch])
        outputs["feats_mean"] = torch.tensor(means)
        outputs["feats_std"] = torch.tensor(stds)

    if use_af and use_pair:
        if use_single or use_lm:
            # af_single_rep + pair rep, or lm_rep + pair rep
            if standardize:
                # pair: [N_res, N_res, pair_embed_dim]
                pairs = np.stack([np.pad(e_l[9], ((0, max_len - e_l[9].shape[0]), (0, max_len - e_l[9].shape[0]), (0, 0))) for e_l in batch])
            else:
                pairs = np.stack([np.pad(e_l[7], ((0, max_len - e_l[7].shape[0]), (0, max_len - e_l[7].shape[0]), (0, 0))) for e_l in batch])
        else:
            # pair rep only
            pairs = np.stack([np.pad(e_l[6], ((0, max_len - e_l[6].shape[0]), (0, max_len - e_l[6].shape[0]), (0, 0))) for e_l in batch])
        outputs["af_pairs"] = torch.tensor(pairs)
    return outputs

