from abc import ABC, abstractmethod
import numpy as np
from common.variables import *
from embeddings.nflows_sphere_inv import *
import os
from ol.oful_proj import *
from ol.oful_bandit import *
from games.npg import find_a_opt
from scipy.optimize import minimize_scalar

class GameInterface(ABC):

    @abstractmethod
    def get_reward(self, action_leader: float, action_follower: float) -> float:
        pass

    @abstractmethod
    def best_response(self, action_leader: float) -> float:
        pass

    @abstractmethod
    def optimize_leader(self) -> float:
        pass


class StackelbergGame(GameInterface):

    def __init__(self, theta_1: float, theta_2: float, alpha_1: float, alpha_2: float, lambda_: float, T: int=BANDIT_T):
        self.theta_1 = theta_1
        self.theta_2 = theta_2
        self.alpha_1 = alpha_1
        self.alpha_2 = alpha_2
        self.lambda_ = lambda_
        self.T = T
        self.leader_rewards = []
        self.follower_rewards = []
        self.leader_actions = []
        self.follower_actions = []
        self.history = []
        self.t = 0

    def get_reward(self, action_leader: float, action_follower: float) -> float:
        leader_reward = (
            self.theta_1 * action_leader
            + self.theta_2 * np.log(1 + action_follower**2)
            - 0.5 * self.lambda_ * action_leader**2
        ) + np.random.normal(0, STDEV_R1)

        follower_reward = (
            self.alpha_1 * (-action_follower**2) 
            + self.alpha_2 * action_leader * action_follower
        ) + np.random.normal(0, STDEV_R1)
        
        return leader_reward, follower_reward
        
    def best_response(self, action_leader: float, alpha_1_estimate = None, alpha_2_estimate = None) -> float:
        if alpha_1_estimate is not None and alpha_2_estimate is not None:
            action_follower = (alpha_2_estimate * action_leader) / (2 * alpha_1_estimate)
        else:
            action_follower = (self.alpha_2 * action_leader) / (2 * self.alpha_1)
        return action_follower

    def optimize_leader(self) -> float:
        
        def objective(a):
            b = self.best_response(a)
            return -self.get_reward(a, b)[0]
        
        result = minimize_scalar(objective, bounds=(-10, 10), method='bounded')
        
        if result.success:
            return result.x
        else:
            raise ValueError("Optimization failed to find the optimal leader action.")
    
    def simulate(self, span_T=None, a=None, b=None):
        if span_T is None:
            span_T = self.T
        for t in range(span_T):
            self.sim_iter(a, b)

    def sim_iter(self, a=None, b=None):
        if a is None:
            a = np.random.uniform(0, A_UPPER_STACK1)
            b = np.random.uniform(0, B_UPPER_STACK1)
        
        leader_reward, follower_reward = self.get_reward(a, b)     
        self.leader_actions.append(a)
        self.follower_actions.append(b)
        self.leader_rewards.append(leader_reward)
        self.follower_rewards.append(follower_reward)
        
        self.history.append({
            'iteration': self.t+1,
            'leader_action': a,
            'follower_action': b,
            'leader_reward': leader_reward,
            'follower_reward': follower_reward
        })
        self.t += 1
        return leader_reward, follower_reward
    
    def get_rewards(self):
        return self.leader_rewards, self.follower_rewards
    
    def get_actions(self):
        return self.leader_actions, self.follower_actions

    def get_history(self):
        return self.history

def get_joint_action_vector_r1(history, t):
        record = history[t]
        return np.array([record['leader_action'], record['follower_action']])

def get_leader_reward(history, t):
    return history[t]['leader_reward']

def get_follower_reward(history, t):
    return history[t]['follower_reward']

def objective_function_R1(a, gamma_0, gamma_1, b):
    gamma_0/(2*gamma_1)
    return a * (gamma_0 + gamma_1 * a)

def optimistic_leader_strategy_R1(history, lower_bound=0, upper_bound=1, kappa=KAPPA_LIU_RONG):
    if len(history) == 0:
        return np.random.uniform([lower_bound, upper_bound])[0]

    leader_action_history = [h['leader_action'] for h in history]
    reward_history = [h['leader_reward'] for h in history]
    x_reg = np.column_stack((np.ones(len(leader_action_history)), leader_action_history))

    gammas = np.linalg.inv(x_reg.T @ x_reg) @ x_reg.T @ reward_history
    gamma_0, gamma_1 = gammas

    t = len(history)

    gamma_lp_bound = kappa * np.sqrt(np.log(t) / t)

    gamma_0_optimistic = gamma_0 + gamma_lp_bound
    gamma_1_optimistic = gamma_1 - gamma_lp_bound

    a_opt = find_a_opt(gamma_0_optimistic, gamma_1_optimistic, obj_func=objective_function_R1)

    return a_opt

def run_gisa_stack_r1(game,
            input_dimA = INPUT_DIM_A,
            input_dimB = INPUT_DIM_B,
            output_dim = OUTPUT_DIM,
            hidden_dim = STACK_EMB_HIDDEN_DIM_1,
            flow_length = FLOW_LENGTH,
            C_t = C_T,
            ucb_obj=None):
    
    weight_path = 'saved_models/model_tests_nflow_sphere.pth'
    hash_mapping = load_hash_mapping(file_path='saved_models/hash_mapping.pkl')    

    model = SphericalMappingNetWithFlows(input_dimA, input_dimB, hidden_dim, flow_layers=flow_length)
    if os.path.isfile(weight_path):
        model.load_state_dict(torch.load(weight_path))
    else:
        model.save_weights(weight_path)

    history = game.get_history()[:BURN_IN_T]
    
    joint_action_history = [get_joint_action_vector_r1(history, t) for t in range(0, BURN_IN_T)]
    leader_rewards_history = [get_leader_reward(history, t) for t in range(0, BURN_IN_T)]
    follower_rewards_history = [get_follower_reward(history, t) for t in range(0, BURN_IN_T)]
    
    joint_action_history_adj = [np.pad(t, ((1), (0)), mode='edge') for t in joint_action_history]

    joint_action_history_tensor = torch.tensor(joint_action_history_adj).float()
    manifold_points = model.map_to_sphere(joint_action_history_tensor[:,0].unsqueeze(1), joint_action_history_tensor[:,1:3]).squeeze()

    leader_rewards = torch.tensor(leader_rewards_history).float()
    follower_rewards = torch.tensor(follower_rewards_history).float()

    leader_weights = lin_reg(manifold_points, leader_rewards)
    print(f'Leader Weights (Coefficients): {leader_weights}')

    follower_weights = lin_reg(manifold_points, follower_rewards)
    print(f'Follower Weights (Coefficients): {follower_weights}')

    oful = OFUL(d=output_dim, alpha=0.1, delta=0.1, lambd=1.0)

    for t in range(BANDIT_T):

        theta_A_est = lin_reg(manifold_points, leader_rewards).detach().numpy()
        theta_B_est = lin_reg(manifold_points, follower_rewards).detach().numpy()

        joint_action_manifold,_, _ = gisa_algorithm(t, theta_A_est, theta_B_est, C_t)
        joint_action_manifold_tensor = torch.tensor(joint_action_manifold).unsqueeze(0).float()

        leader_action_inv, follower_action_inv = model.decode(joint_action_manifold_tensor)

        leader_a = leader_action_inv[0].detach().numpy()[0]
                            
        M_x = model.map_to_sphere(leader_action_inv, follower_action_inv).squeeze(0).double()

        
        follower_b = game.best_response(leader_a)
        
        leader_a = np.clip(leader_a, A_LOWER, A_UPPER)
        follower_b = np.clip(follower_b, B_LOWER, B_UPPER) 

        leader_reward, follower_reward = game.sim_iter(leader_a, follower_b)
        
        manifold_points = torch.cat((manifold_points, M_x), dim=0)
        leader_rewards = torch.cat((leader_rewards, torch.tensor([leader_reward]).double()), dim=0)
        follower_rewards = torch.cat((follower_rewards, torch.tensor([follower_reward]).double()), dim=0)

        oful.update(manifold_points, leader_rewards)
    
    leader_rewards_list, follower_rewards_list = game.get_rewards()
    leader_actions_list, follower_actions_list = game.get_actions()

    history = game.get_history()[BURN_IN_T:]

    leader_cumulative_rewards = np.cumsum(leader_rewards_list[BURN_IN_T:])
    follower_cumulative_rewards = np.cumsum(follower_rewards_list[BURN_IN_T:])

    return_dict = {'joint_action_history': joint_action_history,
                   'leader_rewards_history': leader_rewards_history,
                   'follower_rewards_history': follower_rewards_history,
                   'manifold_points': manifold_points,
                   'leader_weights': leader_weights,
                   'follower_weights': follower_weights,
                   'leader_rewards_list': leader_rewards_list,
                   'follower_rewards_list': follower_rewards_list,
                   'leader_cumulative_rewards': leader_cumulative_rewards,
                   'follower_cumulative_rewards': follower_cumulative_rewards,
                   'leader_actions_list': leader_actions_list,
                   'follower_actions_list': follower_actions_list,}

    return return_dict


class UCBLeader:
    def __init__(self, game: StackelbergGame, num_actions: int = UCB_DISCRETIZATION, c: float = UCB_EXPLORE):
        self.game = game
        self.c = c
        self.num_actions = num_actions
        
        self.actions = np.linspace(0, A_UPPER_STACK1, num_actions)
        
        self.action_values = np.zeros(num_actions)
        self.action_counts = np.zeros(num_actions)
        self.total_reward = 0.0
        self.cumulative_rewards_ucb = []

    def select_action(self, t):
        if t < self.num_actions:
            return t

        ucb_values = self.action_values + self.c * np.sqrt(np.log(t + 1) / (self.action_counts + 1e-10))
        return np.argmax(ucb_values)

    def update_action(self, action_index, reward, t):
        self.action_counts[action_index] += 1
        n = self.action_counts[action_index]
        self.action_values[action_index] += (reward - self.action_values[action_index]) / n

    def simulate_ucb(self, horizon):
        cumulative_reward = 0
        for t in range(horizon):
            action_index = self.select_action(t)
            a = self.actions[action_index]
            b = self.game.best_response(a)
            leader_reward, _ = self.game.get_reward(a, b)
            
            self.update_action(action_index, leader_reward, t)
            self.total_reward += leader_reward
            
            cumulative_reward += leader_reward
            self.cumulative_rewards_ucb.append(cumulative_reward)

        return self.total_reward

    def get_cumulative_rewards(self):
        return self.cumulative_rewards_ucb
