import numpy as np
import time
import pandas as pd

import math
from torch.optim import Optimizer
from torch import Tensor
from typing import List

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.distributions import Beta, Bernoulli

from kumaraswamy import KumaraswamyStable
from simple_squashed_normal import TanhNormal

from sklearn.datasets import fetch_openml
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
import pandas

import torch
from torch.utils.data import Dataset, DataLoader

# import List type
from typing import List

def stable_bernoulli_log_prob(log_p, binary_rewards):
    log_q = torch.log1p(-torch.exp(log_p))  # log(1 - p) using log1p for stability
    log_prob = binary_rewards * log_p + (1 - binary_rewards) * log_q
    return log_prob

class VariationalBanditEncoder(nn.Module):
    def __init__(self, 
        input_dim: int, 
        hidden_layers: List[int], 
        num_bandits: int, 
        top_m: int, 
        var_post: str = 'ks'
        ):
        super().__init__()
        assert var_post in ['ks', 'beta', 'tanh-normal']
        self.var_post = var_post
        #self.q = KumaraswamyStable if var_post == 'ks' else Beta
        if var_post == 'ks':
            self.q = KumaraswamyStable
        elif var_post == 'beta':
            self.q = Beta
        elif var_post == 'tanh-normal':
            self.q = TanhNormal
        
        # Define the shared MLP
        layers = []
        prev_dim = input_dim
        for h_dim in hidden_layers:
            layers.append(nn.Linear(prev_dim, h_dim))
            layers.append(nn.LeakyReLU())
            prev_dim = h_dim
        layers.append(nn.Linear(prev_dim, 2))  # Output layer to get alpha and beta
        self.mlp = nn.Sequential(*layers)
        
        self.num_bandits = num_bandits
        self.top_m = top_m

        # keep track of the past context/reward pairs
        self.contexts = []
        self.rewards = []
        self.arm_indices = []
        # counter of size num_bandits to keep track of the number of times each arm is selected
        self.arm_counter = torch.zeros(num_bandits)
        self.arm_set = set()
        
        self.replay_buffer_arms = []
        self.replay_buffer_rewards = []
        # register buffer
        self.replay_buffer_arm_last_occurance = torch.ones(num_bandits, dtype=torch.int) * -1
        # ensure this is moved to GPU when model is moved to GPU
        #self.register_buffer('replay_buffer_arm_last_occurance', torch.ones(num_bandits) * -1)
        
        # Metrics
        self.cumulative_reward = 0.0
        self.cumulative_regret = 0.0
        self.step_count = 0
        self.start_time = time.time()

        self.metrics = []

        self.optimal_probs = None

    def forward_beta(self, x):
        output = self.mlp(x)  # Shape: [K, 2]
        pre_alpha, pre_beta = output.chunk(2, dim=1)
        alpha = F.softplus(pre_alpha) + 1e-6
        beta = F.softplus(pre_beta) + 1e-6
        #alpha = torch.exp(pre_alpha) + 1e-6 # cause unstable training - nans
        #beta = torch.exp(pre_beta) + 1e-6
        return alpha.squeeze(), beta.squeeze()
    
    def forward_ks(self, x):
        output = self.mlp(x)
        log_a, log_b = output.chunk(2, dim=1)
        return log_a.squeeze(), log_b.squeeze()
    
    def forward_tanh_normal(self, x):
        output = self.mlp(x)
        mu, log_stdv = output.chunk(2, dim=1)
        return mu.squeeze(), log_stdv.squeeze()

    def forward(self, X):
        if self.var_post == 'beta':
            return self.forward_beta(X)
        elif self.var_post == 'ks':
            return self.forward_ks(X)
        elif self.var_post == 'tanh-normal':
            return self.forward_tanh_normal(X)
        else:
            raise ValueError("Invalid variational posterior. Must be one of ['beta', 'ks', 'tanh-normal']")

    def training_step(self, X, true_probs, entropy_estimate_samples):
        # X: [K, D], true_probs: [K]
        
        ### perform arm selection here ###
        # param_1/2 is log_a/log_b for Kumaraswamy, alpha/beta for Beta, and mu/log_stdv for TanhNormal
        var_post_param_1, var_post_param_2 = self.forward(X) # Each of shape: [num_bandits]
        if self.var_post in ['beta', 'ks']:
            distributions = self.q(var_post_param_1, var_post_param_2) # TODO: eventually implement low/high for ks
        elif self.var_post == 'tanh-normal':
            distributions = self.q(var_post_param_1, var_post_param_2, low=0, high=1)

        samples = distributions.rsample()  # Shape: [num_bandits]

        # Select top-M bandits based on sampled values
        _, top_m_sample_indices = torch.topk(samples, self.top_m)
        self.arm_set.update(top_m_sample_indices.tolist())
        
        # Get true probabilities for selected top-M bandits
        rewards = Bernoulli(true_probs[top_m_sample_indices]).sample()  # Shape: [M] # TODO: should these be python ints instead of tensors for efficiency?

        # save context/reward pairs
        self.replay_buffer_arms.extend(top_m_sample_indices.tolist())
        self.replay_buffer_rewards.extend(rewards)

        # update last occurance of each arm, which is the final M indices of the replay buffer: [len(replay_buffer_arms) - M, len(replay_buffer_arms) - M + 1, ...,  len(replay_buffer_arms) - 1]
        l = len(self.replay_buffer_arms)
        self.replay_buffer_arm_last_occurance[top_m_sample_indices] = torch.arange(l - self.top_m, l, dtype=torch.int)
                #list(range(len(self.replay_buffer_arms) - self.top_m, len(self.replay_buffer_arms)))
        # Update metrics
        self.optimal_probs = self.optimal_probs if self.optimal_probs is not None else true_probs.topk(self.top_m).values # expected reward if we always select the top m bandits
        selected_probs = true_probs[top_m_sample_indices]
        regret = (self.optimal_probs - selected_probs).mean()
    
        self.cumulative_reward += rewards.sum().item()
        self.cumulative_regret += regret.item()
        self.step_count += 1
    
        ### Compute ELBO ###
        
        ## Log prob
        # 1. Place sample from arm i to each location in the replay buffer where arm i was pulled.
        # ex) replay_buffer_arms = [4,     7,   7,   0,   9, ...]
        #     samples            = [s_0, s_1, s_2, s_3, s_4, ...]
        #     samples_rb         = [s_4, s_7, s_7, s_0, s_9, ...]
        samples_rb = samples[self.replay_buffer_arms] # samples are for each arm ... distribute to 

        # 2. For each reward in replay buffer, use the sample for that arm to parameterize the Bernoulli likelihood, and compute the log prob
        # of that reward under that sample.
        rb_rewards = torch.tensor(self.replay_buffer_rewards, dtype=torch.float32).to(X.device)
        #log_prob = stable_bernoulli_log_prob(samples_rb, rb_rewards).sum() # use with log_samples
        log_prob = Bernoulli(probs=samples_rb).log_prob(rb_rewards).sum()

        ## Entropy
        # For each arm that has been pulled, compute it's Shannon Entropy. 
        # NOTE: This is NOT the entropy of the entire replay buffer, but the entropy of the unique arms pulled.
        pulled_arms = torch.tensor(list(self.arm_set), dtype=torch.int).to(X.device)
        if self.var_post in ['beta', 'ks']:
            entropy = self.q(var_post_param_1[pulled_arms], var_post_param_2[pulled_arms]).entropy()
        elif self.var_post == 'tanh-normal':
            entropy = self.q(var_post_param_1[pulled_arms], var_post_param_2[pulled_arms]).entropy_estimate(num_samples=entropy_estimate_samples)
        assert len(entropy) == len(self.arm_set), f"# entropy terms != number of unique arms pulled: len(entropy): {len(entropy)}, len(self.arm_set): {len(self.arm_set)}"

        # decreasing c increases the entropy penalty, promoting exploration. c=1.0 is the default, and corresponds to a mean of the entropies.
        c = 1.0 
        entropy_scale = (1/len(entropy))**(c) 
        entropy = entropy_scale * entropy.sum()
        elbo = log_prob + entropy#.mean() # implicitly using a entropy_scaling of 1/len(entropy) == 1/num_unique_arms_pulled
        loss = -elbo


        # Log metrics
        metrics = {
            'reward': rewards.sum().item(), 
            'regret': regret.item(),
            'cumulative_reward': self.cumulative_reward,
            'cumulative_regret': self.cumulative_regret,
            'arms': len(self.arm_set),
            'elbo': elbo.mean().item(),
            'log_prob': log_prob.item(),
            'entropy': entropy.sum().item(),
            }
        self.metrics.append(metrics)
        return loss

    def training_step_mushroom(self, X, label, entropy_scale):
        with torch.no_grad():
            ### perform arm selection here ###
            X = X.unsqueeze(0)
            if self.var_post == 'beta':
                alpha, beta = self.forward(X) 
                distributions = Beta(alpha, beta)
            else:
                log_a, log_b = self.forward(X)
                distributions = KumaraswamyStable(log_a, log_b)

            sample = distributions.rsample()
            pred = (sample > 0.5) + 0.0
            reward = (pred == label).float()

            # save context/reward pairs
            self.contexts.append(X)
            self.rewards.append(reward)

            # Update metrics
            regret = (sample - label).abs()
        
            self.cumulative_reward += reward.item()
            self.cumulative_regret += regret.item()
            self.step_count += 1
        
            # Log metrics
            metrics = {
                'reward': reward.item(), 
                'regret': regret.item(),
                #'elbo': elbo.mean().item(),
                #'log_prob': log_prob.mean().item(),
                #'entropy': entropy.mean().item(),
                'cumulative_reward': self.cumulative_reward,
                'cumulative_regret': self.cumulative_regret
                }
            self.metrics.append(metrics)

    
        ### concatenate all contexts and rewards for training ###
        X_train = torch.cat(self.contexts, dim=0).to(X.device)
        rewards_train = torch.cat(self.rewards, dim=0).to(X.device)
        #print(f"X_train: {X_train.shape}, rewards_train: {rewards_train.shape}")
        if self.var_post == 'beta':
            alpha, beta = self.forward(X_train)  # Each of shape: [K]
            distributions = Beta(alpha, beta)
        else:
            log_a, log_b = self.forward(X_train)
            distributions = KumaraswamyStable(log_a, log_b)
        
        samples_train = distributions.rsample() #(10, )).mean(dim=0)  # Shape: [K]
        
        # Bernoulli likelihood
        log_prob_train = Bernoulli(samples_train).log_prob(rewards_train)

        if self.var_post == 'beta':
            entropy_train = Beta(alpha, beta).entropy()
        else:  
            entropy_train = KumaraswamyStable(log_a, log_b).entropy()

        elbo_train = log_prob_train + entropy_scale * entropy_train  # Maximizing ELBO-like objective

        loss = -elbo_train.mean()  # Since we minimize loss
        return loss

        

###### Langevin Monte Carlo ######

# From author implementation in "Langevin Monte Carlo for Thompson Sampling", https://github.com/devzhk/LMCTS/blob/97114fc7a2160ba5d45c9ef483d2284497f81be6/algo/langevin.py
# example how to train: https://github.com/devzhk/LMCTS/blob/97114fc7a2160ba5d45c9ef483d2284497f81be6/train_utils/helper.py#L59
def lmc(params: List[Tensor],
        d_p_list: List[Tensor],
        weight_decay: float,
        lr: float):
    r"""Functional API that performs Langevine MC algorithm computation.
    """

    for i, param in enumerate(params):
        d_p = d_p_list[i]
        if weight_decay != 0:
            d_p = d_p.add_(param, alpha=weight_decay)

        param.add_(d_p, alpha=-lr)

class LangevinMC(Optimizer):
    def __init__(self,
                 params,              # parameters of the model
                 lr=0.01,             # learning rate
                 beta_inv=0.01,       # inverse temperature parameter
                 sigma=1.0,           # variance of the Gaussian noise
                 weight_decay=1.0,
                 device=None):   # l2 penalty
        if lr < 0:
            raise ValueError('lr must be positive')
        if device:
            self.device = device
        else:
            self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.beta_inv = beta_inv
        self.lr = lr
        self.sigma = sigma
        self.temp = - math.sqrt(2 * beta_inv / lr) * sigma
        self.curr_step = 0
        defaults = dict(weight_decay=weight_decay)
        super(LangevinMC, self).__init__(params, defaults)

    def init_map(self):
        self.mapping = dict()
        index = 0
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    num_param = p.numel()
                    self.mapping[p] = [index, num_param]
                    index += num_param
        self.total_size = index

    @torch.no_grad()
    def step(self):
        self.curr_step += 1
        if self.curr_step == 1:
            self.init_map()

        lr = self.lr
        temp = self.temp
        noise = temp * torch.randn(self.total_size, device=self.device)

        for group in self.param_groups:
            weight_decay = group['weight_decay']

            params_with_grad = []
            d_p_list = []
            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)

                    start, length = self.mapping[p]
                    add_noise = noise[start: start + length].reshape(p.shape)
                    delta_p = p.grad
                    delta_p = delta_p.add_(add_noise)
                    d_p_list.append(delta_p)
                    # p.add_(delta_p)
            lmc(params_with_grad, d_p_list, weight_decay, lr)
class LMCTS(nn.Module):
    def __init__(self, 
        input_dim: int, 
        hidden_layers: List[int], 
        num_bandits: int, 
        top_m: int, 
        ):
        super().__init__()
        
        # Define the shared MLP
        layers = []
        prev_dim = input_dim
        for h_dim in hidden_layers:
            layers.append(nn.Linear(prev_dim, h_dim))
            layers.append(nn.LeakyReLU())
            prev_dim = h_dim
        layers.append(nn.Linear(prev_dim, 1))  # Output layer to expected reward
        self.mlp = nn.Sequential(*layers)
        
        self.num_bandits = num_bandits
        self.top_m = top_m

        # keep track of the past context/reward pairs
        self.contexts = []
        self.rewards = []
        self.arm_set = set()
        
        # Metrics
        self.cumulative_reward = 0.0
        self.cumulative_regret = 0.0
        self.step_count = 0
        self.start_time = time.time()

        self.metrics = []

    def forward(self, X):
        return self.mlp(X) # log expected reward

    def training_step(self, X, true_probs, optimizer, inner_num_iters):
        # X: [K, D], true_probs: [K]
        
        logits = self.forward(X).squeeze()  # Each of shape: [K]
        
        # Select top-M bandits based on logits
        _, top_m_sample_indices = torch.topk(logits, self.top_m)
        
        # Get true probabilities for selected top-M bandits
        rewards = Bernoulli(true_probs[top_m_sample_indices]).sample()  # Shape: [M]

        X_top_m = X[top_m_sample_indices]

        # save context/reward pairs
        self.contexts.append(X_top_m)
        self.rewards.append(rewards)
        self.arm_set.update(top_m_sample_indices.tolist())

        # concatenate all contexts and rewards for training
        X_train = torch.cat(self.contexts, dim=0).to(X.device)
        rewards_train = torch.cat(self.rewards, dim=0).to(X.device)

        # turn on training mode
        # save old weights
        last_weights = [param.clone() for param in self.mlp.parameters()]
        self.mlp.train()
        nan_count = 0
        for i in range(inner_num_iters):
            # langevin monte carlo
            self.mlp.zero_grad()
            logits = self.forward(X_train)
            
            # ensure logits and rewards have the same shape. top_m == 1 causes shape issue.
            if logits.ndim > rewards_train.ndim:
                rewards_train = rewards_train.unsqueeze(-1)

            # torch binary cross entropy. See line 54 for reduction: https://github.com/devzhk/LMCTS/blob/97114fc7a2160ba5d45c9ef483d2284497f81be6/train_utils/helper.py#L265
            loss = F.binary_cross_entropy_with_logits(logits, rewards_train, reduction='sum') # tried mean, and it didnt work!
            if torch.isnan(loss):
                nan_count += 1
                print("Loss is Nan!...replacing weights with pre-LMC weights")
                # replace the weights with the last stable weights
                with torch.no_grad():
                    for param, old_param in zip(self.mlp.parameters(), last_weights):
                        param.data = old_param.data
                    if nan_count > 5:
                        # add some tiny noise to the weights
                        print("5 Monte Carlo steps with Nan loss. Adding N(0, 1) * 1e-4 noise to pre-MC weights.")
                        for param in self.mlp.parameters():
                            param.data += torch.randn_like(param) * 1e-4
                #break
            else:
                loss.backward()
                #last_weights = [param.clone() for param in self.mlp.parameters()]
            optimizer.step() # LMC is implemented in this custom optimizer
        #assert not torch.isnan(loss), "Loss is Nan!"

        self.mlp.eval()
        
        # Update metrics
        optimal_probs = true_probs.topk(self.top_m).values # reward if we always select the top m bandits
        regret = (optimal_probs - true_probs[top_m_sample_indices]).sum()
        
        self.cumulative_reward += rewards.sum().item()
        self.cumulative_regret += regret.sum().item()
        self.step_count += 1
        
        # Log metrics
        metrics = {
            'reward': rewards.sum().item(), 
            'regret': regret.item(),
            'log_prob': - loss.mean().item(),
            'arms': len(self.arm_set),
            'nan_count': nan_count,
            'cumulative_reward': self.cumulative_reward,
            'cumulative_regret': self.cumulative_regret
        }
        self.metrics.append(metrics)
        return
    


""" UCI Dataset Stuff  """

continuous_dataset = ['shuttle', 'covertype']

def sample_data(dataloader):
    while True:
        for batch in dataloader:
            yield batch

def remove_nan(arr):
    '''
    Drop the rows that contain Nan
    '''
    df = pd.DataFrame(arr)
    df = df.dropna()
    return df.to_numpy()

class AutoUCI(Dataset):
    def __init__(self, name, dim_context, num_arms, num_data=None, version='active'):
        super(AutoUCI, self).__init__()
        self.dim_context = dim_context
        self.num_arms = num_arms
        self.loaddata(name, version, num_data)

    def __getitem__(self, idx):
        """ 
        # I believe this logic is for multi-class classification, so far mushroom is binary
        x = self.context[idx]
        cxt = torch.zeros((self.num_arms, self.dim_context * self.num_arms))
        for i in range(self.num_arms):
            cxt[i, i * self.dim_context: (i + 1) * self.dim_context] = x
        return cxt, self.label[idx]
        """
        return self.context[idx], self.label[idx]

    def __len__(self):
        return self.label.shape[0]

    def loaddata(self, name, version, num_data):
        ctx = fetch_openml(name=name, version=version, data_home='data', as_frame=True)
        df = ctx.frame
        
        df = df.dropna()

        encoder_X = OneHotEncoder(sparse_output=False, drop=None)  # We want a dense array output
        X_encoded = encoder_X.fit_transform(df.drop(columns=['class']))

        # Initialize LabelEncoder for target (y) to get binary labels
        label_encoder_y = LabelEncoder()
        y_encoded = label_encoder_y.fit_transform(df['class'])

        X_normalized = X_encoded / np.linalg.norm(X_encoded, axis=1, keepdims=True)

        self.context = torch.tensor(X_normalized).float()
        self.label = torch.tensor(y_encoded)