from typing import Dict, List
from collections import defaultdict
import logging

from tqdm import tqdm
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.random_projection import SparseRandomProjection
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

from inference_rlhf.code.coreset.coreset import Coreset
from inference_rlhf.code.helpers.utils import timing, construct_sparse_matrix

log = logging.getLogger(__name__)

class EllipticalCoreset(Coreset):
    def __init__(self, cfg):
        super().__init__(cfg)

        self.perform_sparse_projection = cfg.coreset.elliptical.perform_sparse_projection
        self.sparse_dim = cfg.coreset.elliptical.sparse_dim
        self.reward_percent_to_filter = cfg.coreset.elliptical.reward_percent_to_filter
        self.log_probs_percent_to_filter = cfg.coreset.elliptical.log_probs_percent_to_filter
        self.perform_pca = cfg.coreset.elliptical.perform_pca
        self.pca_dim = cfg.coreset.elliptical.pca_dim
        self.scale_features_with_log_probs = cfg.coreset.elliptical.scale_features_with_log_probs
        self.use_gradients = cfg.coreset.elliptical.use_gradients
        self.argmax = cfg.coreset.elliptical.argmax
        self.lamb = cfg.coreset.elliptical.lamb
        self.num_trials = cfg.coreset.elliptical.num_trials
        self.alpha = cfg.coreset.elliptical.alpha
        self.use_weird_sampling = cfg.coreset.elliptical.use_weird_sampling
        self.temp = cfg.coreset.elliptical.temp
        self.center_features = cfg.coreset.elliptical.center_features
        self.use_weird_sampling2 = cfg.coreset.elliptical.use_weird_sampling2

        self.selected_responses = defaultdict(dict)

    @timing
    def weird_sampling2(self, X, n_choose: int = -1, argmax: bool = True):
        n, d, k = X.shape
        chosen = []
        batch_size = 10000
        cov = torch.zeros(d,d).cuda() 
        tr_current = 0
        if n_choose == -1: n_choose = n
        for i in range(n_choose):

            # get tr contribution of each point
            traces = []
            for j in range(int(np.ceil(n / batch_size))):
                samps = X[j * batch_size : (j+1) * batch_size].cuda()
                val = 2 * samps.transpose(1,2) @ cov @ samps + (samps.transpose(1,2) @ samps) ** 2
                traces.append(-1 * val)
            traces = torch.cat(traces)

            # sample an index
            traces[chosen] = -1 * torch.inf
            if argmax: 
                chosen_ind = torch.argmax(traces).item()
            else: 
                probs = torch.softmax(traces, dim=0)
                chosen_ind = torch.multinomial(probs.flatten(), 1, replacement=False).item()

            # update inverse with woodbury identity
            chosen_samp = X[chosen_ind] #squeeze().cpu()
            cov = cov + chosen_samp @ chosen_samp.transpose(0,1)
            chosen.append(chosen_ind)
            tr_current += traces[chosen_ind].item()
            # print(i, tr_current, traces[chosen_ind].item(), flush=True)

        return chosen

    @timing
    def weird_sampling(self, X, n_choose: int = -1, lamb: float = 0.1, argmax: bool = True, ignore_idxs: List[int] = []):
        n, d, k = X.shape
        cov_inv = torch.eye(d).cuda() * lamb ** -1
        I = torch.eye(k, k).cuda()
        chosen = []
        batch_size = 20000
        tr_current = lamb * d
        if n_choose == -1: n_choose = n
        for i in range(n_choose):

            # get trace contribution of each point
            traces = []
            for j in range(int(np.ceil(n / batch_size))):
                samps = X[j * batch_size : (j+1) * batch_size]
                batch_cov_inv = cov_inv.unsqueeze(0).expand(len(samps), -1, -1)
                numerator = -1 * (samps.transpose(1,2) @ batch_cov_inv @ batch_cov_inv @ samps).view(-1)
                denominator = (1 + samps.transpose(1,2) @ batch_cov_inv @ samps).view(-1)
                traces.append(numerator / denominator)
            traces = torch.cat(traces)

            # sample an index
            traces[chosen] = -1 * torch.inf
            traces[ignore_idxs] = -1 * torch.inf
            if argmax: 
                chosen_ind = torch.argmax(traces).item()
            else: 
                probs = torch.softmax(traces, dim=0)
                chosen_ind = torch.multinomial(probs, 1, replacement=False).item()

            if i == 0 and self.alpha == 0.0:
                valid_indices = [idx for idx in range(n) if idx not in ignore_idxs]
                chosen_ind = np.random.choice(valid_indices)

            # update inverse with woodbury identity
            chosen_samp = X[chosen_ind] #squeeze().cpu()
            middle_part = torch.inverse(I + chosen_samp.t() @ cov_inv @ chosen_samp)
            cov_inv = cov_inv - cov_inv @ chosen_samp @ middle_part @ chosen_samp.t() @ cov_inv

            chosen.append(chosen_ind)
            tr_current += traces[chosen_ind].item()
            # print(i, tr_current, traces[chosen_ind].item(), flush=True)

        return chosen

    def det_sampling(
        self, 
        X: torch.Tensor, 
        n_choose: int = -1, 
        lamb: float = 0.1, 
        argmax: bool = True, 
        ignore_idxs: List[int] = [], 
        results: List[float] = None, 
        terminate_when_correct: bool = False
    ) -> List[int]:
        n, d, k = X.shape
        cov_inv = torch.eye(d).cuda() * lamb ** -1
        log_det_current = np.log(lamb) * d
        I = torch.eye(k, k).cuda()

        chosen = []
        batch_size = 20000
        if n_choose == -1: n_choose = n
        for i in range(n_choose):
            # get log determinantal contribution of each point
            dets = []
            for j in range(int(np.ceil(n / batch_size))):
                samps = X[j * batch_size : (j+1) * batch_size]
                batch_cov_inv = cov_inv.unsqueeze(0).expand(len(samps), -1, -1)
                dets.append(torch.logdet(I + samps.transpose(1,2) @ batch_cov_inv @ samps))
            dets = torch.cat(dets)
                    
            # sample an index
            dets[chosen] = -1 * torch.inf
            dets[ignore_idxs] = -1 * torch.inf
            if argmax: 
                chosen_ind = torch.argmax(dets).item()
            else:
                probs = torch.softmax(dets, dim=0)
                chosen_ind = torch.multinomial(probs, 1, replacement=False).item()
            
            if i == 0 and self.alpha == 0.0 and not self.scale_features_with_log_probs:
                valid_indices = [idx for idx in range(n) if idx not in ignore_idxs]
                chosen_ind = np.random.choice(valid_indices)

            # update inverse with woodbury identity
            chosen_samp = X[chosen_ind] #squeeze().cuda()
            middle_part = torch.inverse(I + chosen_samp.t() @ cov_inv @ chosen_samp)
            cov_inv = cov_inv - cov_inv @ chosen_samp @ middle_part @ chosen_samp.t() @ cov_inv

            chosen.append(chosen_ind)
            log_det_current += dets[chosen_ind].item()

            if terminate_when_correct and results[chosen_ind]:
                break
            # print(i, log_det_current, dets[chosen_ind].item(), flush=True)

        return chosen

    def pca(self, model_features: torch.Tensor) -> torch.Tensor:
        features = model_features.cpu().numpy()
        # features = StandardScaler().fit_transform(features) # TODO: keep this or not?
        pca = PCA(n_components=self.pca_dim)
        pca.fit(features)
        return pca.transform(features)

    @timing
    def coreset(
        self, 
        model_answers: Dict[int, List[int]],
        model_results: Dict[int, List[float]],
        model_features: Dict[int, torch.Tensor],
        terminate_when_correct: bool = False
    ) -> List[str]:

        elliptical_features = model_features

        if self.perform_sparse_projection:
            sparse_matrix = construct_sparse_matrix(elliptical_features[list(elliptical_features.keys())[0]], self.sparse_dim).cuda()

        self.coreset_answers = {}
        self.coreset_results = {}
        self.coreset_indices = {}
        self.coreset_responses = {}
        self.coreset_samples_to_get_correct = {}
        for example_id, answers in tqdm(model_answers.items(), desc="Constructing coreset ..."):
            if self.perform_sparse_projection:
                cur_model_features = (elliptical_features[example_id].cuda() @ sparse_matrix).unsqueeze(-1)
            elif self.perform_pca:
                cur_model_features = torch.from_numpy(self.pca(elliptical_features[example_id])).float().unsqueeze(-1).cuda()
            else:
                cur_model_features = elliptical_features[example_id].unsqueeze(-1).cuda()

            if self.center_features:
                cur_model_features = cur_model_features - cur_model_features.mean(dim=0, keepdim=True)

            ignore_idxs = []

            if self.use_weird_sampling:
                log.info(f"Weird sampling example {example_id} with {cur_model_features.shape[0]} points")
                indices = self.weird_sampling(
                    cur_model_features, 
                    n_choose=min(max(self.ks), cur_model_features.shape[0] - len(ignore_idxs)), 
                    argmax=self.argmax, 
                    lamb=self.lamb,
                    ignore_idxs=ignore_idxs, 
                )
            elif self.use_weird_sampling2:
                log.info(f"Weird sampling 2 example {example_id} with {cur_model_features.shape[0]} points")
                indices = self.weird_sampling2(
                    cur_model_features, 
                    n_choose=min(max(self.ks), cur_model_features.shape[0] - len(ignore_idxs)), 
                    argmax=self.argmax
                )
            else:
                log.info(f"Determinantal sampling example {example_id} with {cur_model_features.shape[0]} points")
                if not any(model_results[example_id]):
                    log.info(f"No correct responses for example {example_id} ...")
                    indices = []

                else:
                    indices = self.det_sampling(
                        cur_model_features, 
                        n_choose=min(max(self.ks), cur_model_features.shape[0] - len(ignore_idxs)), 
                        argmax=self.argmax, 
                        ignore_idxs=ignore_idxs, 
                        lamb=self.lamb,
                        results=model_results[example_id],
                        terminate_when_correct=terminate_when_correct
                    )

            self.coreset_answers[example_id] = [answers[i] for i in indices]
            self.coreset_results[example_id] = [model_results[example_id][i] for i in indices]
            self.coreset_indices[example_id] = indices
            self.coreset_samples_to_get_correct[example_id] = len(indices) if len(indices) > 0 else None

        return self.coreset_results
    
    @timing
    def pass_at_k(
        self, 
        model_answers: Dict[int, List[int]], 
        model_results: Dict[int, List[float]], 
        model_features: Dict[int, torch.Tensor],
    ) -> Dict[int, float]:
        
        if self.argmax:
            k_to_coreset_results = defaultdict(list)
            coreset_results = self.coreset(
                model_answers, 
                model_results, 
                model_features,
                terminate_when_correct=True
            )
            for k in tqdm(self.ks, desc="Computing pass@k ..."):
                for example_id, results in coreset_results.items():
                    k_to_coreset_results[k].append(any(results[:k]))
                    
            return k_to_coreset_results
        else:
            all_results = []
            for _ in range(self.num_trials):
                k_to_coreset_results = defaultdict(list)
                coreset_results = self.coreset(model_answers, model_results, model_features)
                for k in tqdm(self.ks, desc="Computing pass@k ..."):
                    for example_id, results in coreset_results.items(): 
                        k_to_coreset_results[k].append(any(results[:k]))
                        
                all_results.append({k: np.mean(v) for k, v in k_to_coreset_results.items()})

            return {k: np.mean([result[k] for result in all_results]) for k in self.ks}

    def samples_to_get_correct(
        self, 
        model_answers: Dict[int, List[int]], 
        model_results: Dict[int, List[float]], 
        model_features: Dict[int, torch.Tensor], 
        model_responses: Dict[int, List[str]], 
        model_prompts: Dict[int, List[str]]
    ) -> Dict[int, float]:
        all_results = []
        for trial_num in range(self.num_trials):
            self.coreset(
                model_answers, 
                model_results, 
                model_features,  
                terminate_when_correct=True
            )
            all_results.append(self.coreset_samples_to_get_correct)

            for example_id, idxs in self.coreset_indices.items():
                self.selected_responses[example_id][trial_num] = [
                    {
                        "response": model_responses[example_id][idx],
                        "answer": model_answers[example_id][idx],
                        "result": model_results[example_id][idx],
                    } 
                    for idx in idxs
                ]

        return {k: np.mean([d[k] for d in all_results]) for k in all_results[0].keys()}, {k: [d[k] for d in all_results] for k in all_results[0].keys()}


    @timing
    def avg_reward_at_k(self, model_rewards: Dict[int, List[float]]) -> Dict[int, float]:
        k_to_avg_reward = defaultdict(list)
        for k in tqdm(self.ks, desc="Computing avg reward at k ..."):
            for example_id, rewards in model_rewards.items():
                picked_rewards = [rewards[i] for i in self.coreset_indices[example_id]]
                k_to_avg_reward[k].append(np.mean(picked_rewards[:k]))

        return {k: np.mean(v) for k, v in k_to_avg_reward.items()}
    
    @timing
    def unique_answers_at_k(self, model_answers: Dict[int, List[int]], model_results: Dict[int, List[float]], model_features: Dict[int, torch.Tensor], model_gradients: Dict[int, torch.Tensor], model_rewards: Dict[int, List[float]], model_log_probs: Dict[int, List[float]], model_responses: Dict[int, List[str]], model_prompts: Dict[int, List[str]], gt_answers: Dict[int, int]) -> Dict[int, float]:
        
        if self.argmax:
            k_to_num_unique_answers = defaultdict(list)
            self.coreset(model_answers, model_results, model_features, model_gradients, model_rewards, model_log_probs, model_responses, model_prompts)
            for k in tqdm(self.ks, desc="Computing unique answers at k ..."):
                for example_id, answers in self.coreset_answers.items(): 
                    k_to_num_unique_answers[k].append(len(set(answers[:k])))
                    
            return {k: np.mean(v) for k, v in k_to_num_unique_answers.items()}
        else:
            all_results = []
            for _ in range(self.num_trials):
                k_to_num_unique_answers = defaultdict(list)
                self.coreset(model_answers, model_results, model_features, model_gradients, model_rewards, model_log_probs, model_responses, model_prompts)
                for k in tqdm(self.ks, desc="Computing unique answers at k ..."):
                    for example_id, answers in self.coreset_answers.items(): 
                        k_to_num_unique_answers[k].append(len(set(answers[:k])))
                        
                all_results.append({k: np.mean(v) for k, v in k_to_num_unique_answers.items()})

            return {k: np.mean([result[k] for result in all_results]) for k in self.ks}
