import math
from tqdm import tqdm
from operator import itemgetter

import numpy as np
import pandas as pd
from scipy.optimize import root_scalar
from statistics import mean, stdev

from valuation import get_time_aware_v, get_shapley_values, get_reward_cumulation, get_scaling_factor
from utils import read_otherwise_generate

def scale_rewards(scaling_factor, rewards):
    return [reward * scaling_factor for reward in rewards]

def get_reward_gamma(v, times, gamma):
    assert max(len(s) for s in v.keys()) == len(times)
    weight_function_gamma = lambda times, gamma: [math.exp(-time * gamma) for time in times]
    weight_function = lambda times: weight_function_gamma(times, gamma)
    time_aware_v = get_time_aware_v(v, times, weight_function)
    rewards = get_shapley_values(time_aware_v)
    scaling_factor = get_scaling_factor(v)
    return scale_rewards(scaling_factor, rewards)

def get_reward_beta(v, times, beta):
    assert max(len(s) for s in v.keys()) == len(times)
    rewards = get_reward_cumulation(v, times, beta)
    scaling_factor = get_scaling_factor(v)
    return scale_rewards(scaling_factor, rewards)

def get_reward_value(dataset, scenario, **kwargs):
    filename = f'{dataset}_{scenario}'
    return read_otherwise_generate('value', generate_reward_value, filename, **kwargs)

def get_reward_noise(dataset, scenario, **kwargs):
    filename = f'{dataset}_{scenario}'
    return read_otherwise_generate('noise', generate_reward_noise, filename, **kwargs)

def get_reward_subset(dataset, scenario, **kwargs):
    filename = f'{dataset}_{scenario}'
    return read_otherwise_generate('subset', generate_reward_subset, filename, **kwargs)

def generate_reward_value(filename, v, gamma_list, beta_list):
    n = max(len(s) for s in v.keys())
    times = np.zeros(n)
    rows = []
    for gamma in gamma_list:
        for t in range(5):
            times[0] = t
            rewards = get_reward_gamma(v, times, gamma=gamma)
            rows.append([t] + rewards + [gamma, np.nan])
    for beta in beta_list:
        for t in range(5):
            times[0] = t
            rewards = get_reward_beta(v, times, beta=beta)
            rows.append([t] + rewards + [np.nan, beta])
    df = pd.DataFrame(rows, columns=['$t$', '$r_1$', '$r_2$', '$r_3$', r'$\gamma$', r'$\beta$'])
    df['$r_1-r_2$'] = df['$r_1$'] - df['$r_2$']
    df['$r_1-r_3$'] = df['$r_1$'] - df['$r_3$']
    return df

def generate_reward_noise(filename, gamma_list, beta_list, gp, X_splits, y_splits, noise_splits, X_test, y_test, noise_test, v):
    n = len(X_splits)
    times = np.zeros(n)
    rows = []
    for gamma in tqdm(gamma_list):
        for t in range(10):
            times[0] = t
            reward = Noise_Reward(gp, X_splits, y_splits, noise_splits, None, gamma, times, v,  method="cig", compute_all=True)
            rs= reward.get_rewarded_eval(X_test, y_test, noise_test)
            pdf01 = pd.DataFrame(rs)
            pdf01["$t$"] = t
            pdf01["beta"] = np.nan
            pdf01["gamma"] = gamma
            rows.append(pdf01)
        
    for beta in tqdm(beta_list):
        for t in range(10):
            times[0] = t
            reward = Noise_Reward(gp, X_splits, y_splits, noise_splits, beta, None, times, v, method="dr", compute_all=True)
            rs= reward.get_rewarded_eval(X_test, y_test, noise_test)
            pdf01 = pd.DataFrame(rs)
            pdf01["$t$"] = t
            pdf01["beta"] = beta
            pdf01["gamma"] = np.nan
            rows.append(pdf01)
    pdf02 = pd.concat(rows).reset_index()
    pdf_noise_mp = pdf02.rename(columns = {
        0: '$m_1$',
        1: '$m_2$',
        2: '$m_3$',
        'gamma': r'$\gamma$',
        'beta': r'$\beta$',
    })
    pdf_noise_mp['$m_1-m_2$'] = pdf_noise_mp['$m_1$'] - pdf_noise_mp['$m_2$']
    pdf_noise_mp['$m_1-m_3$'] = pdf_noise_mp['$m_1$'] - pdf_noise_mp['$m_3$']
    return pdf_noise_mp

def generate_reward_subset(filename, gamma_list, beta_list, gp, X_splits, y_splits, noise_splits, X_test, y_test, noise_test, v):
    ll_rs = []
    use_binary_search = False
    for gamma in tqdm(gamma_list):
        time_aware = np.zeros(len(X_splits))
        for t in range(10):
            time_aware[0] = t
            rw = Subset_Reward(gp, X_splits, y_splits, noise_splits, None, gamma, time_aware, v)
            rs_gamma = rw.get_rewarded_eval(X_test, y_test, noise_test, method='cig', bi=use_binary_search) # test 1 first
            pdf_gamma = pd.DataFrame(rs_gamma, index=["mnlp", "mse", "std_mnlp", "std_mse"])
            pdf_gamma[r"$\gamma$"] = gamma
            pdf_gamma[r"$\beta$"] = np.nan
            pdf_gamma["$t$"] = t
            ll_rs.append(pdf_gamma)
    for beta in tqdm(beta_list):
        time_aware = np.zeros(len(X_splits))
        for t in range(10):
            time_aware[0] = t
            rw = Subset_Reward(gp, X_splits, y_splits, noise_splits, beta, None, time_aware, v)
            rs_beta = rw.get_rewarded_eval(X_test, y_test, noise_test, method='dr', bi=use_binary_search) #default k = 1
            pdf_beta = pd.DataFrame(rs_beta, index=["mnlp", "mse", "std_mnlp", "std_mse"])
            pdf_beta[r"$\gamma$"] = np.nan
            pdf_beta[r"$\beta$"]  = beta
            pdf_beta["$t$"] = t
            ll_rs.append(pdf_beta)
    df_subset = pd.concat(ll_rs).reset_index().rename(columns = {0:'$m_1$', 1:'$m_2$', 2:'$m_3$'})
    df_subset['$m_1-m_2$'] = df_subset['$m_1$'] - df_subset['$m_2$']
    df_subset['$m_1-m_3$'] = df_subset['$m_1$'] - df_subset['$m_3$']
    return df_subset

class Noise_Reward:
    def __init__(self, model, Xs, ys, noises, beta, gamma, time_aware, v, method="cig", verbose=True, compute_all=False):
        self.model = model
        self.Xs = Xs
        self.ys = ys
        self.noises = noises
        self.beta = beta
        self.gamma = gamma
        self.time_aware = time_aware
        self.v = v
        # if valuation is None:
        #     self.valuation = val.Valuation(model, Xs, noises)
        # else:
        #     self.valuation = valuation
        if beta is None:
            self.target_rewards = {
                "cig": get_reward_gamma(self.v, self.time_aware, gamma=gamma),
                "dr": None
            }
        if gamma is None:
            self.target_rewards = {
                "cig": None,
                "dr": get_reward_beta(self.v, self.time_aware, beta=beta)
            }
        # self.target_rewards = {
        #     "cig": self.valuation.get_reward(self.time_aware, 'cig', gamma=gamma),
        #     "dr": self.valuation.get_reward(self.time_aware, 'dr', beta = beta) #distribution reward
        # }
        self.method=method
        self.all_X = np.vstack(self.Xs)
        self.all_Y = np.vstack(self.ys).reshape(-1,1)
        self.all_noise = np.vstack(self.noises)
        if compute_all:
            self.tempers = []
            self.add_noise_vectors = []
            self.add_noise_i_party = []
            self.nonIR = []
            for party in range(len(Xs)):
                i_temper, add_noise_vector, add_noise_i_party, nonIR = self.get_rewarded_noise_vector(party, method)
                self.tempers.append(i_temper)
                self.add_noise_vectors.append(add_noise_vector)
                self.add_noise_i_party.append(add_noise_i_party)
                self.nonIR.append(nonIR)

    def get_rewarded_noise_vector(self, party_i, method="cig", seed=0):
        np.random.seed(party_i + seed)
        if method not in self.target_rewards.keys():
            raise ValueError('Method should be one of two kind - cig or dr')
        target_val = self.target_rewards[method][party_i]
        v_val = self.v[frozenset({party_i})]
        NonIR = target_val < v_val
        if NonIR:
            print("non-ir values", party_i, target_val, v_val)

        party_noise_mask = self.noises.copy()
        for j in range(len(party_noise_mask)): ## Need to refix later!
            if party_i == j:
                party_noise_mask[j] = np.zeros_like(party_noise_mask[j])
                if NonIR:
                    party_noise_mask[j] = np.ones_like(party_noise_mask[j])
            else:
                party_noise_mask[j] = np.ones_like(party_noise_mask[j])
                if NonIR :
                    party_noise_mask[j] = np.zeros_like(party_noise_mask[j])

        party_noise_mask_all = np.vstack(party_noise_mask)
        remaining = [i for i in range(len(party_noise_mask)) if i != party_i]
        party_noise_mask_remaining = np.vstack(itemgetter(*remaining)(party_noise_mask))
        noise_remaining = np.vstack(itemgetter(*remaining)(self.noises))
        remaining_X = np.vstack(itemgetter(*remaining)(self.Xs))

        def solve_temper(k):
            cig_all = self.model.mi(self.all_X, self.all_noise)
            cig_remain = self.model.mi(remaining_X, noise_remaining/k)
            if NonIR:
                cig_remain = self.model.mi(remaining_X, noise_remaining) 
                cig_all = self.model.mi(self.all_X, self.all_noise + party_noise_mask_all*(self.all_noise/k - self.all_noise))
            return  target_val - cig_all + cig_remain
            
        #check if ri ~= vn
        n_players = len(self.Xs)
        idxes  = {ii for ii in range(n_players)}
        v_n = self.v[frozenset(idxes)]
        if target_val >= v_n - 1e-6:
            print('Exception case: too close')
            return 1., 0. * party_noise_mask_all, None,  NonIR
            
        sol = root_scalar(solve_temper, method='toms748', bracket=(1e-16, 1)).root
        if NonIR:
            i_temper = sol
        else:
            i_temper = (1-sol)
        i_added_variance = 1/i_temper - 1

        # print("Party {}, target {} solution: tempering factor {} variance {}".format(party_i, target_val, i_temper, i_added_variance))
        if NonIR:
            return i_temper, None, i_added_variance*party_noise_mask[party_i], NonIR
        return i_temper, i_added_variance * party_noise_mask_all, None , NonIR

    def i_party_reward(self, party, X_test, y_test, test_noise, num_repeats=1):
        if self.nonIR[party]:
            X_ = self.Xs[party]
            Y_ = self.ys[party]
            noise_ = self.noises[party] + self.add_noise_i_party[party]
        else:
            X_ = self.all_X
            Y_ = self.all_Y 
            noise_ =self.all_noise + self.add_noise_vectors[party]
        return self.model.evaluation(X_, Y_, noise_, X_test, y_test, test_noise)
                
    def get_rewarded_eval(self, X_test, y_test, test_noise, party=None, num_repeats=1):
        if party is not None:
            return self.i_party_reward(party, X_test, y_test, test_noise, num_repeats)
        else:
            rs = {}
            for party in range(len(self.Xs)):
                rs[party] = self.i_party_reward(party, X_test, y_test, test_noise, num_repeats)
            return rs

class Subset_Reward:
    def __init__(self, model, Xs, ys, noises, beta, gamma, time_aware, v, verbose=True):
        self.model = model
        self.Xs = Xs
        self.ys = ys
        self.noises = noises
        self.beta = beta
        self.gamma = gamma
        self.time_aware = time_aware
        self.all_X = np.vstack(self.Xs)
        self.all_Y = np.vstack(self.ys)
        self.all_noise = np.vstack(self.noises)
        self.n_players = len(Xs)
        self.v = v

        if beta is None:
            self.target_rewards = {
                "cig": get_reward_gamma(self.v, self.time_aware, gamma=gamma),
                "dr": None
            }
        if gamma is None:
            self.target_rewards = {
                "cig": None,
                "dr": get_reward_beta(self.v, self.time_aware, beta=beta)
            }
        # self.target_rewards = {
        #     "cig": get_reward_gamma(self.v, self.time_aware, gamma=gamma),
        #     "dr": get_reward_beta(self.v, self.time_aware, beta=beta) #distribution reward
        # }

    def i_party_get_reward(self, party_id, X_test, y_test, noise_test, method='cig'): #how much increase we want!
        np.random.seed(party_id)
        D_i = self.Xs[party_id]
        y_i = self.ys[party_id]
        noise_i = self.noises[party_id]
        target_i = self.target_rewards[method][party_id]
        # v_val_i = self.v[frozenset({party_id})]
        id_remaining = [id for id in range(self.n_players) if id != party_id]
        X_remain = np.vstack(itemgetter(*id_remaining)(self.Xs))
        Y_remain = np.vstack(itemgetter(*id_remaining)(self.ys))
        noise_remain = np.vstack(itemgetter(*id_remaining)(self.noises))

        X_a, Y_a, noise_a = X_remain, Y_remain, noise_remain
        uptos = range(len(Y_a)+1)

        id_shuffled = np.random.permutation(len(Y_a))
        mi_all = self.model.mi(self.all_X, self.all_noise)
        
        for upto in uptos:
            selected = id_shuffled[:upto]
            X_selected = X_a[selected]
            Y_selected = Y_a[selected]
            noise_selected = noise_a[selected]
            unselected = id_shuffled[upto:]
            X_unselected = X_a[unselected]
            noise_unselected = noise_a[unselected]
            cig = mi_all - self.model.mi(X_unselected, noise_unselected)
            R_xi = np.vstack((D_i, X_selected))
            R_yi = np.vstack((y_i, Y_selected))
            R_noise = np.vstack((noise_i, noise_selected))
            if cig >= target_i:
                break
    
        mnlp = self.model.mnlp(R_xi, R_yi, R_noise, X_test, y_test, noise_test)
        mse = self.model.mse(R_xi, R_yi, R_noise, X_test, y_test, noise_test)
        return mnlp, mse, 0, 0
    
    def i_binary_search(self, party_id, X_test, y_test, noise_test, method):
        """ Method 1 compute and return cig every time, alternative method is that we can compute one time and run multiple time """    
        seed = 42
        np.random.seed(party_id + seed)
        D_i = self.Xs[party_id]
        y_i = self.ys[party_id]
        noise_i = self.noises[party_id]
        id_remaining = [id for id in range(self.n_players) if id != party_id]
        X_remain = np.vstack(itemgetter(*id_remaining)(self.Xs))
        Y_remain = np.vstack(itemgetter(*id_remaining)(self.ys))
        noise_remain = np.vstack(itemgetter(*id_remaining)(self.noises))
        mi_all = self.model.mi(self.all_X, self.all_noise)
        
        target_i = self.target_rewards[method][party_id]
        v_val_i = self.v[frozenset({party_id})]    
        
        X_a, Y_a, noise_a = X_remain, Y_remain, noise_remain
        id_shuffled = np.random.permutation(len(Y_a)) 

        low = 0
        high= len(Y_a)
        best_cig = 0
        first = True
        list_mid = []
        list_sub = []
            
        while low <= high:
            mid = int((low + high)/2) 
            selected = id_shuffled[:mid]
            unselected = id_shuffled[mid:]
            X_selected = X_a[selected]
            Y_selected = Y_a[selected]
            noise_selected = noise_a[selected]
            unselected = id_shuffled[mid:]
            X_unselected = X_a[unselected]
            noise_unselected = noise_a[unselected]
            cig = mi_all - self.model.mi(X_unselected, noise_unselected)
            list_mid.append(mid)
            list_sub.append(abs(cig - target_i))
            if cig < target_i:
                low = mid + 1
            else:
                high = mid - 1
        best_ind = mid + 1
        print("Final choose :",best_ind)       
            
        selected = id_shuffled[:best_ind]
        X_selected = X_a[selected]
        Y_selected = Y_a[selected]
        noise_selected = noise_a[selected]
            
        R_xi = np.vstack((D_i, X_selected))
        R_yi = np.vstack((y_i, Y_selected))
        R_noise = np.vstack((noise_i, noise_selected))

        mnlp = self.model.mnlp(R_xi, R_yi, R_noise, X_test, y_test, noise_test)
        mse = self.model.mse(R_xi, R_yi, R_noise, X_test, y_test, noise_test)
            
        return mnlp, mse
        
    def i_party_get_reward_bi(self, party_id, X_test, y_test, noise_test, method):
        mnlp, mse = self.i_binary_search(party_id, X_test, y_test, noise_test, method)
        return mnlp, mse, 0, 0
        
    def get_rewarded_eval(self, X_test, y_test, noise_test, method='cig', bi=False):
        rs = {}
        for i in range(self.n_players):
            if bi:
                rs[i] = self.i_party_get_reward_bi(i, X_test, y_test, noise_test, method)
            else:
                rs[i] = self.i_party_get_reward(i, X_test, y_test, noise_test, method)
        return rs
