import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tqdm
import itertools

from dataclasses import dataclass, fields, replace
from fim import fpgrowth
from functools import partial

from datasets import Dataset
from rules import TwoWeightKnapsackRule, IntegerKnapsackRule, ORRule, Rule
from sa import BaseAlgorithm

class BenchmarkRuleMiner:
    def __init__(
        self,
        dataset: Dataset,
    ):
        self.dataset = dataset

        self.X_train = self.dataset.get_X_train()
        self.y_train = self.dataset.get_y_train()

        self.X_test = self.dataset.get_X_test()
        self.y_test = self.dataset.get_y_test()
                
        for r in self.dataset.rule_candidates: 
            if r[0].value == 0:
                f = r[0].field_name
                self.X_train[f + '_rev'] = 1 - self.X_train[f]
                self.X_test[f + '_rev'] = 1 - self.X_test[f]
    
    def get_pareto_rules(self, min_support=0.01, zmin=2, zmax=2):
        self.zmin = zmin 
        self.zmax = zmax 
        
        itemMatrix = [[item for item in self.X_train.columns if row[item] == 1] for i,row in self.X_train.iterrows()]
        
        mined_rules = fpgrowth(itemMatrix, supp=min_support, zmin=zmin, zmax=zmax)

        self.mined_rules = []
        
        for r in mined_rules:
            r = Rule.parse_rule(r[0])
            msk_train = r.get_mask(self.X_train)
            
            conf_train = self.dataset.get_y_train_quantile()[msk_train].mean()
            # conf_train = (self.dataset.get_y_train_quantile() > 1 - np.mean(msk_train)).astype(int)[msk_train].mean()

            cov_train = msk_train.mean()

            if cov_train >= min_support and cov_train <= 0.4 and self.y_train[msk_train].mean() > np.mean(self.y_train) * 1.1:
                self.mined_rules.append({
                    'rule': r ,
                    'coverage': cov_train,
                    'metric': conf_train,
                    'mask_train': msk_train,
                })
                
        self.mined_rules = pd.DataFrame(self.mined_rules)
        self.pareto_rules = []
        for threshold in np.arange(0.025, 0.4, 0.025):
            filtered_rules = self.mined_rules[(self.mined_rules.coverage < threshold) & (self.mined_rules.coverage > threshold - 0.025)]
            if len(filtered_rules) > 0:
                self.pareto_rules.append(filtered_rules[filtered_rules.metric >= max(filtered_rules.metric) - 0.05])
        
        self.pareto_rules = pd.concat(self.pareto_rules, axis=0)

        # self.pareto_rules = self.mined_rules.iloc[all_indices]
        
        return self.pareto_rules
    
    def get_or_rules(self, num:int=1000, support:int=3):
        self.support = support 
        self.num = num 
        
        self.or_rules = []

        for _ in range(num):
            msk_train = np.zeros(len(self.X_train)).astype(int) 
            rules_pool = self.pareto_rules[self.pareto_rules.coverage <= 0.2]

            or_rule = []
            s = np.random.randint(1, support + 1)
            for c in np.random.choice(rules_pool.index, s): 
                msk_train = msk_train | self.pareto_rules.loc[c]['mask_train']
                or_rule.append(self.pareto_rules.loc[c]['rule'])

            r = ORRule(or_rule)

            conf_train = self.dataset.get_y_train_quantile()[msk_train].mean()
            # conf_train = (self.dataset.get_y_train_quantile() > 1 - np.mean(msk_train)).astype(int)[msk_train].mean()

            cov_train = msk_train.mean()

            self.or_rules.append({
                'rule': r,
                'metric': conf_train,
                'coverage': cov_train,
            })
        
        self.or_rules = pd.DataFrame(self.or_rules)
        # self.or_rules = pd.concat([self.or_rules, self.pareto_rules.drop(columns=['mask_train'])], axis=0).reset_index(drop=True)
        
        self.or_rules_pareto = []
        for threshold in np.arange(0.025, 0.45, 0.025):
            filtered_rules = self.or_rules[(self.or_rules.coverage < threshold) & (self.or_rules.coverage > threshold - 0.025)]
            if len(filtered_rules) > 0:
                self.or_rules_pareto.append(filtered_rules[filtered_rules.metric >= 0.98 * max(filtered_rules.metric)])     

        self.or_rules_pareto = pd.concat(self.or_rules_pareto, axis=0)
        return self.or_rules_pareto
                
    @staticmethod
    def pareto_frontier_indices(x, y):
        points = list(zip(x, y))  # Combine x and y coordinates into points
        sorted_indices = sorted(range(len(points)), key=lambda i: points[i])
        pareto_indices = []

        # Initialize the current maximum y-coordinate to negative infinity
        max_y = float('-inf')

        # Iterate through the sorted indices
        for index in sorted_indices:
            point = points[index]
            # Check if the current point dominates the maximum y-coordinate
            if point[1] > max_y:
                max_y = point[1]
                pareto_indices.append(index)

        return pareto_indices

    def score_benchmark(self, train_function, test_function):
        results = []
        for r in self.or_rules_pareto.rule:
            results.append({
                'rule': r, 
                'train': train_function(r), 
                'test': test_function(r)})

        results = pd.DataFrame(results).sort_values(by='train', ascending=False)

        return results
    
    def get_top_rule(self, train_function, test_function, num=1):
        return self.score_benchmark(train_function, test_function).iloc[0:num].rule
