"""Defines the Adalead explorer class."""
import random
from typing import Optional, Tuple

import numpy as np
import pandas as pd
from collections import defaultdict
# from explorer import Explorer

import sys 
sys.path.append("../")
import sequence_utils as s_utils
from scipy.special import softmax

# sys.path.append("/home/**/generative_model_work/oracle_models/tf_HESX1_R160C_R1_8mers")
sys.path.append("/home/**/generative_model_work/datasets/tf_HESX1_R160C_R1_8mers/")
sys.path.append("/home/**/generative_model_work/importance_models/tf_HESX1_R160C_R1_8mers")
# from produce_output import test
# from complor import onehotseq
from generate_property import output_property_oracle, onehotseq
from generate_property import onehotseq
from contributions_score_generative import contribution_score

path =  "/home/**/generative_model_work/datasets/tf_HESX1_R160C_R1_8mers/"

alphabets_of_color = np.load(f'{path}/categorical_variables.npy', allow_pickle=True)
alphabets_of_color = alphabets_of_color.tolist()
output_property_oracle = output_property_oracle()

class Adalead():
    """
    Adalead explorer.

    Algorithm works as follows:
        Initialize set of top sequences whose fitnesses are at least
            (1 - threshold) of the maximum fitness so far
        While we can still make model queries in this batch
            Recombine top sequences and append to parents
            Rollout from parents and append to mutants.

    """

    def __init__(
        self,
        model,
        model_args,
        rounds: int,
        sequences_batch_size: int,
        model_queries_per_batch: int,
        starting_sequence: str,
        alphabet: str,
        mu: int = 1,
        recomb_rate: float = 0,
        threshold: float = 0.05,
        rho: int = 0,
        eval_batch_size: int = 20,
        model_contri: Optional[type] = None, 
        criterion_contri: Optional[type] = None,
        optimizer_contri: Optional[type] = None,
        log_file: Optional[str] = None,
        motif_size: int = 1,
        motif_based: bool = False
    ):
        """
        Args:
            mu: Expected number of mutations to the full sequence (mu/L per position).
            recomb_rate: The probability of a crossover at any position in a sequence.
            threshold: At each round only sequences with fitness above
                (1-threshold)*f_max are retained as parents for generating next set of
                sequences.
            rho: The expected number of recombination partners for each recombinant.
            eval_batch_size: For code optimization; size of batches sent to model.

        """
        name = f"Adalead_mu={mu}_threshold={threshold}"
        print(name)

        # super().__init__(
        #     model,
        #     name,
        #     rounds,
        #     sequences_batch_size,
        #     model_queries_per_batch,
        #     starting_sequence,
        #     log_file,
        # )
        self.threshold = threshold
        self.recomb_rate = recomb_rate
        self.alphabet = alphabet
        self.mu = mu  # number of mutations per *sequence*.
        self.rho = rho
        self.eval_batch_size = eval_batch_size
        self.model_args = model_args
        self.model = model
        self.sequences_batch_size = sequences_batch_size
        self.model_queries_per_batch = model_queries_per_batch
        self.model_contri = model_contri
        self.criterion_contri =  criterion_contri
        self.optimizer_contri = optimizer_contri
        self.motif_size = motif_size
        self.motif_based = motif_based

    def _recombine_population(self, gen):
        # If only one member of population, can't do any recombining
        if len(gen) == 1:
            return gen

        random.shuffle(gen)
        ret = []
        for i in range(0, len(gen) - 1, 2):
            strA = []
            strB = []
            switch = False
            for ind in range(len(gen[i])):
                if random.random() < self.recomb_rate:
                    switch = not switch

                # putting together recombinants
                if switch:
                    strA.append(gen[i][ind])
                    strB.append(gen[i + 1][ind])
                else:
                    strB.append(gen[i][ind])
                    strA.append(gen[i + 1][ind])

            ret.append("".join(strA))
            ret.append("".join(strB))
        return ret
    
    def convert_to_ohe(self, data):
        N = len(data)
        ohe = np.zeros((N, self.model_args.max_len, \
            len(alphabets_of_color)))  ## batch size*sequenc_len*21
        seq_lengths = np.zeros((N,))

        for i in range(N):
            seq_en = onehotseq(data[i])
            l = len(data[i])

            ohe[i,0:l,:] =  seq_en
            seq_lengths[i,] = l
    
        return ohe, seq_lengths

    def calcualate_imp_aa(self,seq, seq_len, imp):
        num_seq = len(seq)
        max_len = int(max(seq_len))
        dis_seqs = np.zeros((num_seq, max_len), dtype=object)
        dis_imp = np.zeros((num_seq, max_len))
        for i in range(num_seq):
            l = int(seq_len[i])
            dis_seqs[i, 0:l] = list(seq[i])
            dis_imp[i, 0:l] = imp[i,0:l]
        
        dis_seqs = dis_seqs.reshape((-1,1))
        dis_imp = dis_imp.reshape((-1,1))
        aa_imp = [0]*len(self.alphabet)
        for j,alph in enumerate(self.alphabet):
            temp = (dis_seqs==alph)
            temp_imp = dis_imp[temp]
            aa_imp[j] = np.sum(temp_imp)/(np.sum(temp)+1E-18)

        return aa_imp
            
    def pad_seq(self, all_seq):
        list_of_lists = [list(map(str, group)) for group in all_seq]

        # Find the maximum length for padding
        max_length = max(len(lst) for lst in list_of_lists)

        # Pad lists with zeros
        padded_seq = np.array([lst + [0] * (max_length - len(lst)) for lst in list_of_lists])
        return padded_seq

    def normalize_impscore(self, imp_score):
        # Get the maximum value along axis=1 while keeping the dimensions for broadcasting
        max_vals = np.max(imp_score, axis=1, keepdims=True)
        # Normalize by dividing each element by the max value of its row
        imp_score = imp_score / max_vals
        return imp_score
    
    def make_motif_chunks(self):
        L = self.padded_seq.shape[1]  # Assuming L is the sequence length
        motif_dict = defaultdict(list)  # Dictionary to store motif chunks and their scores
        # print(self.padded_seq)
        # Iterate through the sequence in steps of chunk_size
        for i in range(0, L, self.motif_size):
            chunk = self.padded_seq[:, i:i+self.motif_size]
            chunk_imp = self.normalized_score[:, i:i+self.motif_size]
            # print(chunk)
            # Join along axis=1 (rows)
            joined = np.apply_along_axis(lambda x: ''.join(x), axis=1, arr=chunk)
            joined_imp = np.sum(chunk_imp, axis=1)

            # Store in dictionary (appending values)
            for motif, imp_score in zip(joined, joined_imp):
                motif_dict[motif].append(imp_score)
            
        # Compute the mean importance score for each motif
        motif_mean_dict = {motif: np.mean(scores) for motif, scores in motif_dict.items()}
        return motif_mean_dict

    def motif_level_assessment(self, roots, root_importance):
        ''' This function takes a bunch of sequences (roots)
        and mine all the motifs of size motif size 
        and assign them importance score based on var
        root_importance '''
        # print('========= Implementation =========')
        self.padded_seq = self.pad_seq(roots)
        self.normalized_score = self.normalize_impscore(root_importance)
        return self.make_motif_chunks()
        # '''Return an array with (#motifs,2), dim:1 contains motifs
        # dim:2 contains imp value'''
    
    # def individual_protein_chunks(self, p, imp):
    #     '''p: individual protein sequence
    #     imp: imp score of each position        
    #     returns: dict: {motifs:importance}'''
        
    #     motif_dict = defaultdict(list)
    #     seq_length = len(p)
    #     # Slide over the sequence with a stride of 1
    #     for i in range(seq_length - self.motif_size + 1):
    #         motif = p[i:i+self.motif_size]  # Extract motif of size k
    #         motif_imp = sum(imp[i:i+self.motif_size])  # Sum importance scores for this motif
            
    #         motif_dict[motif].append(motif_imp)  # Store importance scores

    #     # Compute mean importance for each unique motif
    #     protein_dict = {motif: np.mean(scores) for motif, scores in motif_dict.items()}
        
    #     return protein_dict
    
    def individual_protein_chunks(self, p, imp):
        '''p: individual protein sequence
        imp: imp score of each position        
        returns: dict: {motifs_idx:importance}'''
        motif_dict = defaultdict(list)
        seq_length = len(p)
        # Slide over the sequence with a stride of 1
        for i in range(seq_length - self.motif_size + 1):
            motif = p[i:i+self.motif_size]  # Extract motif of size k
            motif_imp = np.mean(imp[i:i+self.motif_size])  # Sum importance scores for this motif
            motif_dict[i].append(motif_imp)  # Store importance scores

        # Compute mean importance for each unique motif
        protein_dict = {motif: np.mean(scores) for motif, scores in motif_dict.items()}
        return protein_dict
            
        
    
    def propose_sequences(
        self,measured_sequences: pd.DataFrame, is_imp_based= False, temp=1.0) -> Tuple[np.ndarray, np.ndarray]:
        """Propose top `sequences_batch_size` sequences for evaluation."""
        measured_sequence_set = set(measured_sequences["sequence"])

        # Get all sequences within `self.threshold` percentile of the top_fitness
        # top_fitness = measured_sequences["true_score"].max()
        top_fitness = max(measured_sequences["true_score"])
        top_inds = measured_sequences["true_score"] >= top_fitness * (
            1 - np.sign(top_fitness) * self.threshold
        )
        top_inds = top_inds.tolist()

        parents = np.resize(
            np.array(measured_sequences["sequence"])[top_inds],
            self.sequences_batch_size,
        )
        # print(parents)
        # print(aaaa)
        # print('Parents are', len(parents))
        # parent_ohe, parent_len = self.convert_to_ohe(parents)
        # print('Shape of importance',parent_importance.shape)
        
        # parents = np.resize(
        #     measured_sequences["sequence"][top_inds].to_numpy(),
        #     self.sequences_batch_size,
        # )

        sequences = {}
        # previous_model_cost = self.model.cost
        track_queries = 0
        while track_queries < self.model_queries_per_batch:
            # print('This outermost', track_queries)
            # while self.model.cost - previous_model_cost < self.model_queries_per_batch:
            # generate recombinant mutants
            for i in range(self.rho):
                parents = self._recombine_population(parents)
            
            # print('Parents', parents)

            for i in range(0, len(parents), self.eval_batch_size):
                # Here we do rollouts from each parent (root of rollout tree)
                roots = parents[i : i + self.eval_batch_size]
                # root_fitnesses = test(self.model,roots, self.model_args)
                root_fitnesses = output_property_oracle.output_property(roots)
                root_ohe, root_len = self.convert_to_ohe(roots)
                if is_imp_based:
                    root_importance = contribution_score(self.model_contri, \
                        self.criterion_contri,self.optimizer_contri,root_ohe, root_len, \
                        self.model_args.device)
                    
                    if not self.motif_based:
                        root_imp_aa = self.calcualate_imp_aa(roots, root_len, root_importance)
                        ## softmax #####
                        root_imp_aa = np.exp(np.array(root_imp_aa))
                        root_imp_aa = root_imp_aa/np.sum(root_imp_aa)                    
                        root_imp_aa = root_imp_aa.tolist()
                        
                    elif self.motif_based:  
                        '''Rank different motifs here'''
                        motif_mean_dict = self.motif_level_assessment(roots, root_importance)
                
                if self.rho > 0:
                    track_queries += len(root_fitnesses)
                # print('This parent check', track_queries)

                nodes = list(enumerate(roots))
                while (
                    len(nodes) > 0
                    and track_queries
                    < self.model_queries_per_batch
                ):
                # while (
                #     len(nodes) > 0
                #     and track_queries + self.eval_batch_size
                #     < self.model_queries_per_batch
                # ):
                    child_idxs = []
                    children = []
                    while len(children) < len(nodes):
                        idx, node = nodes[len(children) - 1]
                        ## write a code here for importance
                        if not is_imp_based:
                            # print('Starting one mutation randomly')
                            ### random mutation
                            # child = s_utils.generate_random_mutant(
                            #     node,
                            #     self.mu * 1 / len(node),
                            #     self.alphabet,
                            # )
                            child = s_utils.generate_random_multiple_mutant(
                                node,
                                self.mu * 1 / len(node),
                                self.alphabet,
                                1
                            )
                        else:
                            node_imp = root_importance[idx,0:int(root_len[idx])]
                            if not self.motif_based: 
                                ### importance-based mutation
                                child = s_utils.generate_importance_based_mutant(
                                    node,
                                    node_imp,
                                    temp,
                                    self.alphabet, root_imp_aa
                                )
                            elif self.motif_based:
                                node_motifs = self.individual_protein_chunks(node,node_imp)
                                ### importance-based motif level mutation
                                child = s_utils.motif_level_mutation(
                                    node,
                                    motif_mean_dict,
                                    node_motifs,
                                    self.motif_size, temp
                                )
                            

                        # Stop when we generate new child that has never been seen
                        # before
                        # if (
                        #     child not in measured_sequence_set
                        #     and child not in sequences
                        # ):
                        child_idxs.append(idx)
                        children.append(child)

  
                    # fitnesses = test(self.model, children, self.model_args)
                    fitnesses = output_property_oracle.output_property(children)
                    # print('This children check', track_queries)
                    track_queries += len(fitnesses)
                    sequences.update(zip(children, fitnesses))

                    nodes = []
                    for idx, child, fitness in zip(child_idxs, children, fitnesses):
                        if fitness >= root_fitnesses[idx]:
                            nodes.append((idx, child))

        if len(sequences) == 0:
            raise ValueError(
                "No sequences generated. If `model_queries_per_batch` is small, try "
                "making `eval_batch_size` smaller"
            )

        # We propose the top `self.sequences_batch_size` new sequences we have generated
        new_seqs = np.array(list(sequences.keys()))
        # print('New sequences', new_seqs)
        preds = np.array(list(sequences.values()))
        sorted_order = np.argsort(preds)#[: -self.sequences_batch_size : -1]
        # print('New sequences', new_seqs[sorted_order])
        # print(f'====={len(new_seqs[sorted_order])}=======')
        # print('==========')

        return new_seqs[sorted_order], preds[sorted_order]
