import copy
import itertools
from abc import ABC
from functools import cached_property

import numpy as np
import pandas as pd
import tqdm
import random

from rules import Operator, Rule, Condition
from datasets import Dataset
from dataclasses import dataclass, fields

@dataclass
class BaseAlgorithmParams: 
    N: int 
    num_iter: int = 1000

class BaseAlgorithm: 
    def __init__(
        self,
        dataset: Dataset,
        params: BaseAlgorithmParams,
        show_progress_bar: bool = True,
        skeleton: bool = False,
    ):
        self.dataset = dataset
        self.params = params
        self.show_progress_bar = show_progress_bar
        self.run_information = []
        self.target_support = None
        
        self.pool = self.get_pool()
        self.starting_rule = self.get_start()
    
    def get_pool(self) -> list[Rule]: 
        raise NotImplemented
    
    def get_start(self): 
        raise NotImplemented
    
    def get_neighbor(self, rule: Rule): 
        raise NotImplemented

    def evaluate_rule(
        self, 
        rule: Rule, 
        X: pd.DataFrame, 
        y: pd.Series,
    ) -> float: 
        y_preds = self.convert_to_quantile(rule.get_mask(X))
        return np.mean(y[y_preds >= 1 - self.target_support])
    
    def evaluate_rule_train(self, rule: Rule) -> float:
        return self.evaluate_rule(
            rule, 
            self.dataset.get_X_train(), 
            self.dataset.get_y_train_quantile(),
        )
     
    def evaluate_rule_test(self, rule: Rule) -> float:
        return self.evaluate_rule(
            rule, 
            self.dataset.get_X_test(), 
            self.dataset.get_y_test_quantile(),
        )   
    
    def get_acceptance_probability(
        self, 
        current_score: float, 
        proposed_score: float, 
        t: int, 
        T: int
    ) -> float:
        if (proposed_score > current_score):
            return 1
        if (proposed_score > current_score - 0.005):
            diff = (current_score - proposed_score) / 0.005
            return np.exp(-30 * diff * t / T)
        return 0
    
    def update(self, t: int, new_rule: Rule, new_score):
        self.run_information.append({
            'Iteration': t,
            'Operation': str(new_rule.name),
            'Train Score': new_score, 
            'Test Score': self.evaluate_rule_test(new_rule)
        })
    
    def run(self, target_support: float) -> Rule:
        self.run_information = []
        self.target_support = target_support
        
        self.candidate_rules = [self.starting_rule]
        iterator = tqdm.trange(self.params.num_iter) if self.show_progress_bar else range(self.params.num_iter)
        
        current = self.starting_rule
        current_score = self.evaluate_rule_train(current)
    
        self.update(0, current, current_score)
    
        for t in iterator:
            neighbor = self.get_neighbor(current)
             
            neighbor_score = self.evaluate_rule_train(neighbor)
            accept_prob = self.get_acceptance_probability(current_score, neighbor_score, t, self.params.num_iter)
            
            u = np.random.uniform(0, 1)
            if u <= accept_prob and neighbor_score != current_score:
                self.update(t, neighbor, neighbor_score)
                self.candidate_rules.append(neighbor)
                current = neighbor
                current_score = neighbor_score
                        
        return current

    def convert_to_quantile(self, predictions: np.ndarray): 
        predictions = np.argsort(predictions)
        output = np.array([None for _ in range(len(predictions))])
        for i, x in enumerate(predictions): 
            output[x] = i
        return output / len(output)
