import torch
import numpy as np
import scipy
import scipy.optimize as so

class MLPOracle(torch.nn.Module):
    def __init__(self, d, T, lr=1e-2, num_grad=2):
        super().__init__()
        self.lr = lr
        self.num_grad = num_grad
        self.context_dim = d

        dim = d + 1
        self.layers = torch.nn.Sequential(
          torch.nn.Linear(in_features=dim, out_features=dim),
          torch.nn.LeakyReLU(),
          torch.nn.Linear(in_features=dim, out_features=dim),
          torch.nn.LeakyReLU(),
          torch.nn.Linear(in_features=dim, out_features=1)
        )
        self.sigmoid = torch.nn.Sigmoid()

        self.loss = torch.nn.BCELoss(reduction='mean')
        self.opt = torch.optim.Adam(self.parameters(), lr=self.lr)
        
        self.t = 0
        self.T = T
        self.device = None

    def reset(self):
        self.device = next(self.parameters()).device
        torch.nn.init.kaiming_uniform_(self.layers[0].weight, 0.01)
        torch.nn.init.zeros_(self.layers[0].bias)
        torch.nn.init.kaiming_uniform_(self.layers[2].weight, 0.01)
        torch.nn.init.zeros_(self.layers[2].bias)
        torch.nn.init.kaiming_uniform_(self.layers[4].weight, 0)
        torch.nn.init.zeros_(self.layers[4].bias)

        self.X = torch.zeros(self.T,self.context_dim).to(self.device)
        self.P = torch.zeros(self.T).to(self.device)
        self.Y = torch.zeros(self.T).to(self.device)
        self.t = 0
        self.opt = torch.optim.Adam(self.parameters(), lr=self.lr)
    
    def forward(self, x, p):
        if isinstance(p, float):
            p = torch.Tensor([p]).to(self.device)
        input = torch.concat([x.reshape(-1,self.context_dim), p.reshape(-1,1)], dim=1)
        return self.sigmoid(self.layers(input))
    
    def compute_phat(self, x):
        p = torch.linspace(0,1,100).reshape(-1,1).to(self.device)
        with torch.no_grad():
            pred = self(torch.tile(x,(100,1)),p)
            ahat_idx = torch.argmax(p*pred).squeeze().item()
        return p[ahat_idx].item()
    
    def update(self, x, p, y):
        self.X[self.t] = x
        self.Y[self.t] = y
        self.P[self.t] = p
        self.t += 1

        for i in range(self.num_grad):
            self.opt.zero_grad()
            pred = self(self.X[:self.t],self.P[:self.t]).reshape(-1)
            loss = self.loss(pred, self.Y[:self.t])
            loss.backward()
            self.opt.step()


class SqMLPOracle(torch.nn.Module):
    def __init__(self, d, T, lr=1e-3, num_grad=2):
        super().__init__()
        self.lr = lr
        self.num_grad = num_grad
        self.context_dim = d

        dim = d + 1
        self.layers = torch.nn.Sequential(
          torch.nn.Linear(in_features=dim, out_features=dim),
          torch.nn.LeakyReLU(),
          torch.nn.Linear(in_features=dim, out_features=dim),
          torch.nn.LeakyReLU(),
          torch.nn.Linear(in_features=dim, out_features=1)
        )
        self.sigmoid = torch.nn.Sigmoid()

        self.opt = torch.optim.Adam(self.parameters(), lr=self.lr)
        
        self.t = 0
        self.T = T
        self.device = None

    def reset(self):
        self.device = next(self.parameters()).device
        torch.nn.init.kaiming_uniform_(self.layers[0].weight, 0.01)
        torch.nn.init.zeros_(self.layers[0].bias)
        torch.nn.init.kaiming_uniform_(self.layers[2].weight, 0.01)
        torch.nn.init.zeros_(self.layers[2].bias)
        torch.nn.init.kaiming_uniform_(self.layers[4].weight, 0)
        torch.nn.init.zeros_(self.layers[4].bias)

        self.X = torch.zeros(self.T,self.context_dim).to(self.device)
        self.A = torch.zeros(self.T).to(self.device)
        self.Y = torch.zeros(self.T).to(self.device)
        self.t = 0
        self.opt = torch.optim.Adam(self.parameters(), lr=self.lr)
    
    def forward(self, x, a):
        if not isinstance(a, torch.Tensor):
            a = torch.Tensor([a]).to(self.device)
        input = torch.concat([x.reshape(-1,self.context_dim),a.reshape(-1,1)], dim=1)
        return self.sigmoid(self.layers(input))
    
    def compute_ahat(self, x):
        a = torch.linspace(0,1,100).reshape(-1,1).to(self.device)
        with torch.no_grad():
            pred = self(torch.tile(x,(100,1)),a)
            ahat_idx = torch.argmax(a*pred).squeeze().item()
        return a[ahat_idx].item()
    
    def update(self, x, a, y):
        self.X[self.t] = x
        self.Y[self.t] = y
        self.A[self.t] = a
        self.t += 1

        for i in range(self.num_grad):
            self.opt.zero_grad()
            pred = self(self.X[:self.t],self.A[:self.t]).reshape(-1)
            loss = torch.mean((pred-self.Y[:self.t])**2)
            loss.backward()
            self.opt.step()