from matplotlib.pyplot import axis
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
import scipy.linalg

from algs.base_learner import kron

from utils import median_trick

from sklearn.kernel_approximation import RBFSampler

import multiprocessing



def get_inverse_sigma(replay_buffer, rbf_feature, lamb, feature_dim, h, queue):

    obses, actions, rewards, next_obses = replay_buffer.get_full_np(recent_size=100000)
    obsac = np.concatenate((obses,actions),axis=-1)
    #bandwidth = median_trick(obsac)
    feature = rbf_feature.fit_transform(obsac / 5)
    Sigma = np.matmul(feature.T, feature) + lamb * np.eye(feature_dim)
    Sigma_invs = np.linalg.inv(Sigma)

    queue.put([h, Sigma_invs, feature])

def get_Q(obs,action_dim,a,rbf_feature, W, Sigma_inv, alpha, Q_max, queue):
    actions = np.zeros((len(obs),action_dim))
    actions[:,a] = 1
    obsac = np.concatenate((obs,actions),axis=-1)
    #bandwidth = median_trick(obsac)
    feature = rbf_feature.fit_transform(obsac / 5)
    Q_est = np.matmul(feature, W) 
    ucb = np.sqrt(np.sum(np.matmul(feature, Sigma_inv)*feature, 1))
            
    Q = np.minimum(Q_est + alpha * ucb, Q_max)

    queue.put([a,Q])

class LSVI_UCB(object): 

    def __init__(
        self,
        obs_dim,
        state_dim,
        action_dim,
        horizon,
        alpha,
        device,
        #rep_learners,
        variable_latent,
        env_temperature,
        feature_size = 1000,
        seed=12,
        lamb = 1
    ):

        self.obs_dim = obs_dim
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.horizon = horizon

        self.feature_dim = feature_size

        self.device = "cpu"
        #self.rep_learners = rep_learners

        self.lamb = lamb
        self.alpha = alpha

        self.W = np.random.rand(self.horizon, self.feature_dim)
        self.Sigma_invs = np.zeros((self.horizon, self.feature_dim, self.feature_dim))

        self.Q_max = self.horizon

        self.variable_latent = variable_latent
        self.env_temperature = env_temperature

        self.rbf_feature = RBFSampler(gamma=1, random_state=seed, n_components=feature_size)

        self.bandwidth = np.ones(horizon) * 5

        print(self.alpha)
        print(self.env_temperature)
    
    def Q_values(self, obs, h, eval=False):
        Qs = np.zeros((len(obs),self.action_dim))
        if eval:
            for a in range(self.action_dim):
                actions = np.zeros((len(obs),self.action_dim))
                actions[:,a] = 1
                obsac = np.concatenate((obs,actions),axis=-1)
                #bandwidth = median_trick(obsac)
                feature = self.rbf_feature.fit_transform(obsac / 5)
                Q_est = np.matmul(feature, self.W[h]) 
                ucb = np.sqrt(np.sum(np.matmul(feature, self.Sigma_invs[h])*feature, 1))
                Qs[:,a] = np.minimum(Q_est + self.alpha * ucb, self.Q_max)

        else:
            queue = multiprocessing.Queue()
            workers = []
            
            for a in range(self.action_dim):
                worker_args = (obs,self.action_dim,a,self.rbf_feature, self.W[h], self.Sigma_invs[h], self.alpha, self.Q_max, queue)
                workers.append(multiprocessing.Process(target=get_Q, args=worker_args))
            for worker in workers:
                worker.start()

            for _ in workers:
                pid, Q = queue.get()
                Qs[:,pid] = Q

        return Qs

    def act(self, obs, h):
        with torch.no_grad():
            obs = torch.FloatTensor(obs).to(self.device)
            obs = obs.unsqueeze(0)
            Qs = self.Q_values(obs, h)
            action = torch.argmax(Qs, dim=1)

        return action.cpu().data.numpy().flatten()

    def act_batch(self, obs, h):
        Qs = self.Q_values(obs, h, eval=True)
        action = np.argmax(Qs, axis=1)

        return action

    def get_feature(self, obsac, h):

        phi = self.rbf_feature.fit_transform(obsac / self.bandwidth[h])
        phi = torch.as_tensor(phi, device=self.device).float()

        return phi


    def update(self, buffers):
        assert len(buffers) == self.horizon
        queue = multiprocessing.Queue()
        workers = []
        
        cur_feature = {}
        
        for h in range(self.horizon):
            worker_args = (buffers[h], self.rbf_feature, self.lamb, self.feature_dim, h, queue)
            workers.append(multiprocessing.Process(target=get_inverse_sigma, args=worker_args))
        for worker in workers:
            worker.start()

        for _ in workers:
            pid, sigma_inverse, feature = queue.get()
            self.Sigma_invs[pid] = sigma_inverse
            cur_feature[pid] = feature

        for h in range(self.horizon)[::-1]:
            obses, actions, rewards, next_obses = buffers[h].get_full_np(recent_size=100000)
            if h == self.horizon - 1:
                target_Q = rewards
            else:
                Q_prime = np.max(self.Q_values(next_obses, h+1),axis=1).reshape(-1,1)
                #print(Q_prime)
                #print(rewards)
                target_Q = rewards + Q_prime
            # print(target_Q.shape)
            # print(feature.shape)
            # print((feature * target_Q).shape)
            # print(torch.sum((feature * target_Q), 0).shape)
            # print(target_Q.shape)
            # print(feature.shape)
            self.W[h] = np.matmul(self.Sigma_invs[h], np.sum(cur_feature[h] * target_Q, 0))


        #print(self.bandwidth)

    def save_weight(self, path):
        for h in range(self.horizon):
            np.save("{}/W_{}".format(path,str(h)),self.W[h])
            np.save("{}/Sigma_{}".format(path,str(h)),self.Sigma_invs[h])

    def load_weight(self, path):
        for h in range(self.horizon):
            self.W[h] = np.load("{}/W_{}.npy".format(path,str(h)))
            print(h)
            print(self.W[h].reshape((10,3)))
            self.Sigma_invs[h] = np.load("{}/Sigma_{}.npy".format(path,str(h)))








