# Standard packages
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os
from pathlib import Path
import numpy as np
import pandas as pd
from collections import defaultdict

# Importing from files

from util import get_best_subdir, make_parent_dir
#from models import BetaProcess, SequentialPredictor

#################### Prediction Evaluation and Plotting #####################

def get_prediction_metric_values(pred, click_rate, metric):
    assert len(click_rate.shape)==1
    reshaped_click_rate = click_rate.unsqueeze(-1).repeat(1,pred.shape[1])
    if metric == 'mse':
        return F.mse_loss(pred, reshaped_click_rate, reduction='none')
    elif metric == 'joint_loss':
        return F.binary_cross_entropy(pred, reshaped_click_rate, reduction='none')
    else:
        raise ValueError('metric must be mse, or joint_loss')


def make_plot_from_predictions(pred_dict, metric, click_rate, timesteps=20, ax=None, 
                               linestyle_dict=None, dict_key=None, ylim=None, title="", skip_val=1,
                                ylogscale=False, plotkeys=None):
    '''
    timesteps: max number of t (num obs) for plotting
    pred_dict: dict of predictions
    click_rate: true click rates
    '''
    assert metric in ['mse','joint_loss']
    if linestyle_dict is None:
        linestyle_dict = defaultdict(lambda: '-')
    K = timesteps
    
    if ax is None:
        fig, ax = plt.subplots(1)
        
    for k,v in pred_dict.items():
        if plotkeys is not None and k not in plotkeys:
            continue
        if dict_key is None:
            metric_values = get_prediction_metric_values(v, click_rate, metric)
        else:
            metric_values = get_prediction_metric_values(v[dict_key], click_rate, metric)
            
        if timesteps is not None:
            metric_values = metric_values[:,:K][:,::skip_val]
        metric_means = metric_values.mean(dim=0)
        metric_sd = metric_values.std(dim=0) / metric_values.shape[0]**0.5

        ax.errorbar(x=np.arange(K)[::skip_val],
                    y=metric_means,
                    yerr=metric_sd, 
                    label=k, 
                    linestyle=linestyle_dict[k])
        ax.set_ylabel(metric.upper())
    ax.set_xlabel('Number of observations')
    ax.legend();
    if ylim is not None:
        ax.set_ylim(ylim[0], ylim[1])
    if ylogscale:
        ax.set_yscale('log')

    ax.set_title(title)

##################### General plotting ##########################

def plot_scatter(x_vals, y_vals, ax=None, alpha=1, buffer=None, s=5, title=None, xlabel=None, ylabel=None):
    if ax is None:
        fig,ax = plt.subplots(1,1)
    ax.scatter(x_vals, y_vals, alpha=alpha, s=s)
    all_vals = np.concatenate([x_vals, y_vals])
    if buffer is None:
        buffer = (np.max(all_vals) - np.min(all_vals) ) / 80
    valrange = (np.min(all_vals)-buffer, np.max(all_vals)+buffer)
    ax.set_xlim(valrange)
    ax.set_ylim(valrange)
    vals = np.arange(valrange[0], valrange[1]+1, 1)
    ax.plot(vals, vals, color='r', alpha=0.3)
    ax.set_aspect('equal')
    if title is not None:
        ax.set_title(title)
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)


#################### Custom Prediction Plotting ########################


################### Regret loading ##########################

def load_bandit_rewards_from_dgp(bandit_dir, all_bandit_envs, N_monte_carlo):
    env_rewards_dict = {}
    env_exp_rewards_dict = {}
    success_p_dict = {}
    action_arms = {}
    pred_probs = {}
    for f in os.listdir(bandit_dir):
        idx = int(f.split('.')[0].split('=')[1])
        if idx >= len(all_bandit_envs): continue
        
        c = torch.load(bandit_dir + '/' + f)
        if 'reward_dict' not in c.keys():
            print(c.keys())
        env_exp_rewards_dict[idx] = c['reward_dict']['expected_rewards']
        env_rewards_dict[idx] = c['reward_dict']['rewards']
        success_p_dict[idx] = c['success_p']
        action_arms[idx] = c['reward_dict']['action_arms']
        pred_probs[idx] = c['reward_dict']['extras']

    missing_idx = [ idx for idx in range(N_monte_carlo) if idx not in env_rewards_dict.keys() ]
    if len(missing_idx) > 0:
        print('missing env_ids', ' '.join([str(x) for x in missing_idx]))
    
    all_rewards = [ env_rewards_dict[idx] for idx in range(len(all_bandit_envs)) ]
    all_expected_rewards = [ env_exp_rewards_dict[idx] for idx in range(len(all_bandit_envs)) ]
    all_action_arms = [ action_arms[idx] for idx in range(len(all_bandit_envs)) ]
    all_pred_probs = [ pred_probs[idx] for idx in range(len(all_bandit_envs)) ]
    all_success_p = [ success_p_dict[idx].unsqueeze(0) for idx in range(len(all_bandit_envs)) ]
    return {'expected_rewards':np.array(all_expected_rewards), 
            'rewards': np.array(all_rewards),
            'success_p': torch.concatenate(all_success_p,dim=0),
           'action_arms': all_action_arms, 'pred_prob': all_pred_probs}
