import numpy as np
import torch
import torch.nn.functional as F
from sklearn.datasets import fetch_openml
from torchvision import datasets, transforms

from utils import *


class LogisticBanditEnvironment:
    def __init__(self, device, T=1000, d=20, K=5, S=1, latent_function='h1', cnum=500, seed=0):
        self.device = device
        self.T = T
        self.d = d
        self.K = K
        self.S = S
        self.latent_function = latent_function
        self.cnum = cnum
        torch.manual_seed(seed)

        self.theta_star = self._sample_from_unit_ball(1, 1, d).to(device)
        self.theta_star_matrix = (
            self._sample_from_unit_ball(1, d, d)
            .squeeze(0)
            .to(device)
        )
        self.theta_star_gram = self.theta_star_matrix.t() @ self.theta_star_matrix

    def _sample_from_unit_ball(self, T, K, d):
        x = torch.rand(T, K, d)
        return 2 * x - 1

    def get_contexts(self, T, K, d):
        contexts = self._sample_from_unit_ball(self.cnum, K, d)
        indices = torch.randint(0, self.cnum, (T,))
        return contexts[indices]

    def latent_reward(self, x):
        inner_prod = torch.sum(x * self.theta_star, dim=-1, keepdim=True)
        if self.latent_function == 'h1':
            return 0.2 * (inner_prod ** 4)
        elif self.latent_function == 'h2':
            return 20 * torch.cos(inner_prod)
        elif self.latent_function == 'h3':
            norm_val = torch.sum(
                torch.matmul(x, self.theta_star_matrix) * x,
                dim=-1,
                keepdim=True
            )
            return 5 * norm_val

        elif self.latent_function == 'h4':
            return 10 * (inner_prod ** 2)

        elif self.latent_function == 'h5':
            quad_val = torch.sum(
                torch.matmul(x, self.theta_star_gram) * x,
                dim=-1,
                keepdim=True
            )
            return quad_val * 0.5

        elif self.latent_function == 'h6':
            return torch.cos(3 * inner_prod)

        else:
            raise ValueError(f"Unknown latent_function: {self.latent_function}")


    def get_rewards(self, x):
        latent = self.latent_reward(x)
        prob = sigmoid(latent)
        reward = torch.bernoulli(prob)
        return reward, prob


class LogisticBanditRealEnvironment:
    def __init__(self, device, T=1000, d=20, K=5, S=1, latent_function='mnist', cnum=500, seed=0):
        self.T = T
        self.d = d
        self.K = K
        self.S = S
        self.latent_function = latent_function
        self.device = device
        self.cnum = cnum
        self.original_dataset = self._download_dataset()      
        torch.manual_seed(seed)
        np.random.seed(seed)

    def _download_dataset(self):
        if self.latent_function == 'mnist':
            mnist_train = datasets.MNIST(root='./data', train=True, download=True,
                                        transform=transforms.ToTensor())
            mnist_data = mnist_train.train_data  # (60000, 28, 28)

            s_ind = torch.randint(0, self.cnum, (mnist_data.shape[0], ))             
            mnist_data = mnist_data[s_ind]
            mnist_labels = mnist_train.train_labels[s_ind]

            mnist_data = F.interpolate(mnist_data.unsqueeze(1).float(), size=(7, 7), mode='bilinear', align_corners=False)
            mnist_data = mnist_data.squeeze(1)  # (60000, 7, 7)
            
            mnist_data = mnist_data.view(mnist_data.size(0), -1)    


            min_vals = mnist_data.min(dim=1, keepdim=True).values
            max_vals = mnist_data.max(dim=1, keepdim=True).values
            mnist_data = (2 * (mnist_data - min_vals) / (max_vals - min_vals + 1e-8) - 1) * 1

            mnist_data = mnist_data.to(torch.float32)
            mnist_labels = mnist_labels.to(torch.float32)
            return mnist_data, mnist_labels, 10
            
        elif self.latent_function == 'mushroom':
            mushroom_x, mushroom_y = fetch_openml("mushroom", version=1, return_X_y=True)
            for col in mushroom_x.select_dtypes(include='category').columns:
                mushroom_x[col] = mushroom_x[col].astype(object)

            mushroom_x.fillna(0, inplace=True) 
            mushroom_x = mushroom_x.to_numpy()  # (8124, 22)
            mushroom_y = mushroom_y.to_numpy()  # (8124, )     
            
            for i in range(97, 123):
                random_val = (np.random.rand(1) * 2 - 1) * 1
                mushroom_x = np.where(mushroom_x == chr(i), random_val, mushroom_x)
                
            mushroom_y = np.where(mushroom_y == 'e', 1., 0.)

            mushroom_x = mushroom_x.astype(np.float32)   
            mushroom_y = mushroom_y.astype(np.float32)

            s_ind = torch.randint(0, self.cnum, (mushroom_x.shape[0], ))             
            mushroom_x = mushroom_x[s_ind]
            mushroom_y = mushroom_y[s_ind]

            mushroom_x = torch.tensor(mushroom_x, dtype=torch.float32)
            mushroom_y = torch.tensor(mushroom_y, dtype=torch.float32)
            return mushroom_x, mushroom_y, 2

        elif self.latent_function == 'shuttle':
            shuttle_x, shuttle_y = fetch_openml("shuttle", version=1, return_X_y=True)
            for col in shuttle_x.select_dtypes(include='category').columns:
                shuttle_x[col] = shuttle_x[col].astype(object)

            shuttle_x.fillna(0, inplace=True)     
            shuttle_x = shuttle_x.to_numpy()  # (58000, 9)
            shuttle_y = shuttle_y.to_numpy()  # (58000, )        

            shuttle_x = shuttle_x.astype(np.float32)   
            shuttle_y = shuttle_y.astype(np.float32)
            
            s_ind = torch.randint(0, self.cnum, (shuttle_x.shape[0], ))             
            shuttle_x = shuttle_x[s_ind]
            shuttle_y = shuttle_y[s_ind]

            shuttle_x = torch.tensor(shuttle_x, dtype=torch.float32)
            shuttle_y = torch.tensor(shuttle_y, dtype=torch.float32)

            min_vals = shuttle_x.min(dim=1, keepdim=True).values
            max_vals = shuttle_x.max(dim=1, keepdim=True).values
            shuttle_x = (2 * (shuttle_x - min_vals) / (max_vals - min_vals + 1e-8) - 1) * 1

            return shuttle_x, shuttle_y, 7



        elif self.latent_function == 'magic':
            magic_x, magic_y = fetch_openml(
                "MagicTelescope", version=1, return_X_y=True
            )

            for col in magic_x.select_dtypes(include='category').columns:
                magic_x[col] = magic_x[col].astype(object)

            magic_x.fillna(0, inplace=True)
            magic_x = magic_x.to_numpy()   # (n, d)
            magic_y = magic_y.to_numpy()   # (n, )

            magic_y = np.where(magic_y == 'g', 1., 0.)

            magic_x = magic_x.astype(np.float32)
            magic_y = magic_y.astype(np.float32)

            s_ind = torch.randint(0, self.cnum, (magic_x.shape[0],))
            magic_x = magic_x[s_ind]
            magic_y = magic_y[s_ind]

            magic_x = torch.tensor(magic_x, dtype=torch.float32)
            magic_y = torch.tensor(magic_y, dtype=torch.float32)

            min_vals = magic_x.min(dim=1, keepdim=True).values
            max_vals = magic_x.max(dim=1, keepdim=True).values
            magic_x = (2 * (magic_x - min_vals) /
                       (max_vals - min_vals + 1e-8) - 1) * 1

            return magic_x, magic_y, 2  # binary


        elif self.latent_function == 'banknote':
            bank_x, bank_y = fetch_openml(
                "banknote-authentication", version=1, return_X_y=True
            )

            bank_x = bank_x.to_numpy().astype(np.float32)   # (N, d)
            bank_y = bank_y.to_numpy().astype(np.float32)   # (N, )

            x = torch.tensor(bank_x, dtype=torch.float32)
            y = torch.tensor(bank_y, dtype=torch.float32)

            N = x.size(0)
            use_n = min(self.cnum, N)
            perm = torch.randperm(N)[:use_n]
            x = x[perm]          # (use_n, d)
            y = y[perm]          # (use_n,)

            min_vals = x.min(dim=1, keepdim=True).values
            max_vals = x.max(dim=1, keepdim=True).values
            x = (2 * (x - min_vals) / (max_vals - min_vals + 1e-8) - 1) * 1.0

            K = 2   # binary (0/1)
            return x, y, K


        elif self.latent_function == 'phoneme':
            phoneme_x, phoneme_y = fetch_openml(
                "phoneme", version=1, return_X_y=True
            )

            for col in phoneme_x.select_dtypes(include='category').columns:
                phoneme_x[col] = phoneme_x[col].astype(object)
            phoneme_x.fillna(0, inplace=True)

            X = phoneme_x.to_numpy().astype(np.float32)   # (N, d)
            y_np = phoneme_y.to_numpy()                   # (N,)

            classes, y_idx = np.unique(y_np, return_inverse=True)
            y_np = y_idx.astype(np.float32)
            K = len(classes)                            

            x = torch.tensor(X, dtype=torch.float32)
            y = torch.tensor(y_np, dtype=torch.float32)

            N = x.size(0)
            use_n = min(self.cnum, N)
            perm = torch.randperm(N)[:use_n]
            x = x[perm]          # (use_n, d)
            y = y[perm]          # (use_n,)

            min_vals = x.min(dim=1, keepdim=True).values
            max_vals = x.max(dim=1, keepdim=True).values
            x = 2 * (x - min_vals) / (max_vals - min_vals + 1e-8) - 1.0

            return x, y, K



    def get_contexts(self, T):
        x, _, K = self.original_dataset   # x: (N, d)
        _, d = x.shape

        if x.size(0) >= T:
            x_T = x[:T, :]                    # (T, d)
        else:
            idx = torch.randint(0, x.size(0), (T,), device=x.device)
            x_T = x[idx]                      # (T, d),

        contexts = torch.zeros((T, K, K*d)).to(self.device)  # (T, K, K*d)
        for i in range(K):
            contexts[:, i, i*d:(i+1)*d] = x_T
        return contexts


    def get_rewards(self, T):
        _, y, K = self.original_dataset  # y: (N,)
        N = y.shape[0]

        if N >= T:
            y_T = y[:T]                  # (T,)
        else:
            idx = torch.randint(0, N, (T,), device=y.device)
            y_T = y[idx]                 # (T,)

        if self.latent_function in ['mnist', 'mushroom', 'magic',
                                    'banknote', 'phoneme']:
            ind = y_T.to(torch.long)
        elif self.latent_function in ['shuttle']:
            ind = (y_T - 1).to(torch.long)

        else:
            raise ValueError(f"Unknown latent_function: {self.latent_function}")


        # one-hot reward tensor: (T, K, 1)
        rewards = torch.zeros((T, K, 1), device=self.device)
        for t in range(T):
            c = int(ind[t].item())
            if 0 <= c < K:
                rewards[t, c, 0] = 1.0

        probs = rewards
        return rewards, probs

