from typing import Dict,NamedTuple,Optional,Union,List
import torch
from collections import namedtuple
import pandas as pd
import numpy as np
from sklearn.decomposition import PCA
class Sample:
    def __init__(self,feat:Optional[torch.Tensor],observed_value:Optional[torch.Tensor],
                 category_value:Optional[Dict],predict_value:Optional[torch.Tensor]=None,
                 additional_feat:Dict[str,torch.Tensor]={})->None:
        self._feat:torch.Tensor = feat
        self.additional_feat:Dict[str,torch.Tensor] = additional_feat
        self.predict_value:torch.Tensor = predict_value
        self._observed_value:torch.Tensor = observed_value
        self.category_value:Dict[str,str] = category_value
        self._is_observed:bool = False
    
    
    def __str__(self):
        return f"{self.category_value} -> observed_value:{self.observed_value},predict_value:{self.predict_value}"
    @property
    def is_observed(self)->bool:
        return self._is_observed
    @property
    def embedding(self)->torch.Tensor:
        return self.additional_feat.get('embedding')
    @property
    def feat(self)->torch.Tensor:
        return self._feat
    @feat.setter
    def feat(self,value:torch.Tensor)->None:
        self._feat = value
    @property
    def observed_value(self)->torch.Tensor:
        self._is_observed = True
        return self._observed_value

    
class BoData():
    @staticmethod
    def read_sci_data(path, x_y_split)->Dict[int,Sample]:
        category = x_y_split[0]
        objective = x_y_split[1]
        raw_df = pd.read_csv(path,na_values=[])
        samples = {}
        feats = torch.tensor(pd.get_dummies(raw_df[category]).astype(float).values,dtype=torch.float64)
    
        
        for i, (index,row) in enumerate(raw_df.iterrows()):
            category_value = {}
            objective_value = []
            for col in category:
                category_value[col] = row[col]
            for col in objective:
                objective_value.append(row[col])
            objective_value = torch.tensor(objective_value,dtype=torch.float64).view(-1,len(objective)).clone().detach()
            sample = Sample(feat=feats[i],observed_value=objective_value,category_value=category_value)
            samples[i] = sample
        return samples
    @staticmethod
    def read_buchwald_data(path)->Dict[int,Sample]:
        category = ['Reactant2','Ligand','Base','Additive']
        objective = ['Yield']
        raw_df = pd.read_csv(path,na_values=[])
        samples = {}
        feats = torch.tensor(pd.get_dummies(raw_df[category]).values,dtype=torch.float64)
    
        
        for i, (index,row) in enumerate(raw_df.iterrows()):
            category_value = {}
            objective_value = []
            for col in category:
                category_value[col] = row[col]
            for col in objective:
                objective_value.append(row[col])
            objective_value = torch.tensor(objective_value,dtype=torch.float64).view(-1,len(objective)).clone().detach()
            sample = Sample(feat=feats[i],observed_value=objective_value,category_value=category_value)
            samples[i] = sample
        return samples
    @staticmethod
    def read_suzuki_exp_data(path)->Dict[int,Sample]:
        category = ['electrophile','nucleophile','catalyst','ligand','base','solvent']
        objective = ['yield']
        raw_df = pd.read_csv(path,na_values=[])
        samples = {}
        feats = torch.tensor(pd.get_dummies(raw_df[category]).values,dtype=torch.float64)
    
        
        for i, (index,row) in enumerate(raw_df.iterrows()):
            category_value = {}
            objective_value = []
            for col in category:
                category_value[col] = row[col]
            for col in objective:
                objective_value.append(row[col])
            objective_value = torch.tensor(objective_value,dtype=torch.float64).view(-1,len(objective)).clone().detach()
            
            sample = Sample(feat=feats[i],observed_value=objective_value,category_value=category_value)
            samples[i] = sample
        return samples
    @staticmethod
    def read_arylation_exp_data(path)->Dict[int,Sample]:
        category = ["Ligand_SMILES", "Base_SMILES", "Additive_SMILES", "Aryl_halide_SMILES"]
        objective = ['yield']
        raw_df = pd.read_csv(path,na_values=[])
        samples = {}
        feats = torch.tensor(pd.get_dummies(raw_df[category]).values,dtype=torch.float64)
    
        
        for i, (index,row) in enumerate(raw_df.iterrows()):
            category_value = {}
            objective_value = []
            for col in category:
                category_value[col] = row[col]
            for col in objective:
                objective_value.append(row[col])
            objective_value = torch.tensor(objective_value,dtype=torch.float64).view(-1,len(objective)).clone().detach()
            
            sample = Sample(feat=feats[i],observed_value=objective_value,category_value=category_value)
            samples[i] = sample
        return samples
    @staticmethod
    def read_tandem_exp_data(path)->Dict[int,Sample]:
        
        category = ['Reactant1','Reactant2','Ligand','Base','Solvent','Additive']
        objective = ['Yield']
        raw_df = pd.read_csv(path,na_values=[])
        samples = {}
        feats = torch.tensor(pd.get_dummies(raw_df[category]).values,dtype=torch.float64)
    
        
        for i, (index,row) in enumerate(raw_df.iterrows()):
            category_value = {}
            objective_value = []
            for col in category:
                category_value[col] = row[col]
            for col in objective:
                objective_value.append(row[col])
            objective_value = torch.tensor(objective_value,dtype=torch.float64).view(-1,len(objective)).clone().detach()
            
            sample = Sample(feat=feats[i],observed_value=objective_value,category_value=category_value)
            samples[i] = sample
        return samples
    
    @staticmethod
    def read_cpa_exp_data(path)->Dict[int,Sample]:
        
        category = ['Reactant1','Reactant2','Catalyst','Solvent']
        objective = ['objective']
        raw_df = pd.read_csv(path,na_values=[])
        samples = {}
        feats = torch.tensor(pd.get_dummies(raw_df[category]).values,dtype=torch.float64)    
        for i, (index,row) in enumerate(raw_df.iterrows()):
            category_value = {}
            objective_value = []
            for col in category:
                category_value[col] = row[col]
            for col in objective:
                objective_value.append(row[col])
            objective_value = torch.tensor(objective_value,dtype=torch.float64).view(-1,len(objective)).clone().detach()
            
            sample = Sample(feat=feats[i],observed_value=objective_value,category_value=category_value)
            samples[i] = sample
        return samples
    @staticmethod
    def read_tongji_exp_data(path,search_space_path,uncompleted)->Dict[int,Sample]:
        ###  
        # category = ['Catalyst','Solvent','Base','Ligand']
        category = ['Catalyst','Solvent','Base','Ligand', 'Water', 'Temperature']
        # category = ['Water', 'Temperature']
        objective = ['Yield']
        experiment_result = pd.read_csv(path,na_values=[])
        search_space = pd.read_csv(search_space_path,na_values=[])
        uncompleted = pd.read_csv(uncompleted,na_values=[])
        print(f'exp_results: {experiment_result[category+objective]}')
        
        merged_data = pd.merge(
        search_space[category],  # 保留变量空间的所有组合
        experiment_result[category + objective],  # 带Yield的实验数据
        on=category,
        how='left'  # 保留所有变量空间组合
    )   
        merged_data['Yield'] = merged_data['Yield'].fillna(-1)
        uncompleted['Yield'] = -1
        samples = {}
        
        # print(merged_data[category])
        merged_data['Water'] = merged_data['Water'].astype(str)
        merged_data['Temperature'] = merged_data['Temperature'].astype(str)
        # print(merged_data[category])
        one_hot = pd.get_dummies(merged_data[category])
        # print(one_hot.dtypes)
        # print(one_hot.values)
        feats = torch.tensor(one_hot.values,dtype=torch.float64)
        # feats = torch.tensor(pd.get_dummies(merged_data[category]).values,dtype=torch.float64)    
        mask_uncompleted = []
        for i, (index,row) in enumerate(merged_data.iterrows()):
            category_value = {}
            objective_value = []
            for col in category:
                category_value[col] = row[col]
            for col in objective:
                objective_value.append(row[col])
            objective_value = torch.tensor(objective_value,dtype=torch.float64).view(-1,len(objective)).clone().detach()
            #如果本行出现在uncompleted中，将idx加入mask_uncompleted
            if (row[category] == uncompleted[category]).all(axis=1).any():
                mask_uncompleted.append(i)
            sample = Sample(feat=feats[i],observed_value=objective_value,category_value=category_value)
            if objective_value[0] != -1:
                print(f"observed value:{objective_value}")
                sample._is_observed = True
            samples[i] = sample
        print(f'mask_uncompleted:{mask_uncompleted}')
        return samples,mask_uncompleted

        
    def load_data_prediction(self,path,group_index=None)->np.ndarray:
        prediction = torch.load(path,weights_only=False,map_location=torch.device('cpu'))['pred_yields_by_rxn']
        if group_index is not None:
            prediction = prediction[int(group_index[0]):int(group_index[1])+1]
        for i, pred in enumerate(prediction):
            self._samples[i].predict_value = pred.clone().detach().clone().detach().view(-1,1).to(torch.float64)
        return prediction
    
        
    def load_data_embedding(self,path,n_pca:Optional[int])->None:
        embedding = torch.load(path,weights_only=False,map_location=torch.device('cpu'))['cls_embs']
        if n_pca is not None:
            pca = PCA(n_components=n_pca)
            embedding = pca.fit_transform(embedding)
            embedding = torch.tensor(embedding,dtype=torch.float64)
        for i, emb in enumerate(embedding):
            self._samples[i].additional_feat['embedding'] = emb.clone().detach().to(torch.float64)

    def make_data_harder_from_npy(self, npy_path: str) -> None:
        """从预先生成的长尾分布结果 npy 文件中加载数据"""

        kept_indices = np.load(npy_path)          # 一维 int64 数组
        # 过滤：只保留 npy 里出现的样本
        self._samples = {idx: self._samples[idx]
            for idx in kept_indices
            if idx in self._samples}
    
    def make_data_harder(self,factor,keep_idx = True,random_seed=None)->None:
        """Make the data distribution more challenging by dropping samples to create a long-tail effect"""
        
        sorted_indices = sorted(self._samples.keys(), 
                            key=lambda x: self._samples[x]._observed_value)
        
        n_samples = len(sorted_indices)
        samples_to_keep = {}
        samples_to_keep_idx = {}
        new_idx = 0
        # Keep samples with probability decreasing exponentially
        for i, idx in enumerate(sorted_indices):
            # Calculate retention probability based on position
            keep_prob = np.exp(-0.1 * i * factor/ n_samples)
            
            # Randomly decide whether to keep sample
            if np.random.random() < keep_prob:
                samples_to_keep[new_idx] = self._samples[idx]
                samples_to_keep_idx[idx] = self._samples[idx]
                new_idx += 1
        if keep_idx:
            self._samples = samples_to_keep_idx
        else:
            self._samples = samples_to_keep
        
    def use_test_set(self,path,group_index=None)->None:
        # import pdb;pdb.set_trace()
        test_idx = torch.load(path, weights_only=False,map_location=torch.device('cpu'))['val_idx']
        if group_index is not None:
            test_idx = [i for i in test_idx if i >= int(group_index[0]) and i <= int(group_index[1])]
        self._samples = {i:self._samples[i] for i in test_idx}
        
        
    def __init__(self,samples:Optional[Dict[int,Sample]])->None:
        self._samples:Dict[int,Sample] = samples
        
    def __len__(self):
        return len(self._samples)
    def __getitem__(self,index:Union[int,List[int]])->Sample:
        if type(index) == list:
            return [self._samples.get(i) for i in index]
        return self._samples.get(index)
    def __iter__(self):
        # Convert samples dict to list of values for iteration
        self._iter_samples = list(self._samples.values())
        self._iter_index = 0
        return self
    
    def __next__(self):
        if self._iter_index >= len(self._iter_samples):
            raise StopIteration
        sample = self._iter_samples[self._iter_index]
        self._iter_index += 1
        return sample
    
    @property
    def all_idxes(self)->List[int]:
        return list(self._samples.keys())
    
    @property
    def max_value(self)->float:
        return max([sample._observed_value for sample in self._samples.values()])
    