import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
import utils
from data.dataloader import data_loader,add_window
from loss.coral import CORAL
import metrics



class Buffer(nn.Module):

    def __init__(self, args):

        self.args = args
        self.memory_size  = 0
        self.memory = {}
        
        self.memory['train_X'] = None
        self.memory['train_Y'] = None
        self.memory['train_normalized_X'] = None
        self.memory['train_normalized_Y'] = None


        self.memory['val_X'] = None
        self.memory['val_Y'] = None
        self.memory['val_normalized_X'] = None
        self.memory['val_normalized_Y'] = None

        self.memory['train_stage_index'] = None
        self.memory['val_stage_index'] = None

        self.memory['logit'] = None
     
    
    def add_reservoir(self,args,window_set,stage_ind,train_num,val_num,model,stage,raw_data,adj_input):

        
        train_X = window_set['train_X'][:train_num]
        train_Y = window_set['train_Y'][:train_num]
        train_norm_X = window_set['train_normalized_X'][:train_num]
        train_norm_Y = window_set['train_normalized_Y'][:train_num]

        val_X = window_set['val_X'][:val_num]
        val_Y = window_set['val_Y'][:val_num]
        val_norm_X = window_set['val_normalized_X'][:val_num]
        val_norm_Y = window_set['val_normalized_Y'][:val_num]

        train_stage_ind= stage_ind['train'][:train_num]
        val_stage_ind = stage_ind['val'][:val_num]

        if args.selection_method[:4] == 'none':
            pass  # do nothing

        elif args.selection_method[:5] == 'joint':
            
            
            if self.memory_size == 0:

                self.memory['train_X'] = train_X
                self.memory['train_Y'] = train_Y
                self.memory['train_normalized_X'] = train_norm_X
                self.memory['train_normalized_Y'] = train_norm_Y

                self.memory['val_X'] = val_X
                self.memory['val_Y'] = val_Y
                self.memory['val_normalized_X'] = val_norm_X
                self.memory['val_normalized_Y'] = val_norm_Y

                self.memory['train_stage_index'] = train_stage_ind
                self.memory['val_stage_index'] = val_stage_ind
                self.memory_size = len(train_X)
        
            else:
                self.memory['train_X'] = torch.cat((self.memory['train_X'],train_X),dim=0)
                self.memory['train_Y'] = torch.cat((self.memory['train_Y'],train_Y),dim=0)
                self.memory['train_normalized_X'] = torch.cat((self.memory['train_normalized_X'],train_norm_X),dim=0)
                self.memory['train_normalized_Y'] = torch.cat((self.memory['train_normalized_Y'],train_norm_Y),dim=0)

                self.memory['val_X'] = torch.cat((self.memory['val_X'],val_X),dim=0)
                self.memory['val_Y'] = torch.cat((self.memory['val_Y'],val_Y),dim=0)
                self.memory['val_normalized_X'] = torch.cat((self.memory['val_normalized_X'],val_norm_X),dim=0)
                self.memory['val_normalized_Y'] = torch.cat((self.memory['val_normalized_Y'],val_norm_Y),dim=0)

                self.memory['train_stage_index'] = torch.cat((self.memory['train_stage_index'],train_stage_ind),dim=0)
                self.memory['val_stage_index'] = torch.cat((self.memory['val_stage_index'],val_stage_ind),dim=0)

                self.memory_size += len(train_X)

    

        if args.selection_method[:10] == 'tdc_hidden':
            
            raw_data_normalized = raw_data['train_normalized']
    
            raw_data_unnormalized = raw_data['train']
            
            sample_num = int(train_num*args.ratio)
            raw_data_sq = torch.from_numpy(raw_data_normalized).float()
            raw_data_trans = torch.transpose(raw_data_sq, 0, 1)
            seg_num  = math.floor(raw_data_trans.shape[0]/args.lag)
            selected_seg = utils.TDC(raw_data_trans.cuda(),seg_num,args.lag,num_domain = args.seg)
            print('selected segments:',selected_seg)

            for m in range(1,len(selected_seg)):
                
                mode_data = raw_data_normalized[:,math.floor((selected_seg[m-1]/10)*seg_num) * args.lag : math.floor((selected_seg[m]/10)*seg_num)*args.lag]
                mode_data_t = raw_data_unnormalized[:,math.floor((selected_seg[m-1]/10)*seg_num) * args.lag : math.floor((selected_seg[m]/10)*seg_num)*args.lag]
      
                features,target = add_window(mode_data,args.data_name,lag = args.lag,horizon = args.horizon)
                features_t,target_t = add_window(mode_data_t,args.data_name,lag = args.lag,horizon = args.horizon)

                features_sq = torch.squeeze(features).view(features.shape[0],-1).float().cuda()


                mode_sample_num =  int(sample_num * (selected_seg[m]- selected_seg[m-1])/10)
                print(m,'mode_sample_num:',mode_sample_num)


                list_of_selected = [0,features_sq.shape[0]-1]

                for _ in range(int(mode_sample_num)):
                    candidate_indexes = [list_of_selected + [i] for i in range(len(features_sq))if i not in list_of_selected]
                    candidate_values =  [CORAL(features_sq[candidate_indexes[i]],features_sq)  for i in range(len(candidate_indexes))]
                    best_candidate_index = min(range(len(candidate_indexes)), key=candidate_values.__getitem__)
                    list_of_selected.append(candidate_indexes[best_candidate_index][-1])
                

                if self.memory_size == 0:
            
                    self.memory['train_X'] = features_t[list_of_selected]
                    self.memory['train_Y'] = target_t[list_of_selected]
                    self.memory['train_normalized_X'] = features[list_of_selected]
                    self.memory['train_normalized_Y'] = target[list_of_selected]
                    self.memory['train_stage_index'] = stage_ind['train'][list_of_selected]
                
                    self.memory_size = len(list_of_selected)
       
                else:
      
                    self.memory['train_X'] = torch.cat((self.memory['train_X'],features_t[list_of_selected]),dim=0)
                    self.memory['train_Y'] = torch.cat((self.memory['train_Y'],target_t[list_of_selected]),dim=0)
                    self.memory['train_normalized_X'] = torch.cat((self.memory['train_normalized_X'],features[list_of_selected]),dim=0)
                    self.memory['train_normalized_Y'] =torch.cat((self.memory['train_normalized_Y'],target[list_of_selected]),dim=0)
                    self.memory['train_stage_index'] = torch.cat((self.memory['train_stage_index'],stage_ind['train'][list_of_selected]),dim=0)
                    self.memory_size += len(list_of_selected)

        
            
           


        

    

        

    def sample(self,batch_size):



        num_classes = len(self.memory['train_stage_index'].unique())
        samples_per_class = batch_size // num_classes

        class_indices = [[] for _ in range(num_classes)]
        for class_idx in range(1,num_classes+1):
            class_sample_indices = np.where(self.memory['train_stage_index'] == class_idx)[0]
            np.random.shuffle(class_sample_indices)
            selected_indices = class_sample_indices[:samples_per_class]

            class_indices[class_idx-1] = selected_indices

        index = np.concatenate(class_indices)
        
       
        return self.memory['train_stage_index'][index],self.memory['train_X'][index],self.memory['train_Y'][index],self.memory['train_normalized_X'][index],self.memory['train_normalized_Y'][index]



    def combine(self,stage_ind,window_set):
         
         window_set['train_X'] = torch.cat((window_set['train_X'],self.memory['train_X']),dim=0)
         window_set['train_Y'] = torch.cat((window_set['train_Y'],self.memory['train_Y']),dim=0)
         window_set['train_normalized_X'] = torch.cat((window_set['train_normalized_X'],self.memory['train_normalized_X']),dim=0)
         window_set['train_normalized_Y'] = torch.cat((window_set['train_normalized_Y'],self.memory['train_normalized_Y']),dim=0)

         window_set['val_X'] = torch.cat((window_set['val_X'],self.memory['val_X']),dim=0)
         window_set['val_Y'] = torch.cat((window_set['val_Y'],self.memory['val_Y']),dim=0)
         window_set['val_normalized_X'] = torch.cat((window_set['val_normalized_X'],self.memory['val_normalized_X']),dim=0)
         window_set['val_normalized_Y'] = torch.cat((window_set['val_normalized_Y'],self.memory['val_normalized_Y']),dim=0)

         stage_ind['train'] = torch.cat((stage_ind['train'],self.memory['train_stage_index']),dim=0)
         stage_ind['val'] = torch.cat((stage_ind['val'],self.memory['val_stage_index']),dim=0)

         return stage_ind,window_set


        


    
    
        

    