from torch.utils.data import Dataset
from utils import *

import os

from tdc.multi_pred import DTI
from tdc.generation import MolGen
from tdc.single_pred import ADME
from tdc.single_pred import Tox
from tdc.utils import retrieve_label_name_list
import pandas as pd
import numpy as np
from collections import Counter

from transformers import AutoTokenizer, AutoModel, T5EncoderModel, BertModel, BertTokenizer
import torch
from tqdm import tqdm

def get_splits(data):
    split = data.get_split()
    split = [split['train'], split['valid'], split['test']]
    split = pd.concat(split)
    return split

def get_drug_embeddings(smiles_list, tokenizer, model, batch_size=64):
    embeddings = []
    for i in tqdm(range(0, len(smiles_list), batch_size)):
        batch = smiles_list[i:i+batch_size]
        encoding = tokenizer(batch, return_tensors="pt", padding=True, truncation=True)
        input_ids = encoding["input_ids"].to('cuda')
        attention_mask = encoding["attention_mask"].to('cuda')  # Get the attention mask
        
        with torch.no_grad():
            # Pass both input_ids and attention_mask to the model
            output = model(input_ids, attention_mask=attention_mask)
            # You can use either the last hidden state or the pooled output
            # depending on your requirements. Here we're using the mean of the last hidden state.
            embedding = output.last_hidden_state.mean(1)
            embeddings.extend(embedding.cpu().numpy())
            
    return embeddings

def get_prot_embeddings(prot_list, tokenizer, model, batch_size=64):
    embeddings = []
    for i in tqdm(range(0, len(prot_list), batch_size)):
        batch = prot_list[i:i+batch_size]
        encoded_input = tokenizer(batch, return_tensors='pt', padding=True,).to('cuda')
        with torch.no_grad():
            output = model(**encoded_input)
            embedding = output.last_hidden_state.mean(1)
            embeddings.extend(embedding.cpu().numpy())

    del model
    return embeddings

class DTIDataset(Dataset):
    def __init__(
        self, 
        path, 
        drug_embeddings_path, 
        prot_embedding_path,
        dti_transform = None,
        drug_dki = None,
        target_dki = None
    ) -> None:
        super().__init__()

        f = pd.read_csv(path)
        drug_embeddings = np.load(drug_embeddings_path)
        prot_embeddings = np.load(prot_embedding_path)
        smiles = f['Drug'].values
        proteins = f['Target'].values
        cols = f.drop(labels=['Drug', 'Target'], axis=1).columns

        vlists = {
            col: f[col].values for col in f.drop(labels=['Drug', 'Target'], axis=1).columns 
        }
        inmask = remove_outliers([v for _,v in vlists.items()])
        print(sum(inmask))
        smiles = smiles[inmask]
        vlists = {
            k: v[inmask] for k,v in vlists.items()
        }

        nullmask = np.stack([
            np.isnan(v)==False for _,v in vlists.items()
            ], axis=-1)

        data = f[cols].values[inmask]
        if dti_transform is not None:
            data = dti_transform.fit_transform(data)
            mins = np.nanmin(data.T, axis=1)
            maxs = np.nanmax(data.T, axis=1)
            self.clip = [
                torch.tensor(mins),
                torch.tensor(maxs)
            ]
            vlists = {
                col: data for col, data in zip(cols, data.T)
            }
        else: # min max normalization
            vlists = {
                k: norm(v) for k,v in vlists.items()
            }
            self.clip = [
                torch.tensor([-1]*len(vlists.items())),
                torch.tensor([1]*len(vlists.items()))
            ]

        # Clip to -1 to 1 anyways
        vlists = {
            k: norm(v) for k,v in vlists.items()
        }

        self.dmss = []
        for k,v in vlists.items():
            vlists[k], dms = sample_local_gaussian(v)
            self.dmss.append(dms)

        # TODO: Train models here to infill values better!

        self.dataset = []
        for i, gt in enumerate(zip(*[v for _,v in vlists.items()])):
            entry = {
                "sm": smiles[i],
                "tg": proteins[i],
                "de": drug_embeddings[i],
                "pe": prot_embeddings[i],
                "ma": nullmask[i],
                "dti": np.array(gt),
            }

            if drug_dki is not None:
                entry["dd"] = drug_dki[smiles[i]]

            if target_dki is not None:
                if proteins[i] not in target_dki:
                    continue
                entry["td"] = target_dki[proteins[i]]

            self.dataset.append(entry)
        print(len(self.dataset))
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]
    
    def update(self, idx, delta):
        item = self.dataset[idx]["gt"]
        self.dataset[idx]["gt"] = item + delta

class DTIDataloader:
    def __init__(
        self,
        embed_model_name,
        data_dir,
        dti_transform = None,
        drug_dki = None,
        target_dki = None
    ):
        self.embed_model = {
            "t5":'sagawa/PubChem-10m-t5-v2',
            "deberta":'sagawa/PubChem-10m-deberta',
            'chemberta_zinc': 'seyonec/ChemBERTa-zinc-base-v1',
            'chemberta_10m': 'DeepChem/ChemBERTa-10M-MLM'

        }[embed_model_name]

        self.data_dir = data_dir
        self.data_dfs = {}

        self.save_data_file_name = os.path.join(data_dir, "xtended_dti_data_all.csv")
        self.drug_embed_file_name = os.path.join(data_dir, f"xtended_dti_drug_emb_all_{embed_model_name}.npy")
        self.prot_embed_file_name = os.path.join(data_dir, f"xtended_dti_prot_emb_all_prot_bert.npy")

        if not os.path.exists(self.save_data_file_name):
            self._download()
            self._build_dataset()

        if not os.path.exists(self.drug_embed_file_name) or not os.path.exists(self.prot_embed_file_name):
            self.dti_df = pd.read_csv(self.save_data_file_name)
            self._generate_embeddings()

        self.dataset = DTIDataset(
            self.save_data_file_name, 
            self.drug_embed_file_name,
            self.prot_embed_file_name,
            dti_transform,
            drug_dki,
            target_dki
        )

    def _download(
        self
    ):
        dti_names = [
            'BindingDB_Kd',
            'BindingDB_IC50',
            'BindingDB_Ki'
        ]

        for name in dti_names:
            data = DTI(name=name, path=self.data_dir)
            data.convert_to_log(form = 'binding')
            self.data_dfs[name] = get_splits(data)

    def _build_dataset(
        self
    ):
        dti_pairs = {}
        dti_sets = {name:[] for name in self.data_dfs}
        for name, df in self.data_dfs.items():
            for d, t, i in zip(df["Drug"].values, df["Target"].values, df["Y"].values):
                if len(d) > 700:
                    continue
                if len(t) > 2000:
                    continue
                pair = ",".join([d, t])
                if pair not in dti_pairs:
                    dti_pairs[pair] = {
                        name: None
                        for name in self.data_dfs
                    }
                dti_pairs[pair][name] = i
                dti_sets[name].append(pair)


        d = {
            "Drug":[pair.split(",")[0] for pair in dti_pairs],
            "Target":[pair.split(",")[1] for pair in dti_pairs]
        }
        for name in self.data_dfs:
            d[name] = [i[name] for _,i in dti_pairs.items()]

        self.dti_df = pd.DataFrame(data = d)
        self.dti_df.to_csv(self.save_data_file_name, index=False)

    def _generate_embeddings(
        self
    ):
        drug_tokenizer = AutoTokenizer.from_pretrained(self.embed_model)
        drug_model = T5EncoderModel.from_pretrained(self.embed_model)

        prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
        prot_model = BertModel.from_pretrained("Rostlab/prot_bert")

        prot_model.to('cuda')
        drug_model.to('cuda')

        self.drug_embeddings = get_drug_embeddings(list(self.dti_df['Drug']), drug_tokenizer, drug_model)
        self.prot_embeddings = get_prot_embeddings(list(self.dti_df['Target']), prot_tokenizer, prot_model)

        np.save(self.drug_embed_file_name, np.array(self.drug_embeddings))
        np.save(self.prot_embed_file_name, np.array(self.prot_embeddings))
        
        
        