import torch
import numpy as np
import random
import itertools
from sklearn.metrics import f1_score   
from torch.distributions.gumbel import Gumbel
import sympy
import wandb
import json

from ucimlrepo import fetch_ucirepo


def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.use_deterministic_algorithms(True, warn_only=True)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    return seed


class PrecomputedNoise:
    """Precomputes noise to improve efficiency.
    Retrive the noise with get_noise_tensor(shape)."""

    def __init__(self, buffer_size, noise_distribution='Gumbel', device=None, **kwargs):
        self.buffer_size = buffer_size
        self.noise_distribution = noise_distribution
        self.device = device
        self.kwargs = kwargs
        self.initialization()

    def initialization(self):
        self.sample_and_store()
        list1 = random.sample(range(self.buffer_size), self.buffer_size)
        list2 = random.sample(range(1, self.buffer_size), self.buffer_size - 1)
        self.indices_generator = itertools.product(list1, list2)

    def sample_and_store(self):
        if self.noise_distribution == 'Gumbel':
            scaling_factor = self.kwargs.get('weight_random', 1.)
            self.noise_vector = Gumbel(0,1).sample((self.buffer_size,)).to(self.device) * scaling_factor
        elif self.noise_distribution == 'Uniform':
            bounds = self.kwargs.get('uniform_noise_bounds', (0.75, 1.25))
            self.noise_vector = bounds[0] + (bounds[1] - bounds[0]) * torch.rand((self.buffer_size,), device=self.device)
        elif self.noise_distribution == 'Bernoulli':
            p = self.kwargs['p']
            self.noise_vector = torch.bernoulli(p * torch.ones((self.buffer_size,), device=self.device)).to(torch.bool)

    def get_noise_slice(self, start_idx, step_size, length):
        indices = (torch.arange(start_idx, start_idx + step_size * length, step_size, device=self.device) % self.buffer_size)
        return self.noise_vector[indices]
    
    def get_noise_tensor(self, shape):
        length = np.prod(shape)
        try:
            start_idx, step_size = next(self.indices_generator)
        except StopIteration:
            self.initialization()
            start_idx, step_size = next(self.indices_generator)
        return self.get_noise_slice(start_idx, step_size, length=length).reshape(shape)


def subtract_mean_of_two_along_dim(tensor, dim):
    """Subtract the mean of the two largest values in the given tensor along the given dimension."""
    top2, _ = torch.topk(tensor, 2, dim=dim)
    mean_top2 = top2.mean(dim=dim, keepdim=True)
    return tensor - mean_top2



def non_redundant_clauses(sympy_expr_nf, dnf_or_cnf=None, starting='and'):
    """
    Simplifies a sympy expression by removing redundant clauses.
    Parameters:
        sympy_expr_nf (sympy expression)
        dnf_or_cnf (str): 'dnf' or 'cnf'. If None, inferred from starting
        starting (str): 'and' for DNF, 'or' for CNF.
    Returns:
        sympy expression: the simplified expression
    """
    
    if dnf_or_cnf is None:
        dnf_or_cnf = 'dnf' if starting == 'and' else 'cnf'

    if dnf_or_cnf == 'dnf' and not isinstance(sympy_expr_nf, sympy.logic.boolalg.Or): #single clause
        return sympy_expr_nf 
    if dnf_or_cnf == 'cnf' and not isinstance(sympy_expr_nf, sympy.logic.boolalg.And): #single clause
        return sympy_expr_nf
    
    type_ = type(sympy_expr_nf)
    clauses_list = list(sympy_expr_nf.args)
    clauses = []
    for clause in clauses_list:
        if isinstance(clause, sympy.core.symbol.Symbol) or isinstance(clause, sympy.logic.boolalg.Not):
            clauses.append(clause)
        else:
            clauses.append(clause.args)

    non_redundant_clauses = set(clauses)
    for c1 in clauses:
        for c2 in clauses:
            if c1 != c2:
                try:
                    if set(c2).issubset(set(c1)):  # c2 is a subset of c1
                        non_redundant_clauses.discard(c1)
                except: # either c1 or c2 is a singleton
                    try:
                        if c2 in set(c1):
                            non_redundant_clauses.discard(c1)
                    except:
                        pass
    
    type_clauses = sympy.logic.boolalg.Or if dnf_or_cnf == 'cnf' else sympy.logic.boolalg.And
    new_clauses = []
    for clause in non_redundant_clauses:
        try:
            new_clauses.append(type_clauses(*clause))
        except:
            new_clauses.append(clause)
    return type_(*new_clauses)


def predict(model, loader, device=None, binarized_forward=False, compiled=False):
    """
    Make predictions on a given loader.
    Parameters:
        model (nn.Module): the model to make predictions with
        loader (DataLoader): the data to make predictions on
        device (torch.device): the device to use for predictions
        binarized_forward (bool): use CRS baseline's binarized_forward method.
        compiled (bool): use DiffLogic baseline's discretized model
    Returns:
        all_outputs (np.ndarray): the output of the model for each input in the loader
        all_predicted (np.ndarray): the predicted labels for each input in the loader
        all_ground_truth (np.ndarray): the ground truth labels for each input in the loader
    """
    all_outputs = []
    all_predicted = []
    all_ground_truth = []
    for features, labels in loader:
        features = features.to(device)
        labels = labels.to(device)
        if binarized_forward:
            outputs = model.binarized_forward(features)
        elif compiled:
            features = torch.nn.Flatten()(features).bool().cpu().detach().numpy()
            outputs = model(features)
        else:
            outputs = model(features)
        if type(outputs) == tuple:
            outputs = outputs[0]
        if (labels.ndim == 2 and labels.shape[1] == 1):
            predicted = torch.round(outputs.data)
            ground_truth = labels.reshape(-1, 1)
        elif labels.ndim == 1:
            predicted = torch.round(outputs.data)
            ground_truth = labels
        else:
            predicted = outputs.argmax(dim=1)
            ground_truth = labels.argmax(dim=1)
        all_predicted.append(predicted.cpu().detach().numpy())
        all_ground_truth.append(ground_truth.cpu().detach().numpy())
        all_outputs.append(outputs.cpu().detach().numpy())
    all_predicted = np.concatenate(all_predicted)
    all_ground_truth = np.concatenate(all_ground_truth)
    all_outputs = np.concatenate(all_outputs)
    return all_outputs, all_predicted, all_ground_truth



def eval(predicted, ground_truth):
    '''compute accuracy and f1 score'''
    acc = np.mean(predicted == ground_truth) * 100
    f1 = f1_score(ground_truth, predicted, average='macro')
    return acc, f1



def set_wandb_runs(trial, args, best, repeat, id_list):
    """Set up a wandb run for the given trial.
    Parameters:
        trial (optuna.trial.Trial): the trial to set up the run for
        args (argparse.Namespace): the arguments for the experiment
        best (bool): whether this is the best trial
        repeat (int): the repeat number of the trial
        id_list (list): list to store the run ids
    """
    config = dict(trial.params)
    config["trial number"] = trial.number
    config["repeat"] = repeat
    job_type = 'Trials' if not best else 'Best'
    name_trial = f'trial_{trial.number}_{repeat}' if not best else f'best_{repeat}'
    run = wandb.init(
        project=args.experimentname,
        config=config,
        group=args.group,
        name=name_trial,
        job_type=job_type,
        reinit=True
    )
    id_list.append(run.id)
    return run


def write_results_to_file(results_dict, results_file):
    """Write the results dictionary to a JSON file."""
    try:
        with open(results_file, "r") as f:
            data = json.load(f)
            if not isinstance(data, list):  # Ensure it's a list
                data = [data]
    except (FileNotFoundError, json.JSONDecodeError):
        data = []
    data.append(results_dict)
    with open(results_file, "w") as f:
        json.dump(data, f, indent=4) # Save back to the file



def get_hyperparameters(results_file, experiment_name, group):
    """Retrieve hyperparameters from a JSON file based on the experiment name and group."""
    try:
        with open(results_file, "r") as f:
            data = json.load(f)
            if not isinstance(data, list):  # Ensure it's a list
                data = [data]
    except (FileNotFoundError, json.JSONDecodeError):
        data = []
    for results_dict in data:
        if results_dict['experiment'] == experiment_name and results_dict['group'] == group:
            hyperparams = results_dict['hyperparameters']
            break
    if 'hyperparams' not in locals():
        raise ValueError(f"No hyperparameters found for experiment {experiment_name} and group {group}.")
    return hyperparams



def get_dataset(dataset_name, print_info=False, print_more_info=False):
    """Fetch a dataset from the UCI Machine Learning Repository."""

    datasets = {
    'adult': 2,
    'bank_marketing': 222,
    'banknote': 267,
    'chess': 23, #KRK
    'connect-4': 26,
    'letRecog': 59,
    'magic': 159,
    'tic-tac-toe': 101,
    'wine': 109,
    'mushroom': 73,
    'magic04': 159,
    'nursery': 76
    }

    if print_info:
        print(dataset_name)
    UCI_dataset = fetch_ucirepo(id = datasets[dataset_name])
    X = UCI_dataset.data.features
    y = UCI_dataset.data.targets
    if print_more_info:
        print(UCI_dataset.metadata)
    if print_info:
        print(UCI_dataset.variables)
        print(X.shape, y.shape)
    return X.to_numpy(), y