import numpy as np
from scipy.optimize import minimize
from common.variables import *
from embeddings.nflows_sphere_inv import *
import os
from ol.oful_proj_N import *
from ol.oful_bandit import *

n = INPUT_DIM_A
theta_A = THETA_A
theta_B = THETA_B
class StackelbergGame:
    
    def __init__(self, n, C_A=1.0, C_B=1.0, theta_a=None, theta_b=None, T=BANDIT_T):
        self.n = n
        self.C_A = C_A
        self.C_B = C_B
        self.T = T
        self.leader_rewards = []
        self.follower_rewards = []
        self.leader_actions = []
        self.follower_actions = []
        self.history = []
        self.t = 0

        if theta_a is not None:
            self.theta_A = theta_a
        else:
            self.theta_A = np.random.rand(n)
        if theta_b is not None:
            self.theta_B = theta_b
        else:
            self.theta_B = np.random.rand(n)
    
    def get_reward(self, action_leader, action_follower):
        f_leader = action_leader ** 2
        g_follower = action_follower ** 2
        
        leader_reward = np.dot(self.theta_A, (action_leader - action_follower)) - np.dot(self.theta_A, f_leader) +  np.random.normal(0, STDEV_RN)
        
        follower_reward = np.dot(self.theta_B, (action_leader - action_follower)) - np.dot(self.theta_B, g_follower) + np.random.normal(0, STDEV_RN)
        
        return leader_reward, follower_reward
    
    def best_response(self, action_leader):
        if len(action_leader) == 2:
                action_leader = action_leader[0]
        
        def follower_obj(action_follower):
            return -(np.dot(self.theta_B, action_leader - action_follower) - np.dot(self.theta_B, action_follower**2))
        
        cons = {'type': 'ineq', 'fun': lambda b: self.C_B - np.sum(np.abs(self.theta_B * b))}
        bounds = [(0.0, self.C_B)] * self.n
        
        res = minimize(follower_obj, np.zeros(self.n), constraints=cons, bounds=bounds)
        
        if res.success:
            return res.x
        else:
            return np.random.uniform(-B_UPPER_STACK_N, B_UPPER_STACK_N, size=self.n)
    
    def optimize_leader(self):
        def leader_obj(action_leader):
            action_follower = self.best_response(action_leader)
            return -(np.dot(self.theta_A, action_leader - action_follower) - np.dot(self.theta_A, action_leader**2))
        
        cons = {'type': 'ineq', 'fun': lambda a: self.C_A - np.sum(np.abs(self.theta_A * a))}
        bounds = [(0.0, self.C_A)] * self.n
        
        res = minimize(leader_obj, np.zeros(self.n), constraints=cons, bounds=bounds)
        
        if res.success:
            return res.x, -res.fun
        else:
            a_opt_filler = np.random.rand(self.n)
            b_br, _ = self.best_response(a_opt_filler)
            return a_opt_filler, self.get_reward(a_opt_filler, b_br) 
            raise ValueError("Leader optimization failed.")
    
    def simulate(self, span_T=None, a=None, b=None):
        if span_T is None:
            span_T = self.T
        for t in range(span_T):
            self.sim_iter(a, b)

    def sim_iter(self, a=None, b=None):
        if a is None:
            a = np.random.uniform(-A_UPPER_STACK_N, A_UPPER_STACK_N, size=self.n)            

        if b is None:    
            b = a = np.random.uniform(-B_UPPER_STACK_N, B_UPPER_STACK_N, size=self.n)
        
        leader_reward, follower_reward = self.get_reward(a, b)
        self.leader_actions.append(a)
        self.follower_actions.append(b)
        self.leader_rewards.append(leader_reward)
        self.follower_rewards.append(follower_reward)
        
        self.history.append({
            'iteration': self.t+1,
            'leader_action': a,
            'follower_action': b,
            'leader_reward': leader_reward,
            'follower_reward': follower_reward
        })
        self.t += 1
        return leader_reward, follower_reward
    
    def get_rewards(self):
        return self.leader_rewards, self.follower_rewards
    
    def get_actions(self):
        return self.leader_actions, self.follower_actions

    def get_history(self):
        return self.history

def get_joint_action_vector_rn(history, t):
        record = history[t]
        return np.array([record['leader_action'], record['follower_action']])

def get_leader_reward(history, t):
    return history[t]['leader_reward']

def get_follower_reward(history, t):
    return history[t]['follower_reward']

def run_gisa_stack_rn(game,
            input_dimA = INPUT_DIM_A,
            input_dimB = INPUT_DIM_B,
            output_dim = OUTPUT_DIM,
            hidden_dim = STACK_EMB_HIDDEN_DIM_1,
            flow_length = FLOW_LENGTH,
            C_t = C_T,
            bandit_t = BANDIT_T,
            ucb_obj = None):
    
    weight_path = 'saved_models/model_tests_nflow_sphere.pth'
    hash_mapping = load_hash_mapping(file_path='saved_models/hash_mapping.pkl')    

    model = SphericalMappingNetWithFlows(input_dimA, input_dimB, hidden_dim, flow_layers=flow_length)
    if os.path.isfile(weight_path):
        model.load_state_dict(torch.load(weight_path))
    else:
        model.save_weights(weight_path)

    history = game.get_history()[:BURN_IN_T]
    
    joint_action_history = [get_joint_action_vector_rn(history, t) for t in range(0, BURN_IN_T)]
    leader_rewards_history = [get_leader_reward(history, t) for t in range(0, BURN_IN_T)]
    follower_rewards_history = [get_follower_reward(history, t) for t in range(0, BURN_IN_T)]

    joint_action_history_tensor = torch.tensor(joint_action_history).float()
    a_tensor = joint_action_history_tensor[:,0]
    b_tensor = joint_action_history_tensor[:,1:3].squeeze(1)
    manifold_points = model.map_to_sphere(joint_action_history_tensor[:,0].unsqueeze(1), joint_action_history_tensor[:,1:3]).squeeze()

    leader_rewards = torch.tensor(leader_rewards_history).float()
    follower_rewards = torch.tensor(follower_rewards_history).float()

    leader_weights = lin_reg(manifold_points, leader_rewards)
    print(f'Leader Weights (Coefficients): {leader_weights}')

    follower_weights = lin_reg(manifold_points, follower_rewards)
    print(f'Follower Weights (Coefficients): {follower_weights}')

    oful = OFUL(d=output_dim, alpha=0.1, delta=0.1, lambd=1.0)

    for t in range(bandit_t):

        theta_A_est = lin_reg(manifold_points, leader_rewards).detach().numpy()
        theta_B_est = lin_reg(manifold_points, follower_rewards).detach().numpy()

        joint_action_manifold,_, _ = gisa_algorithm(t, theta_A_est, theta_B_est, C_t, D = OUTPUT_DIM)
        joint_action_manifold_tensor = torch.tensor(joint_action_manifold).unsqueeze(0).float()

        leader_action_inv, follower_action_inv = model.decode(joint_action_manifold_tensor)
        
        print(f'Leader Action: {leader_action_inv}')
        print(f'Follower Action: {follower_action_inv}')                       

        M_x = model.map_to_sphere(leader_action_inv, follower_action_inv).squeeze(0).double()
        
        leader_a = leader_action_inv[0].detach().numpy()
        follower_b = follower_action_inv[0].detach().numpy()

        leader_a = np.clip(leader_a, A_LOWER, A_UPPER)
        follower_b = np.clip(follower_b, B_LOWER, B_UPPER)

        follower_b = follower_best_response(leader_a)

        leader_reward, follower_reward = game.sim_iter(leader_a, follower_b)
        print(f'Leader Reward GISA: {leader_reward}')
        
        M_x = M_x.unsqueeze(0)
        manifold_points = torch.cat((manifold_points, M_x), dim=0)
        leader_rewards = torch.cat((leader_rewards, torch.tensor([leader_reward]).double()), dim=0)
        follower_rewards = torch.cat((follower_rewards, torch.tensor([follower_reward]).double()), dim=0)

        oful.update(manifold_points, leader_rewards)
    
    leader_rewards_list, follower_rewards_list = game.get_rewards()
    leader_actions_list, follower_actions_list = game.get_actions()

    history = game.get_history()[BURN_IN_T:]

    leader_cumulative_rewards = np.cumsum(leader_rewards_list[BURN_IN_T:])
    follower_cumulative_rewards = np.cumsum(follower_rewards_list[BURN_IN_T:])

    joint_action_history = [get_joint_action_vector_rn(history, t) for t in range(0, BANDIT_T)]
    leader_rewards_history = [get_leader_reward(history, t) for t in range(0, BANDIT_T)]
    follower_rewards_history = [get_follower_reward(history, t) for t in range(0, BANDIT_T)]

    return_dict = {'joint_action_history': joint_action_history,
                   'leader_rewards_history': leader_rewards_history,
                   'follower_rewards_history': follower_rewards_history,
                   'manifold_points': manifold_points,
                   'leader_weights': leader_weights,
                   'follower_weights': follower_weights,
                   'leader_rewards_list': leader_rewards_list,
                   'follower_rewards_list': follower_rewards_list,
                   'leader_cumulative_rewards': leader_cumulative_rewards,
                   'follower_cumulative_rewards': follower_cumulative_rewards,
                   'leader_actions_list': leader_actions_list,
                   'follower_actions_list': follower_actions_list,}

    return return_dict

class UCBLeaderN:
    def __init__(self, game: StackelbergGame, num_actions_per_dim: int = UCB_DISCRETIZATION, c: float = UCB_EXPLORE):
        self.game = game
        self.c = c
        self.num_actions_per_dim = num_actions_per_dim
        
        self.actions = np.array(np.meshgrid(*[np.linspace(-game.C_A, game.C_A, num_actions_per_dim)] * game.n)).T.reshape(-1, game.n)
        self.num_actions = self.actions.shape[0]
        
        self.action_values = np.zeros(self.num_actions)
        self.action_counts = np.zeros(self.num_actions)
        self.total_reward = 0.0
        self.cumulative_rewards_ucb = []

    def select_action(self, t):
        if t < self.num_actions:
            return t

        ucb_values = self.action_values + self.c * np.sqrt(np.log(t + 1) / (self.action_counts + 1e-10))
        return np.argmax(ucb_values)

    def update_action(self, action_index, reward):
        self.action_counts[action_index] += 1
        n = self.action_counts[action_index]
        self.action_values[action_index] += (reward - self.action_values[action_index]) / n

    def simulate_ucb(self, horizon):
        cumulative_reward = 0
        for t in range(horizon):
            action_index = self.select_action(t)
            a = self.actions[action_index]
            b = self.game.best_response(a)
            
            leader_reward, _ = self.game.get_reward(a, b)
            
            self.update_action(action_index, leader_reward)
            self.total_reward += leader_reward
            
            cumulative_reward += leader_reward
            self.cumulative_rewards_ucb.append(cumulative_reward)

        return self.total_reward

    def get_cumulative_rewards(self):
        return self.cumulative_rewards_ucb

def f_quadratic(a):
    return a**2

def g_quadratic(b):
    return b**2

def follower_best_response(a):
    def follower_obj(b):
        return -(np.dot(theta_B, a - b) - np.dot(theta_B, g_quadratic(b)))

    cons = {'type': 'ineq', 'fun': lambda b: C_B - np.sum(np.abs(theta_B * b))}
    bounds = [(0.0, C_B)] * n

    res = minimize(follower_obj, np.zeros(n), constraints=cons, bounds=bounds)

    if res.success:
        return res.x
    else:
        raise ValueError("Follower optimization failed.")

def leader_optimization():
    def leader_obj(a):
        b_star = follower_best_response(a)
        return -(np.dot(theta_A, a - b_star) - np.dot(theta_A, f_quadratic(a)))

    cons = {'type': 'ineq', 'fun': lambda a: C_A - np.sum(np.abs(theta_A * a))}
    bounds = [(0.0, C_A)] * n

    res = minimize(leader_obj, np.zeros(n), constraints=cons, bounds=bounds)

    if res.success:
        return res.x, -res.fun
    else:
        raise ValueError("Leader optimization failed.")
