from typing import Optional

import numpy as np
import torch
import tqdm
from numpy.random import RandomState
from scipy.stats import zscore
from sklearn.linear_model import LassoCV, RidgeCV
from sklearn.utils import check_random_state
from torch.utils.data import Subset

from opendataval.dataval.api import DataEvaluator, ModelMixin



class CustomRidge(RidgeCV):
    def __init__(self, custom_regularization_strength, label, neighbor_value, neighbor_index):
        self.custom_regularization_strength = custom_regularization_strength
        self.neighbor_value = neighbor_value
        self.neighbor_index = neighbor_index
        self.label = label
        super().__init__()
    def fit(self, X, y, sample_weight=None):
        # Call the parent class fit method
        super().fit(X, y, sample_weight=sample_weight)

        custom_reg = (self.custom_regularization_strength * (self.neighbor_value.transpose(0,1) * (2 * ((self.label.unsqueeze(0)== self.label[self.neighbor_index]).to(torch.int))))* (((torch.tensor(self.coef_)[self.neighbor_index] - torch.tensor(self.coef_).unsqueeze(0))**2).transpose(0,1))).sum()
        self.coef_ = self.coef_ + custom_reg.numpy()
            

class GLOC(DataEvaluator, ModelMixin):
    """Implementation of GLOC.

    Parameters
    ----------
    num_models : int, optional
        Number of models to bag/aggregate, by default 1000
    random_state : RandomState, optional
        Random initial state, by default None
    """

    def __init__(
        self, num_models: int = 1000, random_state: Optional[RandomState] = None
    ):
        self.num_models = num_models
        self.random_state = check_random_state(random_state)

    def train_data_values(self, *args, **kwargs):
       
        subsets, performance = [], []
        for proportion in [0.2, 0.4, 0.6, 0.8]:
            sub, perf = (
                BaggingEvaluator(self.num_models, proportion, self.random_state)
                .input_model_metric(self.pred_model, self.metric)
                .input_data(self.x_train, self.y_train, self.x_valid, self.y_valid)
                .train_data_values(*args, **kwargs)
                .get_subset_perf()
            )

            subsets.append(sub)
            performance.append(perf)
        x_train_list = list(self.x_train)
        x_train_list_tensor = torch.stack(x_train_list)
        x_train_list_tensor2 = x_train_list_tensor.t()
        x_train_list_tensor = x_train_list_tensor.mm(x_train_list_tensor2)
        x_train_list_tensor1_ = x_train_list_tensor.norm(dim=1).unsqueeze(0).t()
        x_train_list_tensor2_ = x_train_list_tensor2.norm(dim=0).unsqueeze(0)
        x_frobenins = x_train_list_tensor1_.mm(x_train_list_tensor2_)
        dist = x_train_list_tensor.mul(1/x_frobenins)
        self.neighbor_value, self.neighbor_index = torch.topk(dist, 2, dim=0)
        self.subsets = np.vstack(subsets)
        self.performance = np.vstack(performance).reshape(-1)

        return self

    def evaluate_data_values(self) -> np.ndarray:
       
        norm_subsets = zscore(self.subsets, axis=1)
        norm_subsets[np.isnan(norm_subsets)] = 0  # For when all elements are the same
        centered_perf = self.performance - np.mean(self.performance)
        dv_gloc = CustomRidge(custom_regularization_strength =1, label = self.y_train, neighbor_value=self.neighbor_value, neighbor_index=self.neighbor_index) 
        dv_gloc.fit(X=norm_subsets, y=centered_perf)
        return dv_gloc.coef_


class BaggingEvaluator(DataEvaluator, ModelMixin):
    

    def __init__(
        self,
        num_models: int = 1000,
        proportion: float = 1.0,
        random_state: Optional[RandomState] = None,
    ):
        self.num_models = num_models
        self.proportion = proportion
        self.random_state = check_random_state(random_state)

    def input_data(
        self,
        x_train: torch.Tensor,
        y_train: torch.Tensor,
        x_valid: torch.Tensor,
        y_valid: torch.Tensor,
    ):
        """Store and transform input data for Bagging Evaluator.

        Parameters
        ----------
        x_train : torch.Tensor
            Data covariates
        y_train : torch.Tensor
            Data labels
        x_valid : torch.Tensor
            Test+Held-out covariates
        y_valid : torch.Tensor
            Test+Held-out labels
        """
        self.x_train = x_train
        self.y_train = y_train
        self.x_valid = x_valid
        self.y_valid = y_valid
        

        self.num_points = len(x_train)
        return self

    def train_data_values(self, *args, **kwargs):
        
        sample_dim = (self.num_models, self.num_points)
        self.subsets = self.random_state.binomial(1, self.proportion, size=sample_dim)
        self.performance = np.zeros((self.num_models,))

        for i in tqdm.tqdm(range(self.num_models)):
            subset = self.subsets[i].nonzero()[0]
            if not subset.any():
                continue

            curr_model = self.pred_model.clone()
            curr_model.fit(
                Subset(self.x_train, indices=subset),
                Subset(self.y_train, indices=subset),
                *args,
                **kwargs,
            )
            y_valid_hat = curr_model.predict(self.x_valid)

            curr_perf = self.evaluate(self.y_valid, y_valid_hat)
            self.performance[i] = curr_perf
        
        return self

    def evaluate_data_values(self):
        
        norm_subsets = zscore(self.subsets, axis=1)
        norm_subsets[np.isnan(norm_subsets)] = 0
        centered_perf = self.performance - np.mean(self.performance)
        dv_gloc = CustomRidge(custom_regularization_strength =0.001, neighbor_value=self.neighbor_value, neighbor_index=self.neighbor_index)
        dv_gloc.fit(X=norm_subsets, y=centered_perf)
        return dv_gloc.coef_

    def get_subset_perf(self):
        """Return the subsets and performance, used by GLOC DataEvaluator."""
        return self.subsets, self.performance
