
from typing import List, Optional, Tuple
import hydra
import numpy as np
import pandas as pd
from ggs.models.predictors import BaseCNN
from random import sample
from scipy.sparse.csgraph import laplacian
from scipy.sparse import csr_matrix
from omegaconf import DictConfig
from omegaconf import OmegaConf
import pyrootutils
import torch
from copy import deepcopy
import logging
import time
import os
from datetime import datetime
from pykeops.torch import Vi, Vj
from scipy.sparse.linalg import cg
from scipy.sparse import identity, csr_matrix
from ggs.data.utils.tokenize import Encoder
from tqdm import tqdm
import sys
import pybktree




logging.basicConfig()
logging.root.setLevel(logging.NOTSET)
logger = logging.getLogger('Graph-based Smoothing')
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)

ALPHABET = list("ARNDCQEGHILKMFPSTWYV")
DNA_ALPHABET = list("ATCG")


use_cuda = torch.cuda.is_available()
def tensor(*x):
    if use_cuda:
        return torch.cuda.FloatTensor(*x)
    else:
        return torch.FloatTensor(*x)

def KNN_KeOps(K, metric="euclidean", **kwargs):
    def fit(x_train):
        # Setup the K-NN estimator:
        x_train = tensor(x_train)
        #start = timer()

        # Encoding as KeOps LazyTensors:
        D = x_train.shape[1]
        X_i = Vi(0, D)  # Purely symbolic "i" variable, without any data array
        X_j = Vj(1, D)  # Purely symbolic "j" variable, without any data array

        # Symbolic distance matrix:
        if metric == "euclidean":
            D_ij = ((X_i - X_j) ** 2).sum(-1)
        elif metric == "manhattan":
            D_ij = ((X_i - X_j).abs()).sum(-1)
        elif metric == 'levenshtein':
            D_ij = (-((X_i-X_j).abs())).ifelse(0, 1).sum(-1)
        elif metric == "angular":
            D_ij = -(X_i | X_j)
        elif metric == "hyperbolic":
            D_ij = ((X_i - X_j) ** 2).sum(-1) / (X_i[0] * X_j[0])
        else:
            raise NotImplementedError(f"The '{metric}' distance is not supported.")

        # K-NN query operator:
        KNN_fun = D_ij.Kmin_argKmin(K, dim=1)

        # N.B.: The "training" time here should be negligible.
        #elapsed = timer() - start

        def f(x_test):
            x_test = tensor(x_test)
            # start = timer()

            # Actual K-NN query:
            vals, indices  = KNN_fun(x_test, x_train)

            vals = vals.cpu().numpy()
            indices = indices.cpu().numpy()
            return vals, indices

        return f

    return fit


def run_BERT_predictor(seqs, batch_size, predictor):
    scores = []
    for i in range(0, len(seqs[0]), batch_size):
        seqs1 = tuple(seqs[0][i:i+batch_size])
        seqs2 = tuple(seqs[1][i:i+batch_size])
        batch = [seqs1, seqs2]
        results = predictor(batch).detach()
        scores.append(results)
    return torch.cat(scores, dim=0)

def run_predictor(seqs, batch_size, predictor):
    batches = torch.split(seqs, batch_size, 0)
    scores = []
    for b in batches:
        if b is None:
            continue
        results = predictor(b).detach()
        scores.append(results)
    return torch.concat(scores, dim=0)



def to_seq_tensor(seq):
    seq_ints = [
        ALPHABET.index(x) for x in seq
    ]
    return torch.tensor(seq_ints)


def get_next_state(seq, task, num=1):
    seq_list = list(seq)
    seq_len = len(seq)
    position = np.random.randint(0, seq_len)
    substitution = np.random.choice(ALPHABET if task != 'integrase' else DNA_ALPHABET)
    seq_new = seq_list.copy()
    seq_new[position] = substitution
    return ''.join(seq_new)


def to_batch_tensor(seq_list, task, subset=None, device='cpu'):
    if subset is not None:
        seq_list = seq_list[:subset]
    return torch.stack([to_seq_tensor(x) for x in seq_list]).to(device) if task != 'integrase' else torch.stack([x for x in seq_list]).to(device)


def maximum(A, B):
    BisBigger = A-B
    BisBigger.data = np.where(BisBigger.data < 0, 1, 0)
    return A - A.multiply(BisBigger) + B.multiply(BisBigger)


@hydra.main(version_base="1.3", config_path="../configs", config_name="GS.yaml")
def main(cfg: DictConfig) -> Optional[float]:

    # Extract data path from predictor_dir
    predictor_dir = cfg.experiment.predictor_dir
    num_mutations = [
        x for x in predictor_dir.split('/') if 'mutations' in x][0]
    starting_range = [
        x for x in predictor_dir.split('/') if 'percentile' in x][0]
    if 'GFP' in predictor_dir:
        task = 'GFP'
    elif 'AAV' in predictor_dir:
        task = 'AAV'
    elif 'recombination' in predictor_dir:
        task = 'integrase'
    elif 'Diagonal' in predictor_dir:
        task = 'Diagonal'
    else:
        raise ValueError(f'Task not found in predictor path: {predictor_dir}')
    data_dir = os.path.join(
        cfg.paths.data_dir, task, num_mutations, starting_range)
    base_pool_path = os.path.join(data_dir, 'base_seqs.csv')
    df_base = pd.read_csv(base_pool_path)
    if 'augmented' in df_base.columns:
        df_base = df_base[df_base.augmented == 0]
    logger.info(f'Loaded base sequences {base_pool_path}')

    # Load predictor
    predictor_path = os.path.join(predictor_dir, cfg.ckpt_file)
    cfg_path = os.path.join(predictor_dir, 'config.yaml')
    with open(cfg_path, 'r') as fp:
        ckpt_cfg = OmegaConf.load(fp.name)
    predictor = BaseCNN(**ckpt_cfg.model.predictor) 
    predictor_info = torch.load(predictor_path, map_location='cuda:0')
    predictor.load_state_dict({k.replace('predictor.', ''): v for k, v in predictor_info['state_dict'].items()}, strict=True)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    predictor.to(device).eval()
    logger.info(f'Loading base predictor {predictor_path}')

    # Random walk
    logger.info('Generating sequences by random walk from the base sequence pool..')
    start_time = time.time()
    sequence_cols = cfg.sequence_columns
    original_seqs = list(set(df_base[sequence_cols].values))
    num_starting_seqs = len(original_seqs)
    all_seqs_generated = original_seqs.copy()

    split_point = None
    max_n_seqs = cfg.max_n_seqs
    i_pointer = 0
    unique_seqs = set()
    unique_seqs.update(original_seqs)
    exploration_method = cfg.exploration_method
    pbar = tqdm(total=max_n_seqs, file=sys.stdout)
    while len(all_seqs_generated) < max_n_seqs:
        next_seq = all_seqs_generated[i_pointer] # NOTE: This is confusing change naming later
        if next_seq not in original_seqs and exploration_method == 'single_mut':
            print("WARNING: next_seq not in original_seqs")
            break
        new_seq = get_next_state(next_seq, task)  
        
        if new_seq not in unique_seqs:
            all_seqs_generated.append(new_seq)
            unique_seqs.add(new_seq)
            i_pointer += 1
            pbar.update(1)  # Update the progress bar
            if cfg.exploration_method == 'single_mut':
                i_pointer = i_pointer % num_starting_seqs
    pbar.close()


    logger.info("Finished generating sequences by random walk from the base sequence pool..")
    logger.info("Running predictor on generated sequences..")
    if split_point is None:
        all_seqs = list(sorted(set(all_seqs_generated)))
        all_seqs_pt = to_batch_tensor(all_seqs, task, subset=None, device=device)
        node_scores_init = run_predictor(all_seqs_pt, batch_size=256, predictor=predictor).cpu().numpy()
    else:
        all_seqs = list(sorted(set(all_seqs_generated)))
        all_seqs_1 = [x[:split_point] for x in all_seqs]
        all_seqs_2 = [x[split_point:] for x in all_seqs]
        node_scores_init = run_BERT_predictor([all_seqs_1, all_seqs_2], batch_size=256, predictor=predictor).cpu().numpy()

    indices_all = np.arange(len(all_seqs))
    elapsed_time = time.time() - start_time
    logger.info(f'Finished generation + evaluation in {elapsed_time:.2f} seconds')

    all_seqs_list = [all_seqs[i] for i in indices_all]
    logger.info(f'Total of {len(all_seqs_list)} sequences generated')

    all_seqs_list_orig = deepcopy(all_seqs_list)
    node_scores_init = node_scores_init[indices_all]
    encoder = Encoder(alphabet=ALPHABET)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    all_seqs = tensor(encoder.encode(all_seqs_list).to(torch.float).to(device))
    logger.info('Creating KNN graph..')
    start_time = time.time()
    fit_KeOps = KNN_KeOps(K=501, metric='levenshtein')(all_seqs)
    vals, indices = fit_KeOps(all_seqs)
    elapsed_time = time.time() - start_time
    logger.info(f'Finished kNN construction in {elapsed_time:.2f} seconds')
    vals = 1/vals[:, 1:]
    indices = indices[:, 1:]
    non_mutual_knn_graph = csr_matrix((vals.flatten(), indices.flatten(), np.arange(0, len(vals.flatten()) + 1, len(vals[0])))) 
    mutual_knn_graph = maximum(non_mutual_knn_graph, non_mutual_knn_graph.T)
    knn_graph = csr_matrix((1/mutual_knn_graph.data, mutual_knn_graph.indices, mutual_knn_graph.indptr))


    logger.info('Computing Laplacian..')
    start_time = time.time()
    L = laplacian(knn_graph, normed=True).tocsr()
    S_init = node_scores_init.copy()
    gamma = cfg.gamma
    n = L.shape[0]
    I = identity(n, format='csr')
    A = I + gamma*L

    Y_opt, _ = cg(A, S_init)

    logger.info('storing results..')
    df_smoothed = pd.DataFrame({'sequence': all_seqs_list_orig, 'target': Y_opt})

    #now = datetime.now()
    #now = now.strftime("%Y-%m-%d_%H-%M-%S")
    params = f'n-{max_n_seqs}_g-{gamma}'
    os.makedirs(os.path.join(data_dir, exploration_method), exist_ok=True)
    results_file = f'{exploration_method}/{cfg.results_file}' if cfg.results_file is not None else f'{exploration_method}/smoothed'
    results_path = os.path.join(
        data_dir, f'{results_file}-{params}'+'.csv')
    
    logger.info(f'Writing results to {results_path}')
    df_smoothed.to_csv(results_path, index=None)
    cfg_write_path = os.path.join(
        data_dir, results_file+'.yaml')
    with open(cfg_write_path, 'w') as f:
        OmegaConf.save(config=cfg, f=f)

if __name__ == '__main__':
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
    main()
