import numpy as np
from sklearn.utils import compute_class_weight
import torch.nn as nn
import xgboost as xgb
import torch
import pandas as pd
from sklearn.model_selection import train_test_split, KFold

from project.constants import HQ_AMPs_FILE
from project.synthetic_data import generate_synthetic_sequences
from .sequence_properties import calculate_physchem_prop, calculate_aa_frequency, calculate_positional_encodings
from collections import Counter
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, matthews_corrcoef, precision_score

class PeptideClassifier(nn.Module):
    def __init__(self, model_path):
        super().__init__()
        if model_path is not None:
            self.model = xgb.XGBClassifier()
            self.model.load_model(model_path)
        else:
            self.model = xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss', 
                                         objective='binary:logistic', early_stopping_rounds=50, n_estimators=5000)
        self.dummy_param = nn.Parameter(torch.empty(0))
        self.decision_threshold = 0.5

    def get_input_features(self, sequences):
        """To be implemented by child classes"""
        raise NotImplementedError

    def train_classifier(self, input_features, labels, weight_balancing="balanced_with_adjustment_for_high_quality", mask_high_quality_idxs=[], return_feature_importances=False, verbose=True):
        """To be implemented by child classes"""
        raise NotImplementedError

    def eval_with_k_fold_cross_validation(self, input_features, labels, weight_balancing="balanced_with_adjustment_for_high_quality", k=5, mask_high_quality_idxs=[], reference_file=HQ_AMPs_FILE):
        kf = KFold(n_splits=k, shuffle=True, random_state=42)
        
        accuracies = []
        f1_scores = []
        mcc_scores = []
        high_quality_accuracies = []
        high_quality_f1_scores = []
        high_quality_mcc_scores = []
        confusion_matrices = []
        high_quality_confusion_matrices = []
        random_hit_rate = []
        shuffled_hit_rate = []
        mutated_hit_rate = []
        added_deleted_hit_rate = []
        precision_at_100, high_quality_precision_at_100 = [], []

        mutations=5 # FIXME hardcoded values
        additions=5 # FIXME hardcoded values
        random_sequences, shuffled_sequences, mutated_sequences, added_deleted_sequences = generate_synthetic_sequences(reference_file, 10000, mutations, additions) # FIXME hardcoded values
        no_synthetic_sequences_for_precision_computation = 1000 # FIXME hardcoded values

        random_input_features = self.get_input_features(random_sequences)
        shuffled_input_features = self.get_input_features(shuffled_sequences)
        mutated_input_features = self.get_input_features(mutated_sequences)
        added_deleted_input_features = self.get_input_features(added_deleted_sequences)

        for train_index, test_index in kf.split(input_features):
            train_features = [input_features[i] for i in train_index]
            test_features = [input_features[i] for i in test_index]
            train_labels = [labels[i] for i in train_index]
            test_labels = [labels[i] for i in test_index]
            train_mask_high_quality_idxs = [mask_high_quality_idxs[i] for i in train_index]
            test_mask_high_quality_idxs = [mask_high_quality_idxs[i] for i in test_index]

            self.train_classifier(train_features, train_labels, weight_balancing=weight_balancing, mask_high_quality_idxs=train_mask_high_quality_idxs, verbose=False)

            predictions = self.predict_from_features(test_features)
            scores = self.predict_from_features(test_features, proba=True)
            
            high_quality_test_labels = np.array(test_labels)[test_mask_high_quality_idxs]
            high_quality_predictions = predictions[test_mask_high_quality_idxs]
            high_quality_scores = scores[test_mask_high_quality_idxs]

            random_predictions = self.predict_from_features(random_input_features)
            random_scores = self.predict_from_features(random_input_features, proba=True)
            shuffled_predictions = self.predict_from_features(shuffled_input_features)
            shuffled_scores = self.predict_from_features(shuffled_input_features, proba=True)
            mutated_predictions = self.predict_from_features(mutated_input_features)
            mutated_scores = self.predict_from_features(mutated_input_features, proba=True)
            added_deleted_predictions = self.predict_from_features(added_deleted_input_features)
            added_deleted_scores = self.predict_from_features(added_deleted_input_features, proba=True)

            random_hit_rate.append(random_predictions.mean())
            shuffled_hit_rate.append(shuffled_predictions.mean())
            mutated_hit_rate.append(mutated_predictions.mean())
            added_deleted_hit_rate.append(added_deleted_predictions.mean())

            accuracies.append(accuracy_score(test_labels, predictions))
            f1_scores.append(f1_score(test_labels, predictions))
            mcc_scores.append(matthews_corrcoef(test_labels, predictions))

            high_quality_accuracies.append(accuracy_score(high_quality_test_labels, high_quality_predictions))
            high_quality_f1_scores.append(f1_score(high_quality_test_labels, high_quality_predictions))
            high_quality_mcc_scores.append(matthews_corrcoef(high_quality_test_labels, high_quality_predictions))

            confusion_matrices.append(confusion_matrix(test_labels, predictions, normalize='true'))
            high_quality_confusion_matrices.append(confusion_matrix(high_quality_test_labels, high_quality_predictions, normalize='true'))

            all_predictions = np.concatenate([high_quality_predictions, random_predictions[:no_synthetic_sequences_for_precision_computation], 
                                              shuffled_predictions[:no_synthetic_sequences_for_precision_computation],
                                              mutated_predictions[:no_synthetic_sequences_for_precision_computation], 
                                              added_deleted_predictions[:no_synthetic_sequences_for_precision_computation]])
            all_scores = np.concatenate([high_quality_scores, random_scores[:no_synthetic_sequences_for_precision_computation],
                                         shuffled_scores[:no_synthetic_sequences_for_precision_computation],
                                         mutated_scores[:no_synthetic_sequences_for_precision_computation],
                                         added_deleted_scores[:no_synthetic_sequences_for_precision_computation]])
            all_labels = np.concatenate([high_quality_test_labels, [0]*4*no_synthetic_sequences_for_precision_computation])

            top_100_idxs = np.argsort(all_scores)[-100:]
            precision_at_100.append(precision_score(np.array(all_labels)[top_100_idxs], all_predictions[top_100_idxs]))
            
            if len(high_quality_scores) >= 100:
                top_100_high_quality_idxs = np.argsort(high_quality_scores)[-100:]
                high_quality_precision_at_100.append(precision_score(high_quality_test_labels[top_100_high_quality_idxs], high_quality_predictions[top_100_high_quality_idxs]))

        print(f"Average Accuracy: {np.mean(accuracies):.4f} (+/- {np.std(accuracies):.4f})")
        print(f"Average F1 Score: {np.mean(f1_scores):.4f} (+/- {np.std(f1_scores):.4f})")
        print(f"Average MCC Score: {np.mean(mcc_scores):.4f} (+/- {np.std(mcc_scores):.4f})")
        
        print("Average Confusion Matrix:")
        average_confusion_matrix = np.mean(confusion_matrices, axis=0)
        print(average_confusion_matrix)
        print("Positive Likelihood Ratio:")
        print(average_confusion_matrix[1, 1] / (average_confusion_matrix[0, 1] + 1e-10))

        if high_quality_accuracies:
            print(f"Average High Quality Accuracy: {np.mean(high_quality_accuracies):.4f} (+/- {np.std(high_quality_accuracies):.4f})")
        if high_quality_f1_scores:
            print(f"Average High Quality F1 Score: {np.mean(high_quality_f1_scores):.4f} (+/- {np.std(high_quality_f1_scores):.4f})")
        if high_quality_mcc_scores:
            print(f"Average High Quality MCC Score: {np.mean(high_quality_mcc_scores):.4f} (+/- {np.std(high_quality_mcc_scores):.4f})")

        print("Average High Quality Confusion Matrix:")
        average_high_quality_confusion_matrix = np.mean(high_quality_confusion_matrices, axis=0)
        print(average_high_quality_confusion_matrix)
        print("High Quality Positive Likelihood Ratio:")
        print(average_high_quality_confusion_matrix[1, 1] / (average_high_quality_confusion_matrix[0, 1] + 1e-10))
        
        print(f"Probability of random sequences being AMPs: {np.mean(random_hit_rate):.4f}")
        print(f"Probability of shuffled sequences being AMPs: {np.mean(shuffled_hit_rate):.4f}")
        print(f"Probability of mutated sequences (mutations={mutations}) being AMPs: {np.mean(mutated_hit_rate):.4f}")
        print(f"Probability of added-deleted sequences (added-deleted={additions}) being AMPs: {np.mean(added_deleted_hit_rate):.4f}")

        print(f"Precision at Top 100: {np.mean(precision_at_100):.4f} (+/- {np.std(precision_at_100):.4f})")
        if high_quality_precision_at_100:
            print(f"High-Quality Precision at Top 100: {np.mean(high_quality_precision_at_100):.4f} (+/- {np.std(high_quality_precision_at_100):.4f})")

    def forward(self, sequences):
        input = self.get_input_features(sequences)
        probas = self.model.predict_proba(input)[:, 1]
        return (probas >= self.decision_threshold).astype(int)
    
    def predict_from_features(self, input_features, proba=False):
        probas = self.model.predict_proba(input_features)[:, 1]
        if proba:
            return probas
        return (probas >= self.decision_threshold).astype(int)
    
    def predict_proba(self, sequences):
        input = self.get_input_features(sequences)
        return self.model.predict_proba(input)[:, 1]

    def save(self, path):
        self.model.save_model(path)

class AMPClassifier(PeptideClassifier):
    def get_input_features(self, sequences):
        positional_encodings = pd.DataFrame(calculate_positional_encodings(sequences))
        properties = pd.DataFrame(calculate_physchem_prop(sequences, all_scales=True))
        frequencies = pd.DataFrame(calculate_aa_frequency(sequences))
        return pd.concat([properties, frequencies, positional_encodings], axis=1)

    def train_classifier(self, input_features, labels, weight_balancing="balanced_with_adjustment_for_high_quality", mask_high_quality_idxs=[], return_feature_importances=False, verbose=True):
        train_input, val_input, train_labels, val_labels, train_mask_high_quality_idxs, _ = train_test_split(
            input_features, labels, mask_high_quality_idxs, test_size=0.03, random_state=42, stratify=labels
        )

        if weight_balancing.startswith("balanced"):
            class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(train_labels), y=train_labels)
            weights = np.array([class_weights[i] for i in train_labels])

        if weight_balancing == "balanced_with_adjustment_for_high_quality":
            weights[train_mask_high_quality_idxs] = max(class_weights)

        if weight_balancing.startswith("balanced"):
            self.model.fit(
                train_input, train_labels, 
                sample_weight=weights, 
                eval_set=[(val_input, val_labels)], 
                verbose=verbose
            )
        else:
            self.model.fit(
                train_input, train_labels, 
                eval_set=[(val_input, val_labels)], 
                verbose=verbose
            )

        if return_feature_importances:
            return self.model.feature_importances_

class HemolyticClassifier(PeptideClassifier):
    def __init__(self, model_path):
        super().__init__(model_path)
        self.decision_threshold = 0.05

    def get_input_features(self, sequences):
        positional_encodings = pd.DataFrame(calculate_positional_encodings(sequences))
        properties = pd.DataFrame(calculate_physchem_prop(sequences, all_scales=True))
        frequencies = pd.DataFrame(calculate_aa_frequency(sequences))
        return pd.concat([properties, frequencies, positional_encodings], axis=1)

    def train_classifier(self, input_features, labels, weight_balancing="balanced_with_adjustment_for_high_quality", mask_high_quality_idxs=[], return_feature_importances=False, verbose=True):
        train_input, val_input, train_labels, val_labels, train_mask_high_quality_idxs, _ = train_test_split(
            input_features, labels, mask_high_quality_idxs, test_size=0.03, random_state=42, stratify=labels
        )

        high_quality_weights = compute_class_weight(class_weight="balanced", classes=np.unique(train_mask_high_quality_idxs), y=train_mask_high_quality_idxs)
        class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(train_labels), y=train_labels)
        
        weights = np.array([class_weights[c] + high_quality_weights[int(hq)] for (c,hq) in zip(train_labels, train_mask_high_quality_idxs)])

        self.model.fit(
            train_input, train_labels, 
            sample_weight=weights, 
            eval_set=[(val_input, val_labels)], 
            verbose=verbose
        )

        if return_feature_importances:
            return self.model.feature_importances_

