import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

from rf import RF

########################################################################
    
class Transitions_Dataset(Dataset):
    def __init__(self, path: str, device: str = "cpu", dtype=torch.float64) -> None:
        data = np.load(path)
        to_torch = lambda x: torch.tensor(data[x], dtype=dtype, device=device)
        self.state, self.reward, self.next_state, self.done = map(to_torch, ["state", "reward", "next_state", "done"])

    def __len__(self) -> int:
        return len(self.state)

    def __getitem__(self, idx) -> tuple:
        return self.state[idx], self.reward[idx], self.next_state[idx], self.done[idx]
    
########################################################################

class LSTD(nn.Module):
    """
        Adapted from https://github.com/chrodan/tdlearn/blob/master/td.py 
    """
    def __init__(self, 
                 features_fun: nn.Module = RF, 
                 weight_decay: float = 0.01, 
                 gamma: float = 0.99, 
                 batch_size: int = 100,
                 device: str = "cpu", 
                 dtype=torch.float64) -> None:
        super(LSTD, self).__init__()
        self.features_fun = features_fun.to(device=device)
        self.weight_decay, self.gamma, self.batch_size, self.device, self.dtype = weight_decay, gamma, batch_size, device, dtype
        self.num_features = self.features_fun.outputs
        self.reset()
                
    def reset(self)->None:
        self.W=torch.zeros((self.num_features,), dtype=self.dtype, device=self.device)
        self.A_inv = torch.eye(self.num_features, dtype=self.dtype, device=self.device) / self.weight_decay
        self.b = torch.zeros(self.num_features, dtype=self.dtype, device=self.device)
        
    @staticmethod
    def update_torch(A_inv, b, f, delta_f, reward):
        L = A_inv @ f
        K = L / (1 + delta_f @ L)
        A_inv -= torch.outer(K, delta_f @ A_inv)
        b += f * reward
        return A_inv, b

    def update(self, feature:torch.Tensor, reward:torch.Tensor, next_feature:torch.Tensor)->None:
        self.A_inv, self.b = self.update_torch(self.A_inv, self.b, feature, feature - self.gamma * next_feature, reward)
    
    def learn_offline(self, path:str, verbose:bool=True)->torch.Tensor:
        dataset = Transitions_Dataset(path=path, device=self.device, dtype=self.dtype)
        dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
        if verbose:
            dataloader = tqdm(dataloader, desc="Processing transitions")
        for state, reward, next_state, done in dataloader:
            feature = self.features_fun(state)
            for i in range(state.shape[0]-1):
                next_feature = self.features_fun(next_state[i]) if done[i] else feature[i + 1]
                self.update(feature=feature[i], reward=reward[i], next_feature=next_feature)
            self.update(feature=feature[-1], reward=reward[-1], next_feature=self.features_fun(next_state[-1]))
        self.W = self.A_inv.mv(self.b)        
        return self.W
    
    def save_weights(self, path:str)->None:
        torch.save({
            'W': self.W,
            'feature_fun_state_dict': self.features_fun.state_dict(),}, path)
        
    def load_weights(self, path:str)->None:
        checkpoint = torch.load(path)
        self.W = checkpoint['W']
        self.features_fun.load_state_dict(checkpoint['feature_fun_state_dict'])
    