import numpy as np
from scipy.stats import truncnorm
import matplotlib.pyplot as plt
from scipy.optimize import minimize_scalar
from scipy.stats import norm as scipy_norm
from common.variables import *
from embeddings.nflows_sphere_inv import *
import os
from ol.oful_proj import *
from ol.oful_bandit import *

class NewsvendorPricingGame:
    def __init__(self, gamma_0, gamma_1, sigma, T, leader_strategy=None, follower_strategy=None):
        assert gamma_0 > 0 and gamma_1 < 0, "Ensure gamma_0 > 0 and gamma_1 < 0"
        self.gamma_0 = gamma_0
        self.gamma_1 = gamma_1
        self.sigma = sigma
        self.T = T
        self.leader_strategy = leader_strategy
        self.follower_strategy = follower_strategy
        self.leader_rewards = []
        self.follower_rewards = []
        self.leader_actions = []
        self.follower_actions = []
        self.history = []
        self.t = 0
    
    def generate_demand(self, p):
        lower, upper = 0, np.inf
        exp_demand = self.gamma_0 + self.gamma_1 * p
        uncut_demand = np.random.normal(exp_demand, self.sigma)
        return np.clip(uncut_demand, lower, upper)
    
    def simulate(self, span_T=None, a=None, b=None, p=None):
        if span_T is None:
            span_T = self.T
        for t in range(span_T):
            self.sim_iter(a, b, p)
        
    def sim_iter(self, a=None, b=None, p=None):
        if a is None:
            a = self.leader_strategy(self.history)            
            b, p = self.follower_strategy(self.history, a)
            
        d = np.clip(self.generate_demand(p), 0, np.inf)
        leader_reward = a * b
        follower_reward = p * min(d, b) - b*a
        self.leader_actions.append(a)
        self.follower_actions.append((b, p))
        self.leader_rewards.append(leader_reward)
        self.follower_rewards.append(follower_reward)
        self.history.append({
            'iteration': self.t+1,
            'leader_action': a,
            'demand': d,
            'follower_action': {'b': b, 'p': p},
            'leader_reward': leader_reward,
            'follower_reward': follower_reward
        })
        self.t += 1
        return leader_reward, follower_reward
    
    def get_rewards(self):
        return self.leader_rewards, self.follower_rewards
    
    def get_actions(self):
        return self.leader_actions, self.follower_actions

    def get_history(self):
        return self.history

def objective_function(a, gamma_0_opt, gamma_1_opt):
    mean = gamma_0_opt / (-2 * gamma_1_opt)
    cf = 1 - (2 * a) / (gamma_0_opt / (-1*gamma_1_opt) + a)
    if cf <= 0 or cf >= 1:
        return -np.inf
    F_inv = scipy_norm.ppf(cf, loc=mean)
    return a * F_inv

def find_a_opt(gamma_0_opt, gamma_1_opt, lower_bound=0, upper_bound=1, obj_func=objective_function):
    res = minimize_scalar(lambda a: -obj_func(a, gamma_0_opt, gamma_1_opt), 
                            bounds=(lower_bound, upper_bound), method='bounded')
    return res.x if res.success else None

def lin_reg(manifold_points, rewards, reg_lambda=1e-6):
    mask = ~(torch.isnan(manifold_points).any(dim=1) | torch.isnan(rewards) |
             torch.isinf(manifold_points).any(dim=1) | torch.isinf(rewards))
    if mask.sum() < manifold_points.size(0):
        warnings.warn(f"Removed {manifold_points.size(0) - mask.sum().item()} rows due to NaNs or Infs.")
    manifold_points_clean = manifold_points[mask]
    rewards_clean = rewards[mask]
    XtX = manifold_points_clean.T @ manifold_points_clean
    XtX_reg = XtX + reg_lambda * torch.eye(XtX.size(0))
    XtX_inv = torch.pinverse(XtX_reg)
    XtY = manifold_points_clean.T @ rewards_clean
    weights = XtX_inv @ XtY
    weights = weights.squeeze()
    return weights

def optimistic_leader_strategy(history, lower_bound=0, upper_bound=1, kappa=KAPPA_LIU_RONG):
    if len(history) == 0:
        return np.random.uniform([lower_bound, upper_bound])[0]
    demand_history = [h['demand'] for h in history]
    price_history = [h['leader_action'] for h in history]
    x_reg = np.column_stack((np.ones(len(price_history)), price_history))
    x_reg_tensor = torch.tensor(x_reg).float()
    demand_history_tensor = torch.tensor(demand_history).float()
    gammas = lin_reg(x_reg_tensor, demand_history_tensor)
    gamma_0, gamma_1 = gammas
    t = len(history)
    gamma_lp_bound = kappa * np.sqrt(np.log(t) / t)
    gamma_0_optimistic = gamma_0 + gamma_lp_bound
    gamma_1_optimistic = gamma_1 - gamma_lp_bound
    a_opt = find_a_opt(gamma_0_optimistic, gamma_1_optimistic)
    return a_opt

def reactive_follower_liu_rong_strategy(history, a, b_lower=B_LOWER, b_upper=B_UPPER, p_lower=P_LOWER, p_upper=P_UPPER, kappa=KAPPA_LIU_RONG):
    if len(history) == 0:
        return np.random.uniform([a, A_UPPER_RAND])[0], np.random.uniform([0, B_UPPER_RAND])[0]
    demand_history = [h['demand'] for h in history]
    price_history = [h['leader_action'] for h in history]
    x_reg = np.column_stack((np.ones(len(price_history)), price_history))
    x_reg_tensor = torch.tensor(x_reg).float()
    demand_history_tensor = torch.tensor(demand_history).float()
    gammas = lin_reg(x_reg_tensor, demand_history_tensor)
    gamma_0, gamma_1 = gammas
    t = len(history)
    gamma_lp_bound = kappa * np.sqrt(np.log(t) / t)
    gamma_0_optimistic = gamma_0 + gamma_lp_bound
    gamma_1_optimistic = gamma_1 - gamma_lp_bound
    cf = objective_function(a, gamma_0_optimistic, gamma_1_optimistic) 
    p = gamma_0_optimistic/(2*(-gamma_1_optimistic))
    mean_demand = gamma_0_optimistic + gamma_1_optimistic * p
    if cf > 0:
        b = scipy_norm.ppf(cf, loc=mean_demand)
    else:
        b = b_lower
    b = np.clip(b, b_lower, b_upper)
    return b, p

def reactive_follower_strategy(history, a, b_lower=0, b_upper=1, p_lower=0, p_upper=1):
    b = np.clip(b_upper - a, b_lower, b_upper)
    p = np.clip(p_lower + a / 2, p_lower, p_upper)
    return b, p

def reactive_riskless_follower_strategy(history, a, b_lower=0, b_upper=np.Inf, p_lower=0, p_upper=1):
    p = np.max([((GAMMA_0/(-1*GAMMA_1))+a)/2, a+0.05])
    d_exp = GAMMA_0 + GAMMA_1 * p
    cf = (p-a)/p
    b = scipy_norm.ppf(cf, loc=d_exp)
    b = np.clip(b, b_lower, b_upper)
    if np.isnan(b):
        b = b_upper
    return b, p

def random_leader_strategy(history, lower_bound=0, upper_bound=1):
    return np.random.uniform(lower_bound, upper_bound)

def random_follower_strategy(history, a, b_lower=0, b_upper=1, p_lower=0, p_upper=1):
    b = np.random.uniform(b_lower, b_upper)
    p = np.random.uniform(a, p_upper)
    return b, p

def run_liu_rong(game_liu_rong):
    game_liu_rong.leader_strategy = optimistic_leader_strategy
    game_liu_rong.follower_strategy = reactive_follower_liu_rong_strategy
    game_liu_rong.simulate(BANDIT_T)
    leader_rewards_full, follower_rewards_full = game_liu_rong.get_rewards()
    leader_actions_full, follower_actions_full = game_liu_rong.get_actions()
    leader_rewards_liu_rong = leader_rewards_full[BURN_IN_T:]
    follower_rewards_liu_rong = follower_rewards_full[BURN_IN_T:]
    leader_actions_liu_rong = leader_actions_full[BURN_IN_T:]
    follower_actions_liu_rong = follower_actions_full[BURN_IN_T:]
    history_liu_rong = game_liu_rong.get_history()[BURN_IN_T:]
    leader_cumulative_rewards_liu_rong = np.cumsum(leader_rewards_liu_rong)
    follower_cumulative_rewards_liu_rong = np.cumsum(follower_rewards_liu_rong)
    return (leader_cumulative_rewards_liu_rong, follower_cumulative_rewards_liu_rong,
            leader_actions_liu_rong, follower_actions_liu_rong, history_liu_rong)

def run_random(game_random):
    game_random.leader_strategy = random_leader_strategy
    game_random.follower_strategy = random_follower_strategy
    game_random.simulate(BANDIT_T)
    leader_rewards_full, follower_rewards_full = game_random.get_rewards()
    leader_actions_full, follower_actions_full = game_random.get_actions()
    leader_rewards_random = leader_rewards_full[BURN_IN_T:]
    follower_rewards_random = follower_rewards_full[BURN_IN_T:]
    leader_actions_random = leader_actions_full[BURN_IN_T:]
    follower_actions_random = follower_actions_full[BURN_IN_T:]
    history_random = game_random.get_history()[BURN_IN_T:]
    leader_cumulative_rewards_random = np.cumsum(leader_rewards_random)
    follower_cumulative_rewards_random = np.cumsum(follower_rewards_random)
    return (leader_cumulative_rewards_random, follower_cumulative_rewards_random,
            leader_actions_random, follower_actions_random, history_random)
