from typing import List

import numpy as np
import torch
from gymnasium import spaces

from base_env import BaseEnv

from .corrupt import get_corrupt_params

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def sample(dim, H, var, type="uniform"):
    if type == "uniform":
        means = np.random.uniform(0, 1, dim)
    elif type == "bernoulli":
        means = np.random.beta(1, 1, dim)
    else:
        raise NotImplementedError
    env = BanditEnv(means, H, var=var, type=type)
    return env


def sample_linear(arms, H, var):
    lin_d = arms.shape[1]
    theta = np.random.normal(0, 1, lin_d) / np.sqrt(lin_d)
    env = LinearBanditEnv(theta, arms, H, var=var)
    return env


class BanditEnv(BaseEnv):
    def __init__(self, means, H, var=0.0, type="uniform", corrupt=""):
        opt_a_index = np.argmax(means)
        self.means = means
        self.opt_a_index = opt_a_index
        self.opt_a = np.zeros(means.shape)
        self.opt_a[opt_a_index] = 1.0
        self.dim = len(means)
        self.observation_space = spaces.Box(low=1, high=1, shape=(1,))
        self.action_space = spaces.Box(low=0, high=1, shape=(self.dim,))
        self.state = np.array([1])
        self.var = var
        self.dx = 1
        self.du = self.dim
        self.topk = False
        self.type = type

        # some naming issue here
        self.H_context = H  # H_context -- context horizon, number of transitions in context
        self.H = 1

        (self.corrupt_type, self.corrupted_steps, self.corrupt_magnitude, _, self.corrupted_means) = get_corrupt_params(corrupt, self.means, self.H_context)

    def get_arm_value(self, u):
        return np.sum(self.means * u)

    def reset(self):
        self.current_step = 0
        return self.state

    def reset2(self):
        self.context_step = 0
        return self.state

    def transit(self, state, action):
        action_index = np.argmax(action)
        if self.type == "uniform":
            r = self.means[action_index] + np.random.normal(0, self.var)
        elif self.type == "bernoulli":
            r = np.random.binomial(1, self.means[action_index])
        else:
            raise NotImplementedError

        if self.corrupt_type == "" or self.context_step not in self.corrupted_steps:
            # if not a corrupted step, don't poison
            return self.state.copy(), r

        # poisoning
        if self.corrupt_type == "gaussian":
            r += np.random.normal(0, self.corrupt_magnitude)
        elif self.corrupt_type.startswith("change") or self.corrupt_type.startswith("special"):
            r = self.corrupted_means[action_index] + np.random.normal(0, self.var)

        return self.state.copy(), r

    def step(self, action):
        if self.current_step >= self.H:
            raise ValueError("Episode has already ended")

        _, r = self.transit(self.state, action)
        self.current_step += 1
        self.context_step += 1
        done = self.current_step >= self.H

        return self.state.copy(), r, done, False, {}

    def deploy_eval(self, ctrl):
        # No variance during evaluation
        tmp = self.var
        self.var = 0.0
        res = self.deploy(ctrl)
        self.var = tmp
        return res


class BanditEnvVec(BaseEnv):
    """
    Vectorized bandit environment.
    """

    def __init__(self, envs):
        self._envs: List[BanditEnv] = envs
        self._num_envs = len(envs)
        self.dx = envs[0].dx
        self.du = envs[0].du

    def reset(self):
        return [env.reset() for env in self._envs]

    def reset2(self):
        return [env.reset2() for env in self._envs]

    def step(self, actions):
        next_obs, rews, dones = [], [], []
        for action, env in zip(actions, self._envs):
            next_ob, rew, done, _, _ = env.step(action)
            next_obs.append(next_ob)
            rews.append(rew)
            dones.append(done)
        return next_obs, rews, dones, False, {}

    @property
    def num_envs(self):
        return self._num_envs

    @property
    def envs(self):
        return self._envs

    def deploy_eval(self, ctrl):
        self.reset2()
        # No variance during evaluation
        tmp = [env.var for env in self._envs]
        for env in self._envs:
            env.var = 0.0
        res = self.deploy(ctrl)
        for env, var in zip(self._envs, tmp):
            env.var = var
        return res

    def deploy(self, ctrl):
        x = self.reset()
        xs = []
        xps = []
        us = []
        rs = []
        done = False

        while not done:
            u = ctrl.act_numpy_vec(x)

            xs.append(x)
            us.append(u)

            x, r, done, _, _ = self.step(u)
            done = all(done)

            rs.append(r)
            xps.append(x)

        xs = np.concatenate(xs)
        us = np.concatenate(us)
        xps = np.concatenate(xps)
        rs = np.concatenate(rs)
        return xs, us, xps, rs

    def get_arm_value(self, us):
        """Original, unpoisoned values"""
        values = [np.sum(env.means * u) for env, u in zip(self._envs, us)]
        return np.array(values)


class LinearBanditEnv(BanditEnv):
    def __init__(self, theta, arms, H, var=0.0):
        self.theta = theta
        self.arms = arms
        self.means = arms @ theta
        self.opt_a_index = np.argmax(self.means)
        self.opt_a = np.zeros(self.means.shape)
        self.opt_a[self.opt_a_index] = 1.0
        self.dim = len(self.means)
        self.observation_space = spaces.Box(low=1, high=1, shape=(1,))
        self.action_space = spaces.Box(low=0, high=1, shape=(self.dim,))
        self.state = np.array([1])
        self.var = var
        self.dx = 1
        self.du = self.dim

        self.H_context = H
        self.H = 1

    def get_arm_value(self, u):
        return np.sum(self.means * u)

    def reset(self):
        self.current_step = 0
        return self.state

    def transit(self, x, u):
        a = np.argmax(u)
        r = self.means[a] + np.random.normal(0, self.var)
        return self.state.copy(), r

    def step(self, action):
        if self.current_step >= self.H:
            raise ValueError("Episode has already ended")

        _, r = self.transit(self.state, action)
        self.current_step += 1
        done = self.current_step >= self.H

        return self.state.copy(), r, done, False, {}
