# -*- coding: utf-8 -*-
"""
Created on Thu Jul  4 11:53:09 2019

@author: badat
"""
import hashlib
import os,sys
#import scipy.io as sio
import torch
import numpy as np
import h5py
import time
import pickle
import pdb
from sklearn import preprocessing
from global_setting import NFS_path
#%%
import scipy.io as sio
import pandas as pd
#%%
#import pdb
#%%

img_dir = os.path.join(NFS_path,'H:/Models/CUB/')

class CUBDataLoader():
    def __init__(self, data_path, device, is_scale = False,is_unsupervised_attr = False,is_balance=True):

        print(data_path)
        sys.path.append(data_path)

        self.data_path = data_path
        self.device = device
        self.dataset = 'CUB'
        print('$'*30)
        print(self.dataset)
        print('$'*30)
        self.datadir = self.data_path + 'H:/Models/{}/'.format(self.dataset)
        self.index_in_epoch = 0
        self.epochs_completed = 0
        self.is_scale = is_scale
        self.is_balance = is_balance
        if self.is_balance:
            print('Balance dataloader')
        self.is_unsupervised_attr = is_unsupervised_attr
        self.read_matdataset()
        self.get_idx_classes()        
        
    def next_batch_img(self, batch_size,class_id,is_trainset = False):
        features = None
        labels = None
        img_files = None
        if class_id in self.seenclasses:
            if is_trainset:
                features = self.data['train_seen']['resnet_features']
                labels = self.data['train_seen']['labels']
                img_files = self.data['train_seen']['img_path']
            else:
                features = self.data['test_seen']['resnet_features']
                labels = self.data['test_seen']['labels']
                img_files = self.data['test_seen']['img_path']
        elif class_id in self.unseenclasses:
            features = self.data['test_unseen']['resnet_features']
            labels = self.data['test_unseen']['labels']
            img_files = self.data['test_unseen']['img_path']
        else:
            raise Exception("Cannot find this class {}".format(class_id))
        
        #note that img_files is numpy type !!!!!
        
        idx_c = torch.squeeze(torch.nonzero(labels == class_id))
        
        features = features[idx_c]
        labels = labels[idx_c]
        img_files = img_files[idx_c.cpu().numpy()]
        
        batch_label = labels[:batch_size].to(self.device)
        batch_feature = features[:batch_size].to(self.device)
        batch_files = img_files[:batch_size]
        batch_att = self.att[batch_label].to(self.device)
        
        return batch_label, batch_feature,batch_files, batch_att
    

    def next_batch(self, batch_size):
        if self.is_balance:
            idx = []
            n_samples_class = max(batch_size //self.ntrain_class,1)
            sampled_idx_c = np.random.choice(np.arange(self.ntrain_class),min(self.ntrain_class,batch_size),replace=False).tolist()
            for i_c in sampled_idx_c:
                idxs = self.idxs_list[i_c]
                idx.append(np.random.choice(idxs,n_samples_class))
            idx = np.concatenate(idx)
            idx = torch.from_numpy(idx)
        else:
            idx = torch.randperm(self.ntrain)[0:batch_size]
    
        batch_feature = self.data['train_seen']['resnet_features'][idx].to(self.device)
        batch_label =  self.data['train_seen']['labels'][idx].to(self.device)
        batch_att = self.att[batch_label].to(self.device)
        return batch_label, batch_feature, batch_att
    
    def get_idx_classes(self):
        n_classes = self.seenclasses.size(0)
        self.idxs_list = []
        train_label = self.data['train_seen']['labels']
        for i in range(n_classes):
            idx_c = torch.nonzero(train_label == self.seenclasses[i].cpu()).cpu().numpy()
            idx_c = np.squeeze(idx_c)
            self.idxs_list.append(idx_c)
        return self.idxs_list


    def read_matdataset(self):

        path= self.datadir + 'feature_map_ResNet_101_{}.hdf5'.format(self.dataset)
        print('_____')
        print(path)
        # tic = time.time()
        hf = h5py.File(path, 'r')
        features = np.array(hf.get('feature_map'))
#        shape = features.shape
#        features = features.reshape(shape[0],shape[1],shape[2]*shape[3])
#        pdb.set_trace()
        labels = np.array(hf.get('labels'))
        trainval_loc = np.array(hf.get('trainval_loc'))
#        train_loc = np.array(hf.get('train_loc')) #--> train_feature = TRAIN SEEN
#        val_unseen_loc = np.array(hf.get('val_unseen_loc')) #--> test_unseen_feature = TEST UNSEEN
        test_seen_loc = np.array(hf.get('test_seen_loc'))
        test_unseen_loc = np.array(hf.get('test_unseen_loc'))
        
        if self.is_unsupervised_attr:
            print('Unsupervised Attr')
            class_path = './w2v/{}_class.pkl'.format(self.dataset)
            with open(class_path,'rb') as f:
                w2v_class = pickle.load(f)
            temp = np.array(hf.get('att'))
            print(w2v_class.shape,temp.shape)
#            assert w2v_class.shape == temp.shape 
            w2v_class = torch.tensor(w2v_class).float()
            
            U, s, V = torch.svd(w2v_class)
            reconstruct = torch.mm(torch.mm(U,torch.diag(s)),torch.transpose(V,1,0))
            print('sanity check: {}'.format(torch.norm(reconstruct-w2v_class).item()))
            
            print('shape U:{} V:{}'.format(U.size(),V.size()))
            print('s: {}'.format(s))
            
            self.w2v_att = torch.transpose(V,1,0).to(self.device)
            self.att = torch.mm(U,torch.diag(s)).to(self.device)
            self.normalize_att = torch.mm(U,torch.diag(s)).to(self.device)
            
        else:
            print('Expert Attr')
            att = np.array(hf.get('att'))
            self.att = torch.from_numpy(att).float().to(self.device)
            
            original_att = np.array(hf.get('original_att'))
            self.original_att = torch.from_numpy(original_att).float().to(self.device)
            
            w2v_att = np.array(hf.get('w2v_att'))
            self.w2v_att = torch.from_numpy(w2v_att).float().to(self.device)
            
            self.normalize_att = self.original_att/100
        
        # print('Finish loading data in ',time.time()-tic)
        
        train_feature = features[trainval_loc]
        test_seen_feature = features[test_seen_loc]
        test_unseen_feature = features[test_unseen_loc]
        if self.is_scale:
            scaler = preprocessing.MinMaxScaler()
    
            train_feature = scaler.fit_transform(train_feature)
            test_seen_feature = scaler.fit_transform(test_seen_feature)
            test_unseen_feature = scaler.fit_transform(test_unseen_feature)

        train_feature = torch.from_numpy(train_feature).float() #.to(self.device)
        test_seen_feature = torch.from_numpy(test_seen_feature) #.float().to(self.device)
        test_unseen_feature = torch.from_numpy(test_unseen_feature) #.float().to(self.device)

        train_label = torch.from_numpy(labels[trainval_loc]).long() #.to(self.device)
        test_unseen_label = torch.from_numpy(labels[test_unseen_loc]) #.long().to(self.device)
        test_seen_label = torch.from_numpy(labels[test_seen_loc]) #.long().to(self.device)

        self.seenclasses = torch.from_numpy(np.unique(train_label.cpu().numpy())).to(self.device)
        self.unseenclasses = torch.from_numpy(np.unique(test_unseen_label.cpu().numpy())).to(self.device)
        self.ntrain = train_feature.size()[0]
        self.ntrain_class = self.seenclasses.size(0)
        self.ntest_class = self.unseenclasses.size(0)
        self.train_class = self.seenclasses.clone()
        self.allclasses = torch.arange(0, self.ntrain_class+self.ntest_class).long()

#        self.train_mapped_label = map_label(train_label, self.seenclasses)

        self.data = {}
        self.data['train_seen'] = {}
        self.data['train_seen']['resnet_features'] = train_feature
        self.data['train_seen']['labels']= train_label


        self.data['train_unseen'] = {}
        self.data['train_unseen']['resnet_features'] = None
        self.data['train_unseen']['labels'] = None

        self.data['test_seen'] = {}
        self.data['test_seen']['resnet_features'] = test_seen_feature
        self.data['test_seen']['labels'] = test_seen_label

        self.data['test_unseen'] = {}
        self.data['test_unseen']['resnet_features'] = test_unseen_feature
        self.data['test_unseen']['labels'] = test_unseen_label


class CUBDataLoaderCached:
    def __init__(self, data_path, device, is_scale=False, is_unsupervised_attr=False, is_balance=True, use_cache=True):
        print(data_path)
        sys.path.append(data_path)

        self.data_path = data_path
        self.device = device
        self.dataset = 'CUB'
        self.datadir = os.path.join(self.data_path, 'H:/Models/{}/'.format(self.dataset))
        self.index_in_epoch = 0
        self.epochs_completed = 0
        self.is_scale = is_scale
        self.is_balance = is_balance
        self.is_unsupervised_attr = is_unsupervised_attr
        self.use_cache = use_cache

        if self.is_balance:
            print('Balance dataloader')

        # 设置缓存路径
        cache_key = f"{self.dataset}_{is_scale}_{is_unsupervised_attr}_{is_balance}"
        cache_hash = hashlib.md5(cache_key.encode()).hexdigest()[:8]
        self.cache_file = os.path.join(data_path, f"cub_cache_{cache_hash}.pkl")

        # 数据源文件路径
        self.hdf5_path = self.datadir + 'feature_map_ResNet_101_{}.hdf5'.format(self.dataset)

        # 加载数据
        if self.use_cache and self._try_load_cache():
            print("✓ 从缓存快速加载完成!")
        else:
            print("处理原始数据中...")
            start_time = time.time()
            self._load_original_data()
            self.get_idx_classes()

            if self.use_cache:
                self._save_cache()
                print(f"✓ 数据处理完成并已缓存 (耗时: {time.time() - start_time:.2f}s)")
            else:
                print(f"✓ 数据处理完成 (耗时: {time.time() - start_time:.2f}s)")

    def _get_file_info(self):
        """获取源文件信息用于缓存验证"""
        if os.path.exists(self.hdf5_path):
            stat = os.stat(self.hdf5_path)
            return {
                'size': stat.st_size,
                'mtime': stat.st_mtime,
                'params': (self.is_scale, self.is_unsupervised_attr, self.is_balance)
            }
        return None

    def _try_load_cache(self):
        """尝试从缓存加载数据"""
        if not os.path.exists(self.cache_file):
            return False

        try:
            with open(self.cache_file, 'rb') as f:
                cache_data = pickle.load(f)

            # 验证缓存有效性
            if cache_data.get('file_info') != self._get_file_info():
                print("缓存已过期，重新处理...")
                return False

            # 恢复数据
            for key, value in cache_data['attributes'].items():
                setattr(self, key, value)

            # 移动tensors到目标设备
            tensor_attrs = ['seenclasses', 'unseenclasses', 'att', 'original_att', 'w2v_att', 'normalize_att']
            for attr in tensor_attrs:
                if hasattr(self, attr) and hasattr(getattr(self, attr), 'to'):
                    setattr(self, attr, getattr(self, attr).to(self.device))

            return True

        except Exception as e:
            print(f"缓存加载失败: {e}")
            return False

    def _save_cache(self):
        """保存数据到缓存"""
        try:
            # 准备要缓存的属性
            cache_attrs = {
                'seenclasses': self.seenclasses.cpu(),
                'unseenclasses': self.unseenclasses.cpu(),
                'att': self.att.cpu(),
                'original_att': self.original_att.cpu(),
                'w2v_att': self.w2v_att.cpu(),
                'normalize_att': self.normalize_att.cpu(),
                'ntrain': self.ntrain,
                'ntrain_class': self.ntrain_class,
                'ntest_class': self.ntest_class,
                'train_class': self.train_class.cpu(),
                'allclasses': self.allclasses,
                'data': self.data,
                'idxs_list': getattr(self, 'idxs_list', None)
            }

            cache_data = {
                'file_info': self._get_file_info(),
                'attributes': cache_attrs
            }

            with open(self.cache_file, 'wb') as f:
                pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL)

        except Exception as e:
            print(f"缓存保存失败: {e}")

    def _load_original_data(self):
        """原始数据加载逻辑"""
        print(f"从 {self.hdf5_path} 加载数据...")

        hf = h5py.File(self.hdf5_path, 'r')
        features = np.array(hf.get('feature_map'))
        labels = np.array(hf.get('labels'))
        trainval_loc = np.array(hf.get('trainval_loc'))
        test_seen_loc = np.array(hf.get('test_seen_loc'))
        test_unseen_loc = np.array(hf.get('test_unseen_loc'))

        # 处理属性数据
        if self.is_unsupervised_attr:
            print('使用无监督属性')
            class_path = './w2v/{}_class.pkl'.format(self.dataset)
            with open(class_path, 'rb') as f:
                w2v_class = pickle.load(f)
            w2v_class = torch.tensor(w2v_class).float()

            U, s, V = torch.svd(w2v_class)
            self.w2v_att = torch.transpose(V, 1, 0).to(self.device)
            self.att = torch.mm(U, torch.diag(s)).to(self.device)
            self.normalize_att = torch.mm(U, torch.diag(s)).to(self.device)
        else:
            print('使用专家属性')
            att = np.array(hf.get('att'))
            self.att = torch.from_numpy(att).float().to(self.device)

            original_att = np.array(hf.get('original_att'))
            self.original_att = torch.from_numpy(original_att).float().to(self.device)

            w2v_att = np.array(hf.get('w2v_att'))
            self.w2v_att = torch.from_numpy(w2v_att).float().to(self.device)

            self.normalize_att = (self.original_att / 100).to(self.device)

        # 处理特征数据
        train_feature = features[trainval_loc]
        test_seen_feature = features[test_seen_loc]
        test_unseen_feature = features[test_unseen_loc]

        if self.is_scale:
            scaler = preprocessing.MinMaxScaler()
            train_feature = scaler.fit_transform(train_feature)
            test_seen_feature = scaler.transform(test_seen_feature)
            test_unseen_feature = scaler.transform(test_unseen_feature)

        train_feature = torch.from_numpy(train_feature).float()
        test_seen_feature = torch.from_numpy(test_seen_feature).float()
        test_unseen_feature = torch.from_numpy(test_unseen_feature).float()

        train_label = torch.from_numpy(labels[trainval_loc]).long()
        test_unseen_label = torch.from_numpy(labels[test_unseen_loc]).long()
        test_seen_label = torch.from_numpy(labels[test_seen_loc]).long()

        # 设置类别信息
        self.seenclasses = torch.from_numpy(np.unique(train_label.numpy())).to(self.device)
        self.unseenclasses = torch.from_numpy(np.unique(test_unseen_label.numpy())).to(self.device)
        self.ntrain = train_feature.size(0)
        self.ntrain_class = self.seenclasses.size(0)
        self.ntest_class = self.unseenclasses.size(0)
        self.train_class = self.seenclasses.clone()
        self.allclasses = torch.arange(0, self.ntrain_class + self.ntest_class).long()

        # 组织数据结构
        self.data = {
            'train_seen': {
                'resnet_features': train_feature,
                'labels': train_label
            },
            'train_unseen': {
                'resnet_features': None,
                'labels': None
            },
            'test_seen': {
                'resnet_features': test_seen_feature,
                'labels': test_seen_label
            },
            'test_unseen': {
                'resnet_features': test_unseen_feature,
                'labels': test_unseen_label
            }
        }

        hf.close()

    def next_batch(self, batch_size):
        """获取训练批次"""
        if self.is_balance:
            idx = []
            n_samples_class = max(batch_size // self.ntrain_class, 1)
            sampled_idx_c = np.random.choice(
                np.arange(self.ntrain_class),
                min(self.ntrain_class, batch_size),
                replace=False
            ).tolist()

            for i_c in sampled_idx_c:
                idxs = self.idxs_list[i_c]
                idx.append(np.random.choice(idxs, n_samples_class))
            idx = np.concatenate(idx)
            idx = torch.from_numpy(idx)
        else:
            idx = torch.randperm(self.ntrain)[:batch_size]

        batch_feature = self.data['train_seen']['resnet_features'][idx].to(self.device)
        batch_label = self.data['train_seen']['labels'][idx].to(self.device)
        batch_att = self.att[batch_label]
        return batch_label, batch_feature, batch_att

    def get_idx_classes(self):
        """构建类别索引"""
        n_classes = self.seenclasses.size(0)
        self.idxs_list = []
        train_label = self.data['train_seen']['labels']

        for i in range(n_classes):
            idx_c = torch.nonzero(train_label == self.seenclasses[i].cpu()).squeeze()
            if idx_c.dim() == 0:
                idx_c = idx_c.unsqueeze(0)
            self.idxs_list.append(idx_c.numpy())
        return self.idxs_list

    def clear_cache(self):
        """清除缓存文件"""
        if os.path.exists(self.cache_file):
            os.remove(self.cache_file)
            print("✓ 缓存已清除")
        else:
            print("无缓存文件需要清除")