import os
import ast
import os.path as osp
import pandas as pd

import torch
from torch_geometric.data import InMemoryDataset, Data
from sklearn.model_selection import train_test_split

from .mol import smiles2graph
from .dataset_produce import SmilesRepeat
from .data_aug import gsplit


class PygPolymerDataset(InMemoryDataset):
    def __init__(
        self, root = "data_pyg", polymer_type="homopolymer", repeat_times = 1, task_name = 'mt', set_name='train', _use_concat_train = False, transform=None, pre_transform=None
    ):
        """
        - root (str): root directory to store the dataset folder
        - polymer_type (str): name of the dataset == homopolymer or copolymer
        - transform, pre_transform (optional): transform/pre-transform graph objects
        """
        self.root = root
        self.polymer_type = polymer_type
        self.repeat_times = repeat_times
        self.task_properties = [task_name]
        self.set_name = set_name
        self.task_type = self.task_properties[0]
        if _use_concat_train:
            self.root = osp.join(root, polymer_type, self.task_type, 'concat', str(self.repeat_times),self.set_name)
        else:
            raw_root = osp.join(root, polymer_type, self.task_type, str(self.repeat_times))
            if not osp.exists(raw_root):
                dataproduce = SmilesRepeat(self.repeat_times, self.task_type, root=f'{root}/{self.polymer_type}/')
                dataproduce.repeat()
            self.root = osp.join(raw_root,self.set_name)
            if not osp.exists(self.root):
                test_idx_path =  f"{root}/{self.polymer_type}/{self.task_type}/1/split/test.csv.gz"
                valid_idx_path =  f"{root}/{self.polymer_type}/{self.task_type}/1/split/valid.csv.gz"
                train_idx_path =  f"{root}/{self.polymer_type}/{self.task_type}/1/split/train.csv.gz"
                # obtaining original data splitting
                if not osp.exists(test_idx_path) or not osp.exists(valid_idx_path) or not osp.exists(train_idx_path):
                    self.get_idx_split(raw_root)
                    print("****** Data Splitted!******")
                gsplit(raw_root,test_idx_path,valid_idx_path,train_idx_path)

        if self.task_properties is None:
            self.num_tasks = None
            self.eval_metric = "jaccard"
        else:
            self.num_tasks = len(self.task_properties)
            self.eval_metric = "wmae"

        super(PygPolymerDataset, self).__init__(self.root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
  

    def get_idx_split(self, raw_root):
        path = osp.join(raw_root, "split")
        if not os.path.exists(path):
            os.makedirs(path)
        try: 
            train_idx = pd.read_csv(f'{self.root}/{self.polymer_type}/{self.task_type}/1/split/train.csv.gz', compression='gzip', header = None).values.T[0]
            valid_idx = pd.read_csv(f'{self.root}/{self.polymer_type}/{self.task_type}/1/split/valid.csv.gz', compression='gzip', header = None).values.T[0]
            test_idx = pd.read_csv(f'{self.root}/{self.polymer_type}/{self.task_type}/1/split/test.csv.gz', compression='gzip', header = None).values.T[0]
        except:
            csv_file = osp.join(raw_root,"raw.csv")
            data_df = pd.read_csv(csv_file)
            full_idx = list(range(len(data_df)))
            train_ratio, valid_ratio, test_ratio = 0.6, 0.1, 0.3
            train_idx, test_idx, _, test_df = train_test_split(full_idx, data_df, test_size=test_ratio, random_state=42)
            train_idx, valid_idx, _, _ = train_test_split(train_idx, train_idx, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42)
            df_train = pd.DataFrame({'train': train_idx})
            df_valid = pd.DataFrame({'valid': valid_idx})
            df_test = pd.DataFrame({'test': test_idx})
            df_train.to_csv(osp.join(path, 'train.csv.gz'), index=False, header=False, compression="gzip")
            df_valid.to_csv(osp.join(path, 'valid.csv.gz'), index=False, header=False, compression="gzip")
            df_test.to_csv(osp.join(path, 'test.csv.gz'), index=False, header=False, compression="gzip")
        return {'train': torch.tensor(train_idx, dtype = torch.long), 'valid': torch.tensor(valid_idx, dtype = torch.long), 'test': torch.tensor(test_idx, dtype = torch.long)}


    def get_task_weight(self, ids):
        if self.task_properties is not None:
            try:
                labels = self._data.y[torch.LongTensor(ids)]
                task_weight = []
                for i in range(labels.shape[1]):
                    valid_num = labels[:, i].eq(labels[:, i]).sum()
                    task_weight.append(valid_num)
                task_weight = torch.sqrt(
                    1 / torch.tensor(task_weight, dtype=torch.float32)
                )
                print('****************2:\n')
                print(task_weight)
                print(len(task_weight))
                print(task_weight / task_weight.sum() * len(task_weight))
                return task_weight / task_weight.sum() * len(task_weight)
            except Exception as e:
                print(f"An error occurred: {e}")
        else:
            return None

    @property
    def processed_file_names(self):
        return ["data_dev_processed.pt"]

    def process(self):
        csv_file = osp.join(self.root,f'{self.set_name}.csv')
        data_df = pd.read_csv(csv_file)

        pyg_graph_list = []
        for idx, row in data_df.iterrows():
            smiles = row["SMILES"]
            graph = smiles2graph(smiles, add_fp=True)

            g = Data()
            g.num_nodes = graph["num_nodes"]
            g.edge_index = torch.from_numpy(graph["edge_index"])

            del graph["num_nodes"]
            del graph["edge_index"]

            if graph["edge_feat"] is not None:
                g.edge_attr = torch.from_numpy(graph["edge_feat"])
                del graph["edge_feat"]

            if graph["node_feat"] is not None:
                g.x = torch.from_numpy(graph["node_feat"])
                del graph["node_feat"]

            
            g.fp = torch.tensor(graph["fp"], dtype=torch.int8).view(1, -1)
            del graph["fp"]

            if self.task_properties is not None:
                y = []
                for task in self.task_properties:
                    y.append(float(row[task]))
                g.y = torch.tensor(y, dtype=torch.float32).view(1, -1)
            else:
                g.y = torch.tensor(
                    ast.literal_eval(row["labels"]), dtype=torch.float32
                ).view(1, -1)
            pyg_graph_list.append(g)

        #print(networkx_list)
        pyg_graph_list = (
            pyg_graph_list
            if self.pre_transform is None
            else self.pre_transform(pyg_graph_list)
        )
        torch.save(self.collate(pyg_graph_list), self.processed_paths[0])

    def __repr__(self):
        return "{}()".format(self.__class__.__name__)


if __name__ == "__main__":
    pass