import os
import os.path as osp
import json
import numpy as np
import pickle as cPickle
import torch.utils.data as data
from src.datasets.utils import cached_property


class AlphaFoldDataset(data.Dataset):
    def __init__(self, path='./', upid='', mode='train', max_length=500, limit_length=1, joint_data=0):
        
        self.path = path
        self.upid = upid
        self.max_length = max_length
        self.limit_length = limit_length
        self.joint_data = joint_data
        
        if mode in ['train', 'valid', 'test']:
            self.data = self.cache_data[mode]
        
        if mode == 'all':
            self.data = self.cache_data['train'] + self.cache_data['valid'] + self.cache_data['test']
        
        self.lengths = np.array([ len(sample['seq']) for sample in self.data])
        self.max_len = np.max(self.lengths)
        self.min_len = np.min(self.lengths)

    def _raw_data(self, path, upid):
        if not os.path.exists(path):
            raise "no such file:{} !!!".format(path)
        else:
            path = osp.join(path, upid)
            data_ = cPickle.load(open(path+'/data_{}.pkl'.format(upid),'rb'))
            score_ = cPickle.load(open(path+'/data_{}_score.pkl'.format(upid),'rb'))
            for i in range(len(data_)):
                data_[i]['score'] = score_[i]['res_score']
        return data_

    def _data_info(self, data):
        len_inds = []
        seq2ind = {}
        for ind, temp in enumerate(data):
            if self.limit_length:
                if 30 < len(temp['seq']) and len(temp['seq']) < self.max_length:
                    # 'title', 'seq', 'CA', 'C', 'O', 'N'
                    len_inds.append(ind)
                    seq2ind[temp['seq']] = ind
            else:
                len_inds.append(ind)
                seq2ind[temp['seq']] = ind
        return len_inds, seq2ind
        
    def get_data(self, path, upid, **kwargs):
        data_ = self._raw_data(path, upid)
        path = osp.join(path, upid)

        file_name = 'split_clu_l.json' if self.limit_length else 'split_clu.json'

        assert os.path.exists(osp.join(path, file_name))
        split = json.load(open(osp.join(path, file_name),'r'))
        data_dict = {'train':[data_[i] for i in split['train']],
                     'valid':[data_[i] for i in split['valid']],
                     'test':[data_[i] for i in split['test']]}
        return data_dict

    def get_full_data(self, path, **kwargs):
        datanames = [dataname for dataname in os.listdir(path) if ('_v2' in dataname)]
        file_name = 'split_clu_l.json' if self.limit_length else 'split_clu.json'
        assert os.path.exists(osp.join(path, 'full', file_name))
        split = json.load(open(osp.join(path, 'full', file_name),'r'))
        return split 
    
    @cached_property
    def cache_data(self):
        path = self.path
        upid = self.upid
        if self.joint_data:
            datanames = [dataname for dataname in os.listdir(path) if ('_v2' in dataname)]
            data_dict = {'train':[], 'valid':[], 'test':[]}
            full_inds = self.get_full_data(path)

            for dataname in datanames:
                temp = self._raw_data(path, dataname)
                train_idx, valid_idx, test_idx = map(lambda fold: full_inds[dataname][fold], ['train', 'valid', 'test'])
                data_dict['train'] += [temp[i] for i in train_idx]
                data_dict['valid'] += [temp[i] for i in valid_idx]

                data_test = []
                for i in test_idx:
                    item = temp[i]
                    item['category'] = dataname
                    data_test.append(temp[i])
                
                data_dict['test'] += data_test

        else:
            data_dict = self.get_data(path, upid)
            for item in data_dict['test']:
                item['category'] = upid

        return data_dict
    
    def change_mode(self, mode):
        self.data = self.cache_data[mode]
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]