from typing import Any
import numpy as np
import scipy
import torch


class Env:
    def __init__(self, generator, context_dist, valuation_model):
        self.generator = generator
        self.context_dist = context_dist
        self.context_dist.generator = generator
        self.valuation_model = valuation_model
        self.valuation_model.init_param(generator)
    
    def reset(self):
        self.valuation_model.init_param(self.generator)
    
    def gen_context(self):
        return self.context_dist.gen_context()
    
    def act(self, x, a):
        prob = self.valuation_model(x,a)
        unif = torch.rand(1, generator=self.generator, device=self.generator.device).item()
        return float(unif < prob), prob
    
    def optimal_action(self, x):
        a = torch.linspace(0,1,100).reshape(-1,1)
        pred = self.valuation_model(x,a)

        astar = torch.argmax(a*pred).squeeze().item()
        return a[astar], a[astar]*pred[astar]


class DataEnv:
    def __init__(self, price_data, feature_data, optimal_price, optimal_revenue, cdf, device, generator):
        self.price = price_data
        self.features = feature_data
        self.optimal_price = optimal_price
        self.optimal_revenue = optimal_revenue
        assert self.price.shape[0] == self.features.shape[0], "dataset dimension mismatch"
        self.N = self.price.shape[0]
        self.d = self.features.shape[-1]
        self.cdf = cdf
        self.ptr = None
        self.device = device
        self.generator = generator

    def reset(self):
        return
    
    def gen_context(self):
        idx = self.generator.integers(self.N)
        self.ptr = idx
        ret =  torch.tensor(self.features[idx]).to(self.device)
        return ret
    
    def act(self, x, p: float):
        price = self.price[self.ptr]
        prob = 1 - self.cdf(p - price - 0.5)
        unif = self.generator.random()
        return float(unif < prob), prob
    
    def optimal_action(self, x):
        return self.optimal_price[self.ptr], self.optimal_revenue[self.ptr]

class GaussianContext:
    def __init__(self, d, sigma, device, generator=None):
        self.generator = generator
        self.d = d
        self.device = device
        self.mean = torch.zeros(self.d).to(self.device)
        self.std = sigma/np.sqrt(2*self.d)*torch.ones(self.d).to(self.device)

    def gen_context(self):
        return torch.normal(self.mean, self.std, generator=self.generator)


class UniformContext:
    def __init__(self, d, device, R=1, generator=None):
        self.generator = generator
        self.d = d
        self.device = device
        self.R = R
        self.mean = torch.zeros(self.d).to(self.device)
        self.std = torch.ones(self.d).to(self.device)

    def gen_context(self):
        x = torch.normal(self.mean, self.std, generator=self.generator)
        x = x / torch.norm(x)
        r =  torch.rand(1, generator=self.generator, device=self.device).item()
        return self.R * np.power(r,1/self.d) * x


class BinaryContext:
    def __init__(self, d, device, generator=None):
        self.generator = generator
        self.d = d
        self.p = 0.5 * torch.ones(self.d).to(device)
        self.device = device

    def gen_context(self):
        return torch.bernoulli(self.p, generator=self.generator)


def gaussian_cdf(x):
    return scipy.stats.truncnorm.cdf(x, a=-1, b=1, loc=0, scale=0.2)

def MoU_cdf(x):
    return 0.75*np.clip(4*(x-0.25),0,1) + 0.25*np.clip(4*x,0,1)


class LinearModel:
    def __init__(self, d, cdf, device):
        self.d = d
        self.cdf = cdf
        self.beta = None
        self.device = device
    
    def init_param(self, generator):
        self.beta = torch.normal(torch.zeros(self.d).to(self.device),
                                 1/np.sqrt(self.d)*torch.ones(self.d).to(self.device),
                                 generator=generator)

    def __call__(self, x, a):
        return 1 - self.cdf(a - (0.5+torch.dot(self.beta,x).squeeze().item()))
    
class LogLinearModel:
    def __init__(self, d, cdf, device):
        self.d = d
        self.cdf = cdf
        self.beta = None
        self.device = device
    
    def init_param(self, generator):
        self.beta = torch.normal(torch.zeros(self.d).to(self.device),
                                 1/np.sqrt(self.d)*torch.ones(self.d).to(self.device),
                                 generator=generator)

    def __call__(self, x, a):
        return 1 - self.cdf(2*a * np.exp(-torch.dot(self.beta,x).squeeze().item()) - 1)

# CDFs for PH model: PH model may have extreme distribution of optimal price with the CDFs above, thus we modify to fit the PH model.

def PH_gaussian_cdf(x):
    return scipy.stats.truncnorm.cdf(x, a=-1, b=1, loc=0, scale=1.0)

class PHModel:
    def __init__(self, d, cdf, device):
        self.d = d
        self.beta = None
        self.device = device

        if cdf==gaussian_cdf:
            self.cdf = PH_gaussian_cdf
        else:
            self.cdf = cdf
    
    def init_param(self, generator):
        self.beta = torch.normal(torch.zeros(self.d).to(self.device),
                                 1/np.sqrt(self.d)*torch.ones(self.d).to(self.device),
                                 generator=generator)

    def __call__(self, x, a):
        return np.power(1-self.cdf(2*a-1), np.exp(2*np.sqrt(self.d)*torch.dot(self.beta,x).squeeze().item()))