import random
import pickle

import torch

from .util import get_deltas_gpu



class DeltaDistributionTransducer():
    def __init__(self, args, cfg, dataset, deltas):
        self.device = args.device
        self.train_X = torch.tensor(dataset['train']['reps'], device=self.device)
        self.train_Y = torch.tensor(dataset['train']['targets'], device=self.device)
        self.smiles_X = dataset['train']['smiles']
        self.similarity_type = cfg.model.similarity_type
        self.sample_train = cfg.model.sample_train
        
        if cfg.model.sample_deltas:
            print('sampling train deltas transducer')
            indices = random.sample(range(len(deltas)), len(deltas)//10)
            self.train_deltas = torch.tensor(deltas[indices], device=self.device, dtype=torch.float32)
        else:
            self.train_deltas = torch.tensor(deltas, device=self.device, dtype=torch.float32) if len(deltas) > 0 else deltas


        print('approximating deltas')
        if len(self.train_deltas) == 0:  # Not stored during training
            # Generate random indices (as GPU tensors)
            t1_idx = torch.randint(self.train_X.size(0), (cfg.model.n_approx_deltas,), device=self.device)
            t2_idx = torch.randint(self.train_X.size(0), (cfg.model.n_approx_deltas,), device=self.device)
            
            # Swap indices based on skew direction (performed on GPU)
            if cfg.model.skew_direction == 'right':  # t2 has higher y values than t1
                swap_idxs = (self.train_Y[t1_idx] > self.train_Y[t2_idx]).flatten()
                tmp = t1_idx.clone()
                t1_idx[swap_idxs] = t2_idx[swap_idxs]
                t2_idx[swap_idxs] = tmp[swap_idxs]
            elif cfg.model.skew_direction == 'left':  # t2 has lower y values than t1
                swap_idxs = (self.train_Y[t1_idx] < self.train_Y[t2_idx]).flatten()
                tmp = t1_idx.clone()
                t1_idx[swap_idxs] = t2_idx[swap_idxs]
                t2_idx[swap_idxs] = tmp[swap_idxs]
            
            # Calculate deltas
            self.train_deltas = get_deltas_gpu(self.train_X[t1_idx], self.train_X[t2_idx], cfg.model.similarity_type)
            self.train_deltas = torch.tensor(self.train_deltas, device=self.device, dtype=torch.float32)

    
    def choose_anchor(self, curr_obs, use_dom_know_eval, return_anchor=False, exhaustive_search=True, eps_percentile=10):
        """return idx for training sample that gives delta closest to training deltas"""

        if use_dom_know_eval:
            train_types = torch.argmax(self.train_X, dim=1)
            curr_type = torch.argmax(curr_obs)
            sample_idxs = torch.where(train_types == curr_type)[0]
        elif self.sample_train:
            sample_idxs = torch.randperm(self.train_X.size(0))[:self.train_X.size(0)//5]
        else:
            sample_idxs = torch.arange(self.train_X.size(0), device=self.device)
        
        curr_deltas = get_deltas_gpu(self.train_X[sample_idxs], curr_obs, self.similarity_type)
        #import pdb; pdb.set_trace()
        if exhaustive_search:
            # curr_deltas = curr_deltas.float()
            # self.train_deltas = self.train_deltas.float()
            # distances = torch.cdist(curr_deltas, self.train_deltas, p=2)
            
            # min_idx = torch.argmin(distances.view(-1))
            # anchor_idx = (min_idx // distances.size(1)).item()
            # delta_idx = (min_idx % distances.size(1)).item()
            
            # closest_obs = self.train_X[sample_idxs[anchor_idx]]
            curr_deltas = curr_deltas.float()
            self.train_deltas = self.train_deltas.float()
            distances = torch.cdist(curr_deltas, self.train_deltas, p=2)  # [B, N]

            # top-10 global 최소 거리 찾기
            topk_vals, topk_idxs = torch.topk(distances.view(-1), k=10, largest=False)

            # top-1로부터 closest_obs 선택
            min_idx = topk_idxs[0]
            anchor_idx = (min_idx // distances.size(1)).item()
            delta_idx = (min_idx % distances.size(1)).item()
            closest_obs = self.train_X[sample_idxs[anchor_idx]]

            # top-10 smiles 목록 만들기
            anchor_idx = (topk_idxs // distances.size(1)).tolist()
            delta_indices = (topk_idxs % distances.size(1)).tolist()
            smiles_top10 = [self.smiles_X[sample_idxs[i]] for i in anchor_idx]

        else:
            # 엡실론 탐색 접근법
            # 임계값 결정을 위한 샘플 크기
            sample_size = min(100, curr_deltas.size(0))
            sample_indices = torch.randperm(curr_deltas.size(0))[:sample_size]
            sample_deltas = curr_deltas[sample_indices]
            
            sample_distances = torch.cdist(sample_deltas, self.train_deltas, p=2)
            min_distances, _ = torch.min(sample_distances, dim=1)
            
            delta_eps = torch.quantile(min_distances, eps_percentile/100.0).item()
            
            found = False
            while not found:
                anchor_idx = torch.randint(curr_deltas.size(0), (1,), device=self.device).item()
                closest_obs_delta = curr_deltas[anchor_idx:anchor_idx+1]

                distances = torch.cdist(closest_obs_delta, self.train_deltas, p=2) 
                topk_dists, topk_indices = torch.topk(distances, k=10, largest=False, dim=1) 

                valid_mask = topk_dists[0] <= delta_eps
                if valid_mask.any():
                    topk_valid_indices = topk_indices[0][valid_mask]  # top-10 중 eps 이하인 것만
                    delta_idx = topk_valid_indices[0].item()  # top-1
                    closest_obs = self.train_X[delta_idx]
                    smiles_top10 = [self.smiles_X[idx.item()] for idx in topk_valid_indices]
                    found = True

        if return_anchor:
            return closest_obs.to(dtype=torch.float32), anchor_idx, smiles_top10
        return closest_obs.to(dtype=torch.float32)




def define_transducer(args, cfg, dataset, deltas):
    transducer = DeltaDistributionTransducer(args, cfg, dataset, deltas)
    if cfg.model.sample_deltas: 
        pickle.dump(transducer.train_deltas, open(args.train_deltas_path, 'wb'))
    return transducer

