
import os
import numpy as np
from typing import Union
import torch
import torch.nn.functional as F
import collections
import pandas as pd
from torch.utils.data import Dataset

import logging
logger = logging.getLogger('gearnet_dataset')

import pathlib
from typing import Tuple

from torchdrug.data import Protein

from affinityenhancer.data.datasets.gearnet_dataset \
    import GearNetHeavyLightStructureSequenceDataset, \
    GearNetHeavyLightStructureSequenceEdgesDataset,\
        transform_to_heavy_light

class PairedDataset(GearNetHeavyLightStructureSequenceEdgesDataset):
    def __init__(
        self,
        df,
        df_paired,
        s3_pdb_path: str = "OAS_paired",
        data_dir: str = str(pathlib.Path(__file__).parent.parent / "data"),
        in_memory: bool = True,
        add_noise: bool = True,
        noise_var: float = 0.01,
        max_seq_len: int = 301,
        max_heavy_len: int = 151,
        max_light_len: int = 150,
        chain_map: dict = {'A': 0, 'B': 1},
        add_cdr: bool = False,
        edges: bool = False
        ):
        super().__init__(df=df,
                         s3_pdb_path=s3_pdb_path,
                         data_dir=data_dir,
                         in_memory=in_memory,
                         add_noise=add_noise,
                         noise_var=noise_var,
                         max_seq_len=max_seq_len,
                         max_heavy_len=max_heavy_len,
                         max_light_len=max_light_len,
                         chain_map=chain_map,
                         add_cdr=add_cdr,
                         edges=edges
                         )
        self.df_paired = df_paired


    def get_heavy_light_tensor_for_id(self, id):

        if self.in_memory:
            struct = self.examples[id]
        else:
            struct = torch.load(self.DATA_DIR / f"{id}.pt")

        with struct.residue():
            struct.residue_feature = struct.node_feature.to_dense()
        struct.view = "residue"

        sequence = struct.residue_type
        chain_id = struct.chain_id
        heavy_len = chain_id[chain_id == 0].shape[0]
        heavy, light = sequence[:heavy_len], sequence[heavy_len:]
        
        heavy = torch.nn.functional.pad(heavy, (0, self.max_heavy_len - heavy.shape[0]), value=20)
        light = torch.nn.functional.pad(light, (0, self.max_light_len - light.shape[0]), value=20)
        sequence = torch.cat([heavy, light], dim=0)
        
        return struct, sequence.long()
    
    def __len__(self) -> int:
        """Returns length of dataset"""
        if self.df_paired is None:
            raise ValueError("Please initialize self.df by running dataset.setup(df) first")
        return self.df_paired.shape[0]


    def __getitem__(self, idx) -> Tuple[Protein, torch.Tensor]:
        if idx > len(self.df_paired):
            idx = int(idx % len(self.df_paired))

        row = self.df_paired.iloc[idx]
        
        if self.edges:
            _, edges_1, label_1 = self.get_data_for_id(row["first_seqid"])
            _, _, label_2 = self.get_data_for_id(row["second_seqid"])
        else:    
            _, label_1 = self.get_data_for_id(row["first_seqid"])
            _, label_2 = self.get_data_for_id(row["second_seqid"])

        seq_1 = torch.nn.functional.one_hot(label_1)
        seq_2 = torch.nn.functional.one_hot(label_2)

        if self.edges:
            return seq_1.float(), seq_2.float(), edges_1, label_1, label_2
        
        return seq_1.float(), seq_2.float(), label_1, label_2


class GearNetPairedDataset(GearNetHeavyLightStructureSequenceEdgesDataset):
    def __init__(
        self,
        df,
        df_paired,
        s3_pdb_path: str = "OAS_paired",
        data_dir: str = str(pathlib.Path(__file__).parent.parent / "data"),
        in_memory: bool = True,
        add_noise: bool = True,
        noise_var: float = 0.01,
        max_seq_len: int = 301,
        max_heavy_len: int = 151,
        max_light_len: int = 150,
        chain_map: dict = {'A': 0, 'B': 1},
        add_cdr: bool = False,
        edges: bool = False
        ):
        super().__init__(df=df,
                         s3_pdb_path=s3_pdb_path,
                         data_dir=data_dir,
                         in_memory=in_memory,
                         add_noise=add_noise,
                         noise_var=noise_var,
                         max_seq_len=max_seq_len,
                         max_heavy_len=max_heavy_len,
                         max_light_len=max_light_len,
                         chain_map=chain_map,
                         add_cdr=add_cdr,
                         edges=edges
                         )
        self.df_paired = df_paired

    def __len__(self) -> int:
        """Returns length of dataset"""
        if self.df_paired is None:
            raise ValueError("Please initialize self.df by running dataset.setup(df) first")
        return self.df_paired.shape[0]


    def __getitem__(self, idx) -> Tuple[Protein, torch.Tensor]:
        if idx > len(self.df_paired):
            idx = int(idx % len(self.df_paired))

        row = self.df_paired.iloc[idx]
        
        if self.edges:
            struct_1, edges_1, label_1 = self.get_data_for_id(row["first_seqid"])
            struct_2, _, label_2 = self.get_data_for_id(row["second_seqid"])
            return struct_1, struct_2, edges_1, label_1, label_2
        
        struct_1, label_1 = self.get_data_for_id(row["first_seqid"])
        struct_2, label_2 = self.get_data_for_id(row["second_seqid"])
        return struct_1, struct_2, label_1, label_2

