# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import os
import torch
import torch.nn.functional as F
import collections
import pandas as pd
from torch.utils.data import Dataset
import logging
logger = logging.getLogger('gearnet_dataset')
import pathlib
from s3fs import S3FileSystem
import cpdb
from typing import Tuple, Optional
from torchdrug.data import Protein
from affinityenhancer.data.datasets.gearnet_data_utils \
    import pdb_to_gearnet_protein, get_edges_from_gearnet_struct
from affinityenhancer.data.utils.utils import get_noise



def select_indices_along_dim(tensor, dim, indices):
    # Create a list of slices
    slices = [slice(None)] * tensor.dim()
    
    # Replace the slice for the specified dimension with the indices
    slices[dim] = indices
    
    # Use advanced indexing to select the desired elements
    selected_tensor = tensor[slices]
    
    return selected_tensor


def construct_2d_padded_tensor(M1, M2, M3, M4, X1, Y1, pad_value):
    # Determine the dimensions of each block
    X, _ = M1.size()
    Y, _ = M4.size()
    
    # Create the padding blocks
    pad_X1_Y = torch.full((X1, Y), pad_value)
    pad_X_Y1 = torch.full((X, Y1), pad_value)
    pad_X1_X1 = torch.full((X1, X1), pad_value)
    pad_Y1_Y1 = torch.full((Y1, Y1), pad_value)
    pad_Y_Y1 = torch.full((Y, Y1), pad_value)
    pad_X_X1 = torch.full((X, X1), pad_value)
    pad_X1_Y1 = torch.full((X1, Y1), pad_value)

    # Construct the final matrix using torch.cat
    top_row = torch.cat((M1, pad_X_X1, M2, pad_X_Y1), dim=1)
    second_row = torch.cat((pad_X_X1.T, pad_X1_X1, pad_X1_Y, pad_X1_Y1), dim=1)
    third_row = torch.cat((M3, pad_X1_Y.T, M4, pad_Y_Y1), dim=1)
    fourth_row = torch.cat((pad_X_Y1.T, pad_X1_Y1.T, pad_Y_Y1.T, pad_Y1_Y1), dim=1)

    M = torch.cat((top_row, second_row, third_row, fourth_row), dim=0)
    
    return M
    

def transform_2d_to_heavy_light(chain_id, input, max_heavy_len, max_light_len,
                                value=20):
    # dim will always be the row so dim 0
    dim=0
    heavy_len = chain_id[chain_id == 0].shape[0]
    indices_heavy = [i for i in range(heavy_len)]
    heavy = select_indices_along_dim(input, dim, indices_heavy)
    indices_light = [i for i in range(heavy_len, chain_id.shape[0])]
    light = select_indices_along_dim(input, dim, indices_light)
        
    dim=1
    heavy_heavy = select_indices_along_dim(heavy, dim, indices_heavy)
    heavy_light = select_indices_along_dim(heavy, dim, indices_light)
    light_light = select_indices_along_dim(light, dim, indices_light)
    
    return construct_2d_padded_tensor(heavy_heavy,
                                      heavy_light,
                                      heavy_light.permute(1, 0),
                                      light_light, 
                                      max_heavy_len - heavy.shape[0], 
                                      max_light_len - light.shape[0],
                                      value
                                      )


def transform_to_heavy_light(chain_id, input, max_heavy_len, max_light_len,
                             value=20):
    # dim will always be the row so dim 0
    dim=0
    heavy_len = chain_id[chain_id == 0].shape[0]
    indices_heavy = [i for i in range(heavy_len)]
    heavy = select_indices_along_dim(input, dim, indices_heavy)
    indices_light = [i for i in range(heavy_len, chain_id.shape[0])]
    light = select_indices_along_dim(input, dim, indices_light)
        
    if len(input.shape) ==2:
        #debug(heavy, 'coords_heavy_prefix.png')
        #debug(light, 'coords_light_prefix.png')
        heavy = torch.nn.functional.pad(heavy, (0, 0,
                                                0, max_heavy_len - heavy.shape[0]
                                                ),
                                        value=value)
        light = torch.nn.functional.pad(light, (0, 0,
                                                0, max_light_len - light.shape[0]
                                                ),
                                        value=value)
    else:
        heavy = torch.nn.functional.pad(heavy, (0, max_heavy_len - heavy.shape[0]), value=value)
        light = torch.nn.functional.pad(light, (0, max_light_len - light.shape[0]), value=value)
    
    output = torch.cat([heavy, light], dim=0)
    return output
        

class GearNetDataset(Dataset):
    def __init__(
        self,
        df,
        s3_pdb_path,
        data_dir: str = str(pathlib.Path(__file__).parent.parent / "data"),
        in_memory: bool = True,
        bulk: bool = False,
        add_noise: bool = False,
        noise_var: float = 0.005,
        max_files: int = 104413,
        chain_map: Optional[dict] = None,
        add_cdr: bool = False,
        ):
        """
        :param df: Dataset dataframe
        :param data_dir: Directory to store data
        :oaram transforms: List of transforms to apply to each structure
        :param in_memory: Whether to load structures into memory (faster)
        """
        
        self.df = df
        self.s3_pdb_path = s3_pdb_path if not s3_pdb_path.endswith('/') else s3_pdb_path[:-1]
        self.bulk = bulk
        self.add_noise = add_noise
        self.noise_var = noise_var
        self.in_memory = in_memory

        self.max_files = max_files
        self.add_cdr = add_cdr

        # N.B there now appear to be some more chains (likely multimeric antigens)
        # This is a hack to set them to the same chain (2). We can discuss if this is
        # the right thing to do.
        if chain_map is None:
            self.CHAIN_MAP = collections.defaultdict(lambda: 2)
            self.CHAIN_MAP["H"] = 0
            self.CHAIN_MAP["L"] = 1
            self.CHAIN_MAP["A"] = 2
        else:
            self.CHAIN_MAP = chain_map
        
        self.DATA_DIR = pathlib.Path(data_dir)
        if not os.path.exists(self.DATA_DIR):
            os.makedirs(self.DATA_DIR, exist_ok=True)
        logger.info(f"Using data directory: {self.DATA_DIR}")
        self._setup()


    def _setup(self):

        # Get list of files in S3 bucket:
        if self.s3_pdb_path.startswith('s3:'):
            s3_pdb_fnames = S3FileSystem().ls(self.s3_pdb_path)
            logger.info(f"Expected PDBs: {min(len(s3_pdb_fnames), self.df.shape[0])}")
        else:
            print(self.s3_pdb_path, f'is not an s3 path. Looking for pdbs in self.DATA_DIR = {self.DATA_DIR}')
        
        logger.info(f"Dataset size: {len(self.df)}")
        present_files = os.listdir(self.DATA_DIR)
        present_pdbs = [f for f in present_files if f.endswith(".pdb")]
        present_processed_examples = [f for f in present_files if f.endswith(".pt")]

        logger.info(f"Found {len(present_pdbs)} structures in {self.DATA_DIR}")
        logger.info(f"Found {len(present_processed_examples)} processed structures in {self.DATA_DIR}")

        logger.info(f"Dataset: {len(self.df)}")
        self.structure_ids = self.df.seqid.tolist()

        # Download PDBs if we need to
        if (len(present_pdbs) <  self.df.shape[0]):
            if self.s3_pdb_path.startswith('s3:'):
                logger.info(
                    f"Expected {self.df.shape[0]} PDBs, found {len(present_pdbs)} in {self.DATA_DIR}"
                    )
                self._download_structures()
            else:
                raise FileNotFoundError(f"{self.DATA_DIR} does not have expected number of pdbs: {self.df.shape[0]}")
 
        self._process_structures()
        # If using in-memory loading, load all structures
        if self.in_memory:
            self.examples = self._load_examples()

    def __len__(self) -> int:
        """Returns length of dataset"""
        if self.df is None:
            raise ValueError("Please initialize self.df by running dataset.setup(df) first")
        return self.df.shape[0]

    
    def _process_structures(self, save_positions=False):
        """
        Pre-processes and caches protein datastructures in TorchProtein format
        """
        for id in list(set(self.structure_ids)):
            if not os.path.exists(self.DATA_DIR / f"{id}.pt"):
                if not os.path.exists(self.DATA_DIR / f"{id}.pdb"):
                    logger.warning(f"File {id}.pdb not found in {self.DATA_DIR}")
                    os.system(f"aws s3 cp {self.s3_pdb_path}/{id}.pdb {str(self.DATA_DIR)}/.")
                    assert os.path.exists(self.DATA_DIR / f"{id}.pdb")
                logger.info(f"Processing {id} in {self.DATA_DIR}")
                p, positions = pdb_to_gearnet_protein(str(self.DATA_DIR / f"{id}.pdb"),
                                                      self.CHAIN_MAP,
                                                      self.add_noise,
                                                      self.noise_var,
                                                      add_cdr=self.add_cdr)
                torch.save(p, self.DATA_DIR / f"{id}.pt")
                if save_positions:
                    torch.save(positions, self.DATA_DIR / f"xyz_{id}.xyz.pt")

    def _download_structures(self):
        """Download structure files from AWS S3"""
        logger.info("Downloading structures from AWS S3")

        if self.bulk:
            S3FileSystem().download(
                    rpath=self.s3_pdb_path,
                    lpath=str(self.DATA_DIR),
                    recursive=True,
                    )
        else:
            print('Downloading files')
            print(len(self.structure_ids))
            for struct_file in self.structure_ids:
                os.system(f"aws s3 cp {self.s3_pdb_path}/{struct_file}.pdb {str(self.DATA_DIR)}/.")
        structure_files = os.listdir(self.DATA_DIR)
        structure_files = [f for f in structure_files if f.endswith(".pdb")]
        logger.info(f"Downloaded structures. Number of PDB files present {len(structure_files)}")


    def _load_examples(self):
        """Load all structures into memory"""
        logger.info("Loading structures into memory")
        examples = {}
        for id in self.structure_ids:
            examples[id] = torch.load(self.DATA_DIR / f"{id}.pt")
        return examples

    

    def __getitem__(self, idx) -> Tuple[Protein, torch.Tensor]:
        if idx > len(self.df):
            idx = int(idx % len(self.df))

        row = self.df.iloc[idx]
        id = row["seqid"]

        if self.in_memory:
            struct = self.examples[id]
        else:
            struct = torch.load(self.DATA_DIR / f"{id}.pt")

        y = (
            F.one_hot(torch.tensor([int(row["is_binder"])]), num_classes=2)
            .reshape(-1)
            .float()
        )
        with struct.residue():
            struct.residue_feature = struct.node_feature.to_dense()
        struct.view = "residue"
        return struct, y




class GearNetPairedAbStructureSequenceDataset(GearNetDataset):
    def __init__(
        self,
        df,
        s3_pdb_path: str = "s3://prescient-data-dev/sandbox/mahajs17/OAS_paired",
        data_dir: str = str(pathlib.Path(__file__).parent.parent / "data"),
        in_memory: bool = True,
        add_noise: bool = True,
        noise_var: float = 0.01,
        max_seq_len: int = 301,
        add_cdr=False
        ):
        super().__init__(df=df,
                         s3_pdb_path=s3_pdb_path,
                         data_dir=data_dir,
                         in_memory=in_memory,
                         add_noise=add_noise,
                         noise_var=noise_var,
                         add_cdr=add_cdr
                         )
        self.max_seq_len = max_seq_len
    

    def __getitem__(self, idx) -> Tuple[Protein, torch.Tensor]:
        if idx > len(self.df):
            idx = int(idx % len(self.df))

        row = self.df.iloc[idx]
        id = row["seqid"]

        if self.in_memory:
            struct = self.examples[id]
        else:
            struct = torch.load(self.DATA_DIR / f"{id}.pt")

        with struct.residue():
            struct.residue_feature = struct.node_feature.to_dense()
        struct.view = "residue"

        sequence = struct.residue_type
        sequence = torch.nn.functional.pad(sequence, 
                                           (0, 
                                            self.max_seq_len - 
                                            sequence.shape[0]), 
                                            value=20
                                            )
        
        return struct, sequence.long()


class GearNetHeavyLightStructureSequenceDataset(GearNetDataset):
    def __init__(
        self,
        df,
        s3_pdb_path: str = "OAS_paired",
        data_dir: str = str(pathlib.Path(__file__).parent.parent / "data"),
        in_memory: bool = True,
        add_noise: bool = True,
        noise_var: float = 0.01,
        max_seq_len: int = 301,
        max_heavy_len: int = 151,
        max_light_len: int = 150,
        chain_map = {'A': 0, 'B': 1},
        add_cdr=False,
        ):
        super().__init__(df=df,
                         s3_pdb_path=s3_pdb_path,
                         data_dir=data_dir,
                         in_memory=in_memory,
                         add_noise=add_noise,
                         noise_var=noise_var,
                         chain_map=chain_map,
                         add_cdr=add_cdr,
                         )
        self.max_seq_len = max_seq_len
        self.max_heavy_len = max_heavy_len
        self.max_light_len = max_light_len
        
    
    def __getitem__(self, idx) -> Tuple[Protein, torch.Tensor]:
        if idx > len(self.df):
            idx = int(idx % len(self.df))

        row = self.df.iloc[idx]
        id = row["seqid"]
        
        if self.in_memory:
            struct = self.examples[id]
        else:
            struct = torch.load(self.DATA_DIR / f"{id}.pt")

        with struct.residue():
            struct.residue_feature = struct.node_feature.to_dense()
        struct.view = "residue"

        sequence = struct.residue_type
        chain_id = struct.chain_id
        # replace with transform_to_heavy_light
        heavy_len = chain_id[chain_id==0].shape[0]
        heavy, light = sequence[:heavy_len], sequence[heavy_len:]
        heavy = torch.nn.functional.pad(heavy, (0, self.max_heavy_len - heavy.shape[0]), value=20)
        light = torch.nn.functional.pad(light, (0, self.max_light_len - light.shape[0]), value=20)
        sequence = torch.cat([heavy, light], dim=0)
        return struct, sequence.long() #(heavy.long(), light.long())


class GearNetHeavyLightStructureSequenceEdgesDataset(GearNetDataset):
    def __init__(
        self,
        df,
        s3_pdb_path: str = "OAS_paired",
        data_dir: str = str(pathlib.Path(__file__).parent.parent / "data"),
        in_memory: bool = True,
        add_noise: bool = True,
        noise_var: float = 0.01,
        max_seq_len: int = 301,
        max_heavy_len: int = 151,
        max_light_len: int = 150,
        chain_map = {'A': 0, 'B': 1},
        add_cdr=False,
        edges=False
        ):
        super().__init__(df=df,
                         s3_pdb_path=s3_pdb_path,
                         data_dir=data_dir,
                         in_memory=in_memory,
                         add_noise=add_noise,
                         noise_var=noise_var,
                         chain_map=chain_map,
                         add_cdr=add_cdr,
                         )
        self.max_seq_len = max_seq_len
        self.max_heavy_len = max_heavy_len
        self.max_light_len = max_light_len
        self.edges = edges
        
    
    def get_data_for_id(self, id):
        if self.in_memory:
            struct = self.examples[id]
        else:
            struct = torch.load(self.DATA_DIR / f"{id}.pt")

        with struct.residue():
            struct.residue_feature = struct.node_feature.to_dense()
        struct.view = "residue"

        sequence = struct.residue_type
        chain_id = struct.chain_id
        # replace with transform_to_heavy_light
        heavy_len = chain_id[chain_id==0].shape[0]
        heavy, light = sequence[:heavy_len], sequence[heavy_len:]
        heavy = torch.nn.functional.pad(heavy, (0, self.max_heavy_len - heavy.shape[0]), value=20)
        light = torch.nn.functional.pad(light, (0, self.max_light_len - light.shape[0]), value=20)
        sequence = torch.cat([heavy, light], dim=0)
        edges_org = get_edges_from_gearnet_struct(struct)
        edges = transform_2d_to_heavy_light(chain_id,
                                         edges_org,
                                         self.max_heavy_len,
                                         self.max_light_len,
                                         value=1000)
        if self.edges:
            return struct, edges, sequence.long()
        else:
            return struct, sequence.long()


    def __getitem__(self, idx) -> Tuple[Protein, torch.Tensor]:
        if idx > len(self.df):
            idx = int(idx % len(self.df))

        row = self.df.iloc[idx]
        return self.get_data_for_id(row["seqid"])
        
        


class GearNetSequenceEdgesDataset(GearNetHeavyLightStructureSequenceEdgesDataset):
    def __init__(
        self,
        df,
        s3_pdb_path: str = "OAS_paired",
        data_dir: str = str(pathlib.Path(__file__).parent.parent / "data"),
        in_memory: bool = True,
        add_noise: bool = True,
        noise_var: float = 0.01,
        max_seq_len: int = 301,
        max_heavy_len: int = 151,
        max_light_len: int = 150,
        chain_map = {'A': 0, 'B': 1},
        add_cdr=False,
        edges=False,
        ):
        super().__init__(df=df,
                         s3_pdb_path=s3_pdb_path,
                         data_dir=data_dir,
                         in_memory=in_memory,
                         add_noise=add_noise,
                         noise_var=noise_var,
                         chain_map=chain_map,
                         add_cdr=add_cdr,
                         )
        self.max_seq_len = max_seq_len
        self.max_heavy_len = max_heavy_len
        self.max_light_len = max_light_len
        self.edges = edges
        
    
    def __getitem__(self, idx) -> Tuple[Protein, torch.Tensor]:
        if idx > len(self.df):
            idx = int(idx % len(self.df))

        row = self.df.iloc[idx]
        if self.edges:
            _, edges, label = self.get_data_for_id(row["seqid"])
            seq = torch.nn.functional.one_hot(label)
            return seq.float(), edges, label
        else:
            _, label = self.get_data_for_id(row["seqid"])
            seq = torch.nn.functional.one_hot(label)
            return seq.float(), label
        


def debug(coords, outfile='coords_atoms.png'):
    import matplotlib.pyplot as plt
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    ax.scatter(coords[:, 0].clone().detach().cpu().numpy(),
               coords[:, 1].clone().detach().cpu().numpy(),
               coords[:, 2].clone().detach().cpu().numpy(),
               c='blue', s=60)
    plt.savefig(outfile, dpi=600)
    plt.close()

class GearNetHeavyLightStructureSurfSequenceDataset(GearNetHeavyLightStructureSequenceDataset):
    def __init__(
        self,
        df,
        s3_pdb_path: str = "OAS_paired",
        data_dir: str = str(pathlib.Path(__file__).parent.parent / "data"),
        in_memory: bool = True,
        add_noise: bool = True,
        noise_var: float = 0.01,
        max_seq_len: int = 301,
        max_heavy_len: int = 151,
        max_light_len: int = 150,
        add_noise_pos: int = True,
        add_cdr=False
        ):
        super().__init__(s3_pdb_path=s3_pdb_path,
                         df=df,
                         data_dir=data_dir,
                         in_memory=in_memory,
                         add_noise=add_noise,
                         noise_var=noise_var,
                         max_seq_len=max_seq_len,
                         max_heavy_len=max_heavy_len,
                         max_light_len=max_light_len,
                         add_cdr=add_cdr
                         )
        
        self.add_noise_pos = add_noise_pos
    
    def __getitem__(self, idx) -> Tuple[Protein, torch.Tensor]:
        if idx > len(self.df):
            idx = int(idx % len(self.df))

        row = self.df.iloc[idx]
        id = row["seqid"]
        
        if self.in_memory:
            struct = self.examples[id]
        else:
            struct = torch.load(self.DATA_DIR / f"{id}.pt")

        with struct.residue():
            struct.residue_feature = struct.node_feature.to_dense()
        struct.view = "residue"

        #debug(struct.node_position, outfile=f'coords_{id}.png')
        chain_id = struct.chain_id
        first_occurrences = torch.empty_like(struct.residue_type, dtype=torch.long)
        for i in range(struct.residue_type.shape[0]):
            first_occurrences[i] = (struct.atom2residue == i).nonzero(as_tuple=True)[0][0].item()
        node_positions = struct.node_position[first_occurrences, :]
        #debug(node_positions, outfile=f'coords_{id}_firstatom.png')
        positions = transform_to_heavy_light(chain_id, node_positions,
                                             self.max_heavy_len,
                                             self.max_light_len,
                                             value=0)
        #debug(positions, outfile=f'coords_{id}_fixedlength.png')
        if self.add_noise_pos:
            positions += get_noise(positions, noise_var=0.1)
        # replace with transform_to_heavy_light
        sequence = transform_to_heavy_light(chain_id, struct.residue_type, 
                                            self.max_heavy_len,
                                            self.max_light_len,
                                            value=20)
        return struct, sequence.long(), positions



class GearNetHeavyLightAnnotatedDataset(GearNetHeavyLightStructureSurfSequenceDataset):
    def __init__(
        self,
        df,
        s3_pdb_path: str = "OAS_paired",
        data_dir: str = str(pathlib.Path(__file__).parent.parent / "data"),
        in_memory: bool = True,
        add_noise: bool = True,
        noise_var: float = 0.01,
        max_seq_len: int = 301,
        max_heavy_len: int = 151,
        max_light_len: int = 150,
        add_noise_pos: int = True
        ):
        super().__init__(s3_pdb_path=s3_pdb_path,
                         df=df,
                         data_dir=data_dir,
                         in_memory=in_memory,
                         add_noise=add_noise,
                         noise_var=noise_var,
                         max_seq_len=max_seq_len,
                         max_heavy_len=max_heavy_len,
                         max_light_len=max_light_len,
                         add_noise_pos=add_noise_pos,
                         add_cdr=True
                         )
    
    def __getitem__(self, idx) -> Tuple[Protein, torch.Tensor]:
        if idx > len(self.df):
            idx = int(idx % len(self.df))

        row = self.df.iloc[idx]
        id = row["seqid"]
        labels = row['affinity_pkd']
        
        if self.in_memory:
            struct = self.examples[id]
        else:
            struct = torch.load(self.DATA_DIR / f"{id}.pt")

        with struct.residue():
            struct.residue_feature = struct.node_feature.to_dense()
        struct.view = "residue"

        chain_id = struct.chain_id
        first_occurrences = torch.empty_like(struct.residue_type, dtype=torch.long)
        for i in range(struct.residue_type.shape[0]):
            first_occurrences[i] = (struct.atom2residue == i).nonzero(as_tuple=True)[0][0].item()
        node_positions = struct.node_position[first_occurrences, :]
        positions = transform_to_heavy_light(chain_id, node_positions,
                                             self.max_heavy_len,
                                             self.max_light_len,
                                             value=0)
        #debug(positions, outfile=f'coords_{id}_fixedlength.png')
        if self.add_noise_pos:
            positions += get_noise(positions, noise_var=0.1)
        # replace with transform_to_heavy_light
        sequence = transform_to_heavy_light(chain_id, struct.residue_type, 
                                            self.max_heavy_len,
                                            self.max_light_len,
                                            value=20)
        cdr_mask = transform_to_heavy_light(chain_id, struct.mol_feature,
                                            self.max_heavy_len,
                                            self.max_light_len,
                                            value=0
                                            )
        return struct, labels, positions, cdr_mask, id



