import numpy as np
import torch
import torch.nn.functional as F
from sklearn.datasets import fetch_openml
from torchvision import datasets, transforms

from utils import *


class LogisticBanditEnvironment:
    def __init__(self, device, T=1000, d=20, K=5, S=1, latent_function='h1', cnum=500, seed=0):
        self.device = device
        self.T = T
        self.d = d
        self.K = K
        self.S = S
        self.latent_function = latent_function
        self.cnum = cnum
        torch.manual_seed(seed)

        self.theta_star = self._sample_from_unit_ball(1, 1, d).to(device)
        self.theta_star_matrix = (
            self._sample_from_unit_ball(1, d, d)
            .squeeze(0)
            .to(device)
        )

    def _sample_from_unit_ball(self, T, K, d):
        x = torch.rand(T, K, d)
        return 2 * x - 1

    def get_contexts(self, T, K, d):
        contexts = self._sample_from_unit_ball(self.cnum, K, d)
        indices = torch.randint(0, self.cnum, (T,))
        return contexts[indices]

    def latent_reward(self, x):
        inner_prod = torch.sum(x * self.theta_star, dim=-1, keepdim=True)
        if self.latent_function == 'h1':
            return 0.2 * (inner_prod ** 4)
        elif self.latent_function == 'h2':
            return 20 * torch.cos(inner_prod)
        elif self.latent_function == 'h3':
            norm_val = torch.sum(
                torch.matmul(x, self.theta_star_matrix) * x,
                dim=-1,
                keepdim=True
            )
            return 5 * norm_val

    def get_rewards(self, x):
        latent = self.latent_reward(x)
        prob = sigmoid(latent)
        reward = torch.bernoulli(prob)
        return reward, prob


class LogisticBanditRealEnvironment:
    def __init__(self, device, T=1000, d=20, K=5, S=1, latent_function='mnist', cnum=500, seed=0):
        self.T = T
        self.d = d
        self.K = K
        self.S = S
        self.latent_function = latent_function
        self.device = device
        self.cnum = cnum
        self.original_dataset = self._download_dataset()      
        torch.manual_seed(seed)
        np.random.seed(seed)

    def _download_dataset(self):
        if self.latent_function == 'mnist':
            mnist_train = datasets.MNIST(root='./data', train=True, download=True,
                                        transform=transforms.ToTensor())
            mnist_data = mnist_train.train_data  # (60000, 28, 28)

            s_ind = torch.randint(0, self.cnum, (mnist_data.shape[0], ))             
            mnist_data = mnist_data[s_ind]
            mnist_labels = mnist_train.train_labels[s_ind]

            mnist_data = F.interpolate(mnist_data.unsqueeze(1).float(), size=(7, 7), mode='bilinear', align_corners=False)
            mnist_data = mnist_data.squeeze(1)  # (60000, 7, 7)
            
            mnist_data = mnist_data.view(mnist_data.size(0), -1)    


            min_vals = mnist_data.min(dim=1, keepdim=True).values
            max_vals = mnist_data.max(dim=1, keepdim=True).values
            mnist_data = (2 * (mnist_data - min_vals) / (max_vals - min_vals + 1e-8) - 1) * 1

            mnist_data = mnist_data.to(torch.float32)
            mnist_labels = mnist_labels.to(torch.float32)
            return mnist_data, mnist_labels, 10
            
        elif self.latent_function == 'mushroom':
            mushroom_x, mushroom_y = fetch_openml("mushroom", version=1, return_X_y=True)
            for col in mushroom_x.select_dtypes(include='category').columns:
                mushroom_x[col] = mushroom_x[col].astype(object)

            mushroom_x.fillna(0, inplace=True) 
            mushroom_x = mushroom_x.to_numpy()  # (8124, 22)
            mushroom_y = mushroom_y.to_numpy()  # (8124, )     
            
            for i in range(97, 123):
                random_val = (np.random.rand(1) * 2 - 1) * 1
                mushroom_x = np.where(mushroom_x == chr(i), random_val, mushroom_x)
                
            mushroom_y = np.where(mushroom_y == 'e', 1., 0.)

            mushroom_x = mushroom_x.astype(np.float32)   
            mushroom_y = mushroom_y.astype(np.float32)

            s_ind = torch.randint(0, self.cnum, (mushroom_x.shape[0], ))             
            mushroom_x = mushroom_x[s_ind]
            mushroom_y = mushroom_y[s_ind]

            mushroom_x = torch.tensor(mushroom_x, dtype=torch.float32)
            mushroom_y = torch.tensor(mushroom_y, dtype=torch.float32)
            return mushroom_x, mushroom_y, 2

        elif self.latent_function == 'shuttle':
            shuttle_x, shuttle_y = fetch_openml("shuttle", version=1, return_X_y=True)
            for col in shuttle_x.select_dtypes(include='category').columns:
                shuttle_x[col] = shuttle_x[col].astype(object)

            shuttle_x.fillna(0, inplace=True)     
            shuttle_x = shuttle_x.to_numpy()  # (58000, 9)
            shuttle_y = shuttle_y.to_numpy()  # (58000, )        

            shuttle_x = shuttle_x.astype(np.float32)   
            shuttle_y = shuttle_y.astype(np.float32)
            
            s_ind = torch.randint(0, self.cnum, (shuttle_x.shape[0], ))             
            shuttle_x = shuttle_x[s_ind]
            shuttle_y = shuttle_y[s_ind]

            shuttle_x = torch.tensor(shuttle_x, dtype=torch.float32)
            shuttle_y = torch.tensor(shuttle_y, dtype=torch.float32)

            min_vals = shuttle_x.min(dim=1, keepdim=True).values
            max_vals = shuttle_x.max(dim=1, keepdim=True).values
            shuttle_x = (2 * (shuttle_x - min_vals) / (max_vals - min_vals + 1e-8) - 1) * 1

            return shuttle_x, shuttle_y, 7


    def get_contexts(self, T):
        x, _, K = self.original_dataset
        _, d = x.shape
        contexts = torch.zeros((T, K, K*d)).to(self.device)  # (T, K, K*d)
        for i in range(K):
            contexts[:, i, i*d:(i+1)*d] = x[:T, :]
        return contexts  # (T, K, K*d)
    
        
    def get_rewards(self, T):
        # (T, K, K*d) -> (T, K, 1), (T, K, 1)
        _, y, K = self.original_dataset
        d = y.shape[0]
        rewards = torch.zeros((T, K, 1)).to(self.device) 
        if self.latent_function == 'mnist':
            ind = y
        elif self.latent_function == 'mushroom':
            ind = y
        elif self.latent_function == 'shuttle':
            ind = y - 1

        for t in range(T):
            rewards[t, int(ind[t].item()), 0] = 1
        probs = rewards
        return rewards, probs  # (T, K, 1), (T, K, 1)
