import numpy as np
import seeding as seeding
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F

class Phinet(nn.Module):
    def __init__(self, state_dim, action_space, phi_dim):
        """
        Creates a new phi network to be learned from the first task.
        
        Parameters
        ----------
        state_dim : int
            dimnetion of a state feature vector
        action_space:
            number of actions in the MDP
        phi_dim : int
            dimention of phi 
        """
        super(Phinet, self).__init__()
        self.state_dim = state_dim
        self.action_space = action_space
        self.phi_dim = phi_dim
        self.hidden_dim = 8


        # model layers
        self.fc1 = nn.Linear(state_dim, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.fc3 = nn.Linear(self.hidden_dim, self.action_space * self.phi_dim)
            
    def forward(self, x):
        x = F.relu(self.fc1(x))
        # x = F.relu(self.fc2(x))
        x = F.sigmoid(self.fc3(x))
        return x.reshape([-1, self.action_space, self.phi_dim])

class synthetic_env:
    
    def __init__(self, 
                 state_space=30, 
                 action_space=4, 
                 state_dim=4,
                 phi_dim=4, 
                 gamma=0.95,
                 n_tasks=4,
                 seed=0, c=0.01,
                 tildeP=False):
        self.np_random, seed = seeding.np_random(seed)
        self.state_space = state_space
        self.action_space = action_space
        self.state_dim = state_dim
        self.phi_dim = phi_dim
        self.gamma = gamma
        
        # create state features
        self.states = (self.np_random.rand(state_space, state_dim)).astype('float32')

        # create a learnable phi 
        self.phi = Phinet(state_dim, action_space, phi_dim)

        # print('phi\n',self.phi)

        # create random reward weights. Also possible to design custom weights
        self.w_scale=1
        self.w = self.np_random.rand(phi_dim, n_tasks)
        # beta= 50
        
        # to train task 1 again from scratch, pretending we dont know optimal psi, w
        self.w[:, 1] = self.w[:, 0] # + np.ones(phi_dim)

        self.w_tilde = self.np_random.rand(phi_dim,)
        c1 = c
        beta1= 1 + c1/np.linalg.norm(self.w[:, 1] - self.w_tilde)
        print()
        # self.w_tilde = self.w[:, 1]
        self.w_tilde = nn.Parameter(torch.tensor((beta1 * self.w[:, 1] + (1-beta1) * self.w_tilde).reshape(-1, 1)).float())
        print('self.w_tilde', self.w_tilde)

        print('w1', np.linalg.norm(self.w[:, 1]))
        print('w init distance', np.linalg.norm(self.w[:, 1] - self.w_tilde.view(-1).detach().numpy()))

        # define arbitrary transition dynamics (shared by all tasks)
        P = self.np_random.rand(state_space, state_space, action_space)
        self.P = P/P.sum(axis=0)
        if tildeP:
            self.P = (1-self.gam) * self.P + self.gam/self.state_space
    
    def R(self, state, action, task):
        return (self.phi(torch.tensor(list([self.states[state]])))[0][action] @ torch.tensor(self.w[:, task]).float()).detach()
        
    def seed(self, seed):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]
    
    def reset(self):  
        self.state = self.np_random.randint(0,self.state_space-1)
        return self.states[self.state], self.onehot(self.state, self.state_space)
    
    def step(self, action, task):
        action_ = np.squeeze(action)
        state_ = self.state
        r = self.R(self.state, action_, task)
        phi = self.phi(torch.tensor(list([self.states[self.state]])))[0][action_].clone()
        self.state = self.np_random.choice(self.state_space, 
                                      p=self.P[:, self.state, action_]) 
        # r = self.R[state_, action_, self.state, task]
        # phi = self.phi[state_, action_, self.state, :]
         
        return self.states[self.state], r.detach().item(), phi, False, {}
    
    def set_stationary_state(self, mu):
        self.state = self.np_random.choice(self.state_space, p=mu)
        return self.states[self.state], self.onehot(self.state, self.state_space)
        
    def get_opt(self, pi):     
        mu, rank, Ps = self.get_mu(pi)
        if not rank == self.state_space:
            raise Exception('Non-solvable MDP.')
        # compute A_\pi,b_\pi
        A = 0
        b = 0
        for s in range(self.state_space):
            for s_ in range(self.state_space):
                A += mu[s] * Ps[s_,s] * self.states[s][:,np.newaxis] @ (self.gam * self.states[s_] - self.states[s])[np.newaxis,:]
            for a in range(self.action_space):
                b += mu[s] * pi[a,s] * self.R[s,a] * self.states[s]
        return -np.linalg.inv(A) @ b
    
    def get_mu(self, pi):
        # compute stationary distribution mu
        Ps = np.zeros((self.state_space, self.state_space))
        for s in range(self.state_space):
            Ps[:,s] = self.P[:,s,:] @ pi[:,s]
        a = np.eye(self.state_space) - Ps
        a = np.concatenate((a, np.ones(self.state_space)[np.newaxis,:]), axis=0)
        b = np.array([0] * self.state_space + [1])
        mu,_,rank,_ = np.linalg.lstsq(a, b, rcond=None)
        return mu, rank, Ps
    
    def onehot(self, index, length):
        v = np.zeros(length)
        v[index] = 1
        return v

    def close(self):
        pass
        
        

if __name__ == '__main__':
    env = synthetic_env()
    s, s_onehot = env.reset()
    actions, states, rewards = [], [], []
    for i in range(50):
        action = np.random.randint(0,3)
        actions.append(action)
        state, rew, state_onehot, done = env.step(action)
        states.append(state)
        rewards.append(rew)
        