import torch
import torch.nn as nn
from torch.utils.data import random_split, Dataset, Subset
from typing import List, Tuple, Union, Any
from ..utils.data import MultiLabeledDataset
from ..utils.simple_models import LinearModel
from tqdm import tqdm

class SelectiveHalfspaceLearner(nn.Module):
    """
    Conditional Classification for Any Finite Classes
    """
    def __init__(
            self,
            prev_header: str,
            subset_fracs: List[float],
            num_iter: int, 
            lr: float,
            device: torch.device = torch.device('cpu')
    ):
        """
        Initialize the conditional learner for finite class classification.
        Compute the learning rate of PSGD for the given lr coefficient using
        the formula:
            beta = O(sqrt(1/num_iter * dim_sample)).

        Parameters:
        prev_header (str):              The header of the previous module.
        num_iter (int):                 The number of iterations for optimizer.
        lr (float):                     The learning rate.
        subset_fracs (List[float]):     The ratio between training data size and validation data size.
        device (torch.device):          The device to be used.
        """
        super(SelectiveHalfspaceLearner, self).__init__()
        self.header = " ".join([prev_header, "learning conditional predictors", "-"])
        self.subset_fracs = subset_fracs
        self.num_iter = num_iter
        self.lr = lr    
        self.device = device

    def data_split(
            self,
            dataset: MultiLabeledDataset,
            subset_fracs: List[float]
    ) -> Tuple[Subset, Any, Subset]:
        '''
        Split dataset into subsets for PGD training, predictor evaluation, and PGD evaluation (optional)

        Parameters:
        dataset (MultiLabeledDataset):  For paralel processing, the dataset with mapped labels. Note that the dataset is 
                                        multi-labelled that each label is mapped to multiple errors, each of which is 
                                        generated by a corresponding predictor. 
                                        labels:   [num train sample, num predictors]
                                        features: [num train sample, num features]
        subset_fracs (List[float]):     The fractions that each subset should occupy in the dataset.

        Return:
        dataset_train:                  Subset used to update selector weights in PGD algorithm.
        dataset_val:                    Subset used to select the best selector over all iterations in PGD algorithm.
                                        This subset is optional since we can simply choose the selector of the last
                                        iteration.
        dataset_sel:                   Subset used to select the best predictor-selector pair.
        '''

        if sum(subset_fracs) > 1:
            raise ValueError(f"{self.header} sum of fractions of subsets exceed 1.")
        
        # compute subset lengths using the given fractions
        subset_sizes = [int(len(dataset) * frac) for frac in subset_fracs]

        dataset_val = None
        if len(subset_sizes) == 1:
            dataset_train, dataset_sel = random_split(
                dataset, 
                subset_sizes + [len(dataset) - sum(subset_sizes)],
                # generator=torch.Generator().manual_seed(42)
            )
        elif len(subset_sizes) == 2:
            dataset_train, dataset_val, dataset_sel = random_split(
                dataset, 
                subset_sizes + [len(dataset) - sum(subset_sizes)],
                # generator=torch.Generator().manual_seed(42)
            )
        else:
            raise ValueError(f"{self.header} Invalid number of subset sizes.")
        
        return dataset_train, dataset_val, dataset_sel

    def forward(
            self, 
            dataset: MultiLabeledDataset,
            observations: torch.Tensor          # [num observations, num features]
    ) -> Tuple[torch.Tensor, torch.Tensor, LinearModel]:
        """
        Call optimizer for the sparse predictors using all the data given.
        
        Note that the optimizer runs in parallel for all the sparse predictors.
        PSGD optimizer will return one selector for each sparse predictor.

        For each cluster, we evaluate the best classifier-selector pair using all the data given
        due to insufficient data size.

        At last, we use the same data set to find the best classifier-selector pair across cluster.

        Parameters:
        dataset (MultiLabeledDataset): For paralel processing, the dataset with mapped labels. Note that the dataset is 
                                        multi-labelled that each label is mapped to multiple errors, each of which is 
                                        generated by a corresponding predictor. 
                                        labels:   [num train sample, num predictors]
                                        features: [num train sample, num features]
        observations (torch.Tensor):    The observations for initialization the optimizer.
                                        [num observations, num predictors, num features]

        Returns:
        min_val (torch.Tensor):         The minimum error rate.         [num observations]
        min_ids (torch.Tensor):         The minimum error rate indices. [num observations]
        selectors (LinearModel):        The selector model.             [num observations, num predictors, num features]
        """        

        # print(f"{self.header} dataset feature size: {(len(dataset), dataset.num_features())}")
        # print(f"{self.header} dataset label size: {(len(dataset), dataset.num_labels())}")

        # split dataset
        dataset_train, dataset_sel_pgd, dataset_sel = self.data_split(
            dataset=dataset,
            subset_fracs=self.subset_fracs
        )

        # print(f"{self.header} dataset_sel_pgd is None? {dataset_sel_pgd is None}")

        # initialize selectors using observations
        selectors: LinearModel = LinearModel(
            weights=observations.unsqueeze(1).repeat(   # [num observations, 1, num features]
                1,
                dataset.num_labels(), 
                1
            )                                           # [num observations, num predictors, num features]
        )

        # for contractive projection
        self.observations = observations.unsqueeze(-2).expand(
            -1,
            dataset.num_labels(), 
            -1
        )

        # print(f"{self.header} initial selectors size: {selectors.size()}\n")

        # run gradient descent algorithm
        selectors: LinearModel = self.PGDOptim(
            lin_model=selectors,                        # [num observations, num predictors, num features]
            dataset_train=dataset_train, 
            dataset_sel_pgd=dataset_sel_pgd      
        )                                               # [num observations, num predictors, num features]
        # print(f"{self.header} learned selectors size: {selectors.size()}\n")

        # reduce the model to the best classifier-selector pair
        return selectors.model_selection_by_one(
            dim=1,
            dataset=dataset_sel
        )
    
    def PGDOptim(
            self,
            lin_model: LinearModel,
            dataset_train: Union[Subset, Dataset],
            dataset_sel_pgd: Subset = None
    ) -> LinearModel:
        """
        Perform the projected gradient descent optimization.

        Parameters:
        lin_model (LinearModel):                The sparse predictors.
        dataset_train (Union[Subset, Dataset]): The training dataset.
        dataset_sel_pgd (Subset):                   The validation dataset, if necessary.

        Returns:
        selector (LinearModel):                 The selector model.
        """        

        # initialize progress bar to count converged weights
        # self.converged_bar = tqdm(
        #     total=lin_model.size(0) * lin_model.size(1),
        #     desc=f"{self.header} converging"
        # )

        if dataset_sel_pgd is not None:
            lin_model = self.pgd_with_model_selection(
                lin_model=lin_model,
                dataset_train=dataset_train,
                dataset_sel=dataset_sel_pgd
            )
        else:
            lin_model = self.pgd(
                lin_model=lin_model,
                dataset_train=dataset_train
            )

        # self.converged_bar.close()
        
        return lin_model
    
    def pgd(
            self,
            lin_model: LinearModel,
            dataset_train: MultiLabeledDataset
    ) -> LinearModel:
        # labels:   [num predictors, num train sample]
        # features: [num train sample, num features]
        labels_train, features_train = dataset_train[:]

        for i in range(self.num_iter):
            # update weights
            self.grad_update(
                lin_model=lin_model,        # [num observations, num predictors, num features]
                labels=labels_train,        # [num predictors, num train sample]
                features=features_train     # [num train sample, num features]
            )
        
        return lin_model
    
    def pgd_with_model_selection(
            self,
            lin_model: LinearModel,
            dataset_train: MultiLabeledDataset,
            dataset_sel: MultiLabeledDataset
    ) -> LinearModel:
        # labels:   [num predictors, num train sample]
        # features: [num train sample, num features]
        labels_train, features_train = dataset_train[:]

        min_weights, min_errors = self.error_tracker(
            weight_shape=lin_model.size(), 
            device=self.device
        )

        for i in range(self.num_iter):
            # update weights
            self.grad_update(
                lin_model=lin_model,        # [num observations, num predictors, num features]
                labels=labels_train,        # [num predictors, num train sample]
                features=features_train     # [num train sample, num features]
            )

            # compute the conditional error rate
            conditional_error_rates = lin_model.conditional_one_rate(
                *dataset_sel[:]
            )                               # [num observations, num predictors]

            # select the best selector between the current and the previous best
            min_weights, min_errors = self.pairwise_select(
                curr_error=conditional_error_rates,
                min_error=min_errors,       # [num observations, num predictors]
                curr_weight=lin_model.weights,
                min_weight=min_weights      # [num observations, num predictors, num features]
            )

        return LinearModel(min_weights)
    
    def error_tracker(
            self,
            weight_shape: Union[List[int], torch.Size],
            device: torch.device
        ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Initialize the error tracker.

        Parameters:
        weight_shape (Union[List[int], torch.Size]):   The shape of the weights.
        device (torch.device):                         The device to be used.
        """

        # store the weights of the best linear models
        weights = torch.zeros(
            weight_shape
        ).to(device).squeeze() # [cluster_size, ..., dim_sample]

        # record the conditional error of the corresponding best linear models
        error = torch.ones(
            weight_shape[:-1]
        ).to(device).squeeze() # [cluster_size, ...]

        return weights, error
    
    def grad_update(
            self,
            lin_model: LinearModel,
            labels: torch.Tensor,       # labels:   [num predictors, num train sample]
            features: torch.Tensor      # features: [num train sample, num features]
    ) -> None:
        """
        Perform the gradient step for weights.
        
        Parameters:
        lin_model (LinearModel):         The linear model to be updated.
        labels (torch.Tensor):           The labels to be used.
        features (torch.Tensor):         The features to be used.
        """
        # compute projected gradients
        proj_grads = lin_model.projected_gradient(
            y=labels,
            X=features
        )

        # gradient step
        lin_model.update(
            weights= - self.lr * proj_grads
        )

        # update convergence progress
        # self.converged_bar.n = int((torch.norm(proj_grads, p=2, dim=-1) < 0.015).sum())
        # self.converged_bar.refresh()

    def pairwise_select(
            self,
            curr_error: torch.Tensor,
            min_error: torch.Tensor,
            curr_weight: torch.Tensor,
            min_weight: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """ 
        Update the weights based on the current error and the minimum error.

        Parameters:
        curr_error (torch.Tensor):      The current error.              [num observations, num predictors]
        min_error (torch.Tensor):       The minimum error.              [num observations, num predictors]
        curr_weight (torch.Tensor):     The current weight.             [num observations, num predictors, num features]
        min_weight (torch.Tensor):      The minimum weight.             [num observations, num predictors, num features]

        Returns:
        min_weight (torch.Tensor):      The updated minimum weight.     [num observations, num predictors, num features]
        min_error (torch.Tensor):       The updated minimum error.      [num observations, num predictors]
        """

        # print(f"{self.header}> updating - computing indices for weights that need to update ...")
        indices = curr_error < min_error   # [num observations, num predictors]
        # print(f"{self.header}> updating - updating errors ...")
        min_error = min_error * ~indices + curr_error * indices   # [num observations, num predictors]
        # print(f"{self.header}> updating - updating weights ...")
        min_weight = min_weight * ~indices.unsqueeze(-1) + curr_weight * indices.unsqueeze(-1) # [num observations, num predictors, num features]
        
        return min_weight, min_error
    

class ReferenceClassLearner(SelectiveHalfspaceLearner):
    """
    Conditional Classification for Any Finite Classes
    """
    def __init__(
            self,
            prev_header: str,
            subset_fracs: List[float],
            num_iter: int, 
            lr: float,
            device: torch.device = torch.device('cpu')
    ):
        """
        Initialize the conditional learner for finite class classification.
        Compute the learning rate of PSGD for the given lr coefficient using
        the formula:
            beta = O(sqrt(1/num_iter * dim_sample)).

        Parameters:
        prev_header (str):              The header of the previous module.
        num_iter (int):                 The number of iterations for optimizer.
        lr (float):                     The learning rate.
        subset_fracs (List[float]):     The ratio between training data size and validation data size.
        device (torch.device):          The device to be used.
        """
        super().__init__(
            prev_header=prev_header,
            subset_fracs=subset_fracs,
            num_iter=num_iter,
            lr=lr,
            device=device
        )
        self.header = " ".join([prev_header, "learning reference class", "-"])

    def grad_update(
            self,
            lin_model: LinearModel,
            labels: torch.Tensor,       # labels:   [num predictors, num train sample]
            features: torch.Tensor      # features: [num train sample, num features]
    ) -> None:
        """
        Perform the gradient step for weights.
        
        Parameters:
        lin_model (LinearModel):         The linear model to be updated.
        labels (torch.Tensor):           The labels to be used.
        features (torch.Tensor):         The features to be used.
        """
        # compute projected gradients
        proj_grads = lin_model.projected_gradient(
            y=labels,
            X=features
        )

        # gradient step
        lin_model.update(
            weights= - self.lr * proj_grads
        )

        # contractive projection
        lin_model.project_onto(X=self.observations)
        
        # update convergence progress
        # self.converged_bar.n = int((torch.norm(proj_grads, p=2, dim=-1) < 0.015).sum())
        # self.converged_bar.refresh()
