import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
import scipy.linalg

from algs.base_learner import kron



class LSVI_UCB_ORA(object): 

    def __init__(
        self,
        obs_dim,
        state_dim,
        action_dim,
        horizon,
        alpha,
        device,
        lamb = 1
    ):

        self.obs_dim = obs_dim
        self.middle = int(obs_dim / 2)
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.horizon = horizon

        self.feature_dim = state_dim * action_dim

        self.device = "cpu"
        self.gpu_device = device
        #self.rep_learners = rep_learners

        self.lamb = lamb
        self.alpha = alpha

        self.W = torch.rand((self.horizon, self.feature_dim)).to(self.device)
        self.Sigma_invs = torch.zeros((self.horizon, self.feature_dim, self.feature_dim)).to(self.device)

        self.Q_max = torch.tensor(self.horizon)

        rotation = scipy.linalg.hadamard(self.middle)
        self.A = torch.inverse(torch.as_tensor(rotation, device=self.device).float())

    def get_feature(self, obs, states, actions):
        #print(obs)
        state_encoding = torch.matmul(self.A.to(self.device), obs[:,:self.middle].T).T[:,:3] + torch.matmul(self.A.to(self.device), obs[:,self.middle:].T).T[:,:3]
        #state_encoding = torch.zeros(len(obs),3).to(self.device)
        #state_encoding[torch.arange(len(state_encoding)), np.array(states, dtype=np.int32)] = 1
        #print(state_encoding)
        state_encoding = F.softmax(state_encoding / 0.001)

        #if self.variable_latent:
        #    state_encoding = F.softmax(state_encoding / self.env_temperature)
        #print(state_encoding)
        phi = kron(actions, state_encoding)
        #print(phi)
        return phi
    def Q_values(self, obs, states, h):
        Qs = torch.zeros((len(obs),self.action_dim)).to(self.device)
        for a in range(self.action_dim):
            actions = torch.zeros((len(obs),self.action_dim)).to(self.device)
            actions[:,a] = 1
            feature = self.get_feature(obs, states, actions)
            Q_est = torch.matmul(feature, self.W[h].to(self.device)) 
            ucb = torch.sqrt(torch.sum(torch.matmul(feature, self.Sigma_invs[h].to(self.device))*feature, 1))
            
            Qs[:,a] = torch.minimum(Q_est + self.alpha * ucb, self.Q_max)
            # if h == 0:
            #     print(a)
            #     print(ucb)
            #     print(Q_est)
            #     print(Qs[:,a])

            # if h == 1:
            #     print("11111111")
            #     print(a)
            #     print(ucb)
            #     print(Q_est)
            #     print(Qs[:,a])

        return Qs

    def act(self, obs, states, h):
        with torch.no_grad():
            obs = torch.FloatTensor(obs).to(self.device)
            obs = obs.unsqueeze(0)
            Qs = self.Q_values(obs, states, h)
            action = torch.argmax(Qs, dim=1)

        return action.cpu().data.numpy().flatten()

    def act_batch(self, obs, states, h):
        with torch.no_grad():
            obs = torch.FloatTensor(obs).to(self.device)
            #states = torch.FloatTensor(states).to(self.device)
            Qs = self.Q_values(obs, states, h)
            #if h == 0:
            #    print(Qs)
            action = torch.argmax(Qs, dim=1)

        return action.cpu().data.numpy().flatten()

    def update(self, buffers):
        assert len(buffers) == self.horizon

        self.device = self.gpu_device

        for h in range(self.horizon)[::-1]:
            obses, states, actions, rewards, next_obses, next_states = buffers[h].get_full(state=True)
            
            feature = self.get_feature(obses, states, actions)

            #if h == 0:
            #    print(feature)
            Sigma = torch.matmul(feature.T, feature) + self.lamb * torch.eye(self.feature_dim).to(self.device)
            self.Sigma_invs[h] = torch.inverse(Sigma)

            #if h == 0:
            #    print(self.Sigma_invs[h])

            if h == self.horizon - 1:
                target_Q = rewards
            else:
                Q_prime = torch.max(self.Q_values(next_obses, next_states, h+1),dim=1)[0].unsqueeze(-1)
                #print(Q_prime)
                #print(rewards)
                target_Q = rewards + Q_prime

                # if h == 0: 
                #     print("target Q")
                #     print(target_Q)
                # if h == 1:
                #     print("target Q !!!!!!!")
                #     print(target_Q)
            # print(target_Q.shape)
            # print(feature.shape)
            # print((feature * target_Q).shape)
            # print(torch.sum((feature * target_Q), 0).shape)
            # print(target_Q.shape)
            # print(feature.shape)
            self.W[h] = torch.matmul(self.Sigma_invs[h].to(self.device), torch.sum(feature * target_Q, 0))

        self.device = "cpu"

    def save_weight(self, path):
        for h in range(self.horizon):
            torch.save(self.W[h],"{}/W_{}.pth".format(path,str(h)))
            torch.save(self.Sigma_invs[h], "{}/Sigma_{}.pth".format(path,str(h)))

    def load_weight(self, path):
        for h in range(self.horizon):
            self.W[h] = torch.load("{}/W_{}.pth".format(path,str(h)))
            print(h)
            print(self.W[h].reshape((10,3)))
            self.Sigma_invs[h] = torch.load("{}/Sigma_{}.pth".format(path,str(h)))








