import olga
import olga.generation_probability as pgen
import olga.load_model as load_model
import torch
import torch.nn as nn

import numpy as np
import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import re
import math

import logomaker
import seaborn as sns
import subprocess


####################
# Helper functions #
####################


## Load OLGA Model
params_file_name = f'{olga.__path__[0]}/default_models/human_T_beta/model_params.txt'
marginals_file_name = f'{olga.__path__[0]}/default_models/human_T_beta/model_marginals.txt'
V_anchor_pos_file =f'{olga.__path__[0]}/default_models/human_T_beta/V_gene_CDR3_anchors.csv'
J_anchor_pos_file = f'{olga.__path__[0]}/default_models/human_T_beta/J_gene_CDR3_anchors.csv'
genomic_data = load_model.GenomicDataVDJ()
genomic_data.load_igor_genomic_data(params_file_name, V_anchor_pos_file, J_anchor_pos_file)
generative_model = load_model.GenerativeModelVDJ()
generative_model.load_and_process_igor_model(marginals_file_name)

pgen_model = pgen.GenerationProbabilityVDJ(generative_model, genomic_data)
#pgen_model.compute_aa_CDR3_pgen('CAWSVAPDRGGYTF', 'TRBV30*01', 'TRBJ1-2*01')
#pgen_model.compute_aa_CDR3_pgen('CASSLGGGGGEEFF')


def logo_plot_from_sequences(sequences, title_prefix, top_k=4):
    """
    Creates a Sequence Logo Plot for the given list of strings using logomaker.
    
    Parameters:
    sequences (list): A list of strings of variable sequence length.
    top_k (int): The number of top characters to display in the logo plot.
    """
    # Calculate the length of the longest sequence
    max_length = max(len(seq) for seq in sequences)
    
    # Pad the sequences to have the same length
    padded_sequences = [seq.ljust(max_length, '|') for seq in sequences]

    # Count the occurrences of characters at each position
    position_counters = [Counter() for _ in range(max_length)]
    for seq in padded_sequences:
        for i, char in enumerate(seq):
            position_counters[i][char] += 1
    
    # Normalize the counts and sort by occurrences
    for counter in position_counters:
        total = sum(counter.values())
        for char in counter:
            counter[char] /= total
    
    # Create a dataframe for the logo plot
    logo_df = pd.DataFrame(columns=["A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y", 'X', "|"], index=range(max_length)).fillna(0)
    for index, counter in enumerate(position_counters):
        sorted_chars = sorted(counter.items(), key=lambda x: x[1], reverse=True)[:top_k]
        for char, freq in sorted_chars:
            if char == '|':
                continue
            else:
                logo_df.at[index, char] = freq
    
    # Create the sequence logo plot using logomaker
    logo = logomaker.Logo(logo_df, color_scheme='weblogo_protein')
    plt.xlabel('Position')
    plt.xlim(-1, 25)
    plt.ylabel('Frequency')
    plt.title(f'{title_prefix} Logo Plot')
    plt.show()

def plot_olga_distributions(pmhcs, model_adapter, sampling_depth=200, n_rows=2, n_cols=5, **kwargs):
    '''
    Plot the distributions of the model generated vs the true TCRs based
    on pGen.
    '''
    # Get the sampling kwargs
    top_k = kwargs.get('top_k', 10)
    temperature = kwargs.get('temperature', 1.0)
    num_beams = kwargs.get('num_beams', 0)
    
    # Create a figure with 10 subplots arranged in a 2x5 grid
    fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(15, 6))
    fig.suptitle('Distribution of OLGA Predicted P_Gen', fontsize=16)
    
    # Iterate through the subplots and plot the histograms
    for i, ax in enumerate(axes.ravel()):
        pmhc = pmhcs[i]
        model_preds = model_adapter.sample_translations(pmhc, max_len=25, n=sampling_depth, top_k=top_k, temperature=temperature, num_beams=num_beams)
        true_probs = [pgen_model.compute_aa_CDR3_pgen(x.cdr3b) for x in pmhc.tcrs]
        model_probs = [pgen_model.compute_aa_CDR3_pgen(x) for x in model_preds]
        # Check for 0 probabilities
        true_probs = [1.0 if x==0 else x for x in true_probs]
        model_probs = [1.0 if x==0 else x for x in model_probs]

        ax.hist(np.log10(true_probs), density=True, alpha=0.65, label='True')
        ax.hist(np.log10(model_probs), density=True, alpha=0.65, label='Model')
        ax.set_title(f'{pmhc}')
        ax.legend(loc='best')

    # Add common x-axis and y-axis labels outside of the subplots
    #plt.xlabel('Log10 Pgen')
    #plt.ylabel('Density')

    # Adjust spacing between subplots
    plt.tight_layout()

    # Show the plot
    plt.show()
    

