import numpy as np; import torch; import bz2; import _pickle as c; import json; import gc

class mouse_data_struct():
    def __init__(self, fields, size, batch_len):
        self.size = size
        self.start_ind = self.special_ind = 0 
        self.fields = fields
        self.batch_len = batch_len
        for f in fields:
            setattr(self, f, np.empty(self.size, dtype = object))
        
    def push_multiple(self, fields, datas):
          for field, data in zip(fields, datas):
              self.push(field, data)
              
    def push(self, field_name, new_data):
        data = getattr(self, field_name)
        single = type(new_data) == int or type(new_data) == float #or len(new_data) == 1
        special = (not single and len (new_data) != self.batch_len)
        self.add_single(data, new_data) if single else self.add_special(data, new_data) if special else self.add_array(data, new_data) 
        setattr(self, field_name, data)
              
    def add_single(self, data, new_data):
        data[self.start_ind : self.start_ind + self.batch_len] = new_data
        
    def add_special(self, data, new_data):
        data[self.special_ind] = new_data
        
    def add_array(self, data, new_data):
        for i in range(self.batch_len):
             data[self.start_ind + i] = new_data[i]
        
    def increment(self):
        self.start_ind += self.batch_len
        self.special_ind += 1
        
    def to_dictionary(self, d = dict()):
        for f in self.fields:
            d[f] = getattr(self, f)
        return d 
    
    def from_dictionary(self, d):
        for f in d.keys():
            setattr(self, f, d[f])
            
    def remove_tensors(self):
        for f in self.fields:
            if getattr(self, f) is not None:
                if type(getattr(self, f)[0]) == torch.Tensor:  
                    setattr(self, f, None)
                
    def postprocess_specials(self):
        for f in self.fields:
            data = getattr(self, f)
            if data[-1] is None:
                setattr(self, f, data[:self.special_ind])


class cage_data_struct():
    def __init__(self, location, names = None, load = False, suffix = None):
        self.location = location + suffix #'.json' '.pkbz2' '.pickle'
        if load is False:
            self.new_struct(names)
        else:
            self.load_struct()

    def new_struct(self, names):
        self.d = dict()
        self.d["names"] = dict()
        for self.name in names:
            self.add_name(name = self.name)
    
    def add_name(self, name):
        self.d[name] = dict()
        self.d["names"][name] = dict()
        self.d["names"][name]["sample_num"] = 0               
        
    def add_data(self, name, manager):
        self.name = name
        self.sample = self.d["names"][name]["sample_num"]
        self.d["names"][self.name]["sample_num"] += 1   
        self.d[self.name][self.sample] = dict()
#        self.d[self.name][self.sample]['data'] = self.compress_mouse(manager.data.to_dictionary())
        self.d[self.name][self.sample]['data'] = manager.data.to_dictionary()
        # self.d[self.name][self.sample]['network'] = manager.agent.state_dict()
        self.save()

    def change_name(self, old, new):
        self.d["names"][new] = self.d["names"].pop(old)
        self.d[new] = self.d.pop(old)
        self.save()
        
    def save(self):
        print("saving cage")
        gc.collect()

        # with open(self.location, 'wb') as f: 
        #     c.dump(self.d, f)    

        # with bz2.BZ2File(self.location, 'w' , compresslevel = 1) as f: 
        #     c.dump(self.d, f) 
        
        with open(self.location, 'w') as f: 
            json.dump(self.d, f) 
            
    def load_struct(self):        
        # self.d = c.load(bz2.BZ2File(self.location, 'rb'))
 
        #json.load(self.location)
        
        with open(self.location, 'rb') as f:
            self.d = c.load(f)
            
        print("cage loaded")
          
    def structure(self):
        print("[condition][sample number]['data'][field]")
        
    def names(self):
        return list(self.d["names"].keys())

    def info(self):
        return self.d["names"]

    def get_data(self, name, sample):
        return self.d[name][sample]['data']
    
    def compress(self):
        for self.name in self.d["names"]:            
            for self.sample in self.d[self.name]:
                print(f"mouse {self.sample}")
                self.compress_mouse()
    
    def compress_mouse(self):
        self.purge_extra_fields()
        self.compress_fields()
                    
    def purge_extra_fields(self):
        keys =  self.d[self.name][self.sample]['data'].keys()
        purge = ['backbone', 'stim', 'stim_end', 'W4L_end', 'gos', 'nogos','plant_inds', 'plant_PGO', 'plant_ID', "lick_prob", "Qs", "net_input", "net_output", "LTM", "f_gate", "i_gate", "c_gate", "o_gate"]
        for field in purge:
            if field in keys:
                del(self.d[self.name][self.sample]['data'][field])
                
    def compress_fields(self):
        keys =  self.d[self.name][self.sample]['data'].keys()
        for self.field in keys:
                self.OG_data = self.d[self.name][self.sample]['data'][self.field]
                E = self.d[self.name][self.sample]['data']['episode'][-1] + 1 
                T = self.d[self.name][self.sample]['data']['trial'][-1] + 1 
                if len(self.OG_data) == E * T:                                            # ignore elements that are not episode x trial length
                    print(f"compressing {self.field}")
                    self.vec_of_vec = False
                    self.compress_array()
                self.d[self.name][self.sample]['data'][self.field] = self.d[self.name][self.sample]['data'][self.field].tolist()
                del(self.OG_data)
                
                
    def compress_array(self):
        sub_array, first_element = self.get_child(self.OG_data)
        if sub_array: 
            self.handle_sub_array()
        else:
            self.compress_elements(first_element, self.OG_data)
    
    def handle_sub_array(self, found_first = False):
        for self.sub_i, sub in enumerate(self.OG_data):
            if (len(sub) != 0):
                if not found_first:
                    found_first = True
                    _, first_element = self.get_child(sub)
                self.compress_elements(first_element, sub)
            self.d[self.name][self.sample]['data'][self.field][self.sub_i] = self.d[self.name][self.sample]['data'][self.field][self.sub_i].tolist()
            del sub
        
    def get_child(self, field_data, i = 0):
        shape = np.shape(field_data[i])
        if len(shape) > 0:                                                       # has numpy sub-array
            self.vec_of_vec = True
            return True, field_data[0] 
        
        # if torch.is_tensor(field_data[0]):                                       # has tensor sub-array
        #     self.vec_of_vec = True
        #     return True, field_data[0] 

        while len(np.shape(field_data[i])):                                      # 1 if empty, 0 if leaf
            i += 1 
        return False, field_data[i]
        
    def compress_elements(self, first_element, field_data):
        T = self.get_compressed_type(first_element, field_data)
        if self.vec_of_vec:
            self.d[self.name][self.sample]['data'][self.field][self.sub_i] = (self.OG_data[self.sub_i].astype(T))
        else:
            self.d[self.name][self.sample]['data'][self.field] = self.OG_data.astype(T)
        # if not torch.is_tensor(first_element):
        #     T = self.get_compressed_type(first_element, field_data)
        #     if self.vec_of_vec:
        #         self.d[self.name][self.sample]['data'][self.field][self.sub_i] = (self.OG_data[self.sub_i].astype(T))
        #     else:
        #         self.d[self.name][self.sample]['data'][self.field] = self.OG_data.astype(T)
        # else: 
        #     self.d[self.name][self.sample]['data'][self.field][self.sub_i] = self.OG_data[self.sub_i].to(torch.float16) 

    def get_compressed_type(self, first_element, field_data):
        if (np.any(field_data == None) or np.any(np.isnan(field_data.astype(float)))):
            T = 'float16'
        else: 
            is_float = isinstance(first_element, np.float16) or isinstance(first_element, np.float32) or isinstance(first_element, np.float64) or isinstance(first_element, float)
            is_int = isinstance(first_element, int) or first_element.is_integer()
            if is_float:
                T = 'float32'
            if is_int:
                data_min = field_data[field_data != None].min()
                data_max = field_data[field_data != None].max()
                T = 'int8' if (data_max < 100 and data_min > -100) else 'int16'
                if data_min >=0:
                    T = 'u'+T   