### A human model and AI policy NN

import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import pytorch_lightning as pl
import os 
import sys

class HumanDecisionMakerNN(pl.LightningModule):
    def __init__(self, M=2,N=2,K=2,batch_size=2**6, batch_num=100000, num_workers=8,lr=1e-6, traindata_path=None, valdata_path=None, fc1_size = 512, true_human = None, data_augment=False, state_save_path = '', reg=0):
        super(HumanDecisionMakerNN, self).__init__()
        self.fc1_size = fc1_size
        self.fc1 = nn.Linear(M * (N+1+K), self.fc1_size) #  u^R + \lamba + pi
        self.fc2 = nn.Linear(self.fc1_size,self.fc1_size)
        self.fc4 = nn.Linear(self.fc1_size,self.fc1_size)
        self.fc3 = nn.Linear(self.fc1_size, N*K)
        self.sm = nn.Softmax(dim=1)
        
        self.M, self.N, self.K = M, N, K
        self.batch_size = batch_size
        self.batch_num = batch_num
        self.num_workers = num_workers  
        self.lr = lr
        self.traindata_path = traindata_path
        self.valdata_path = valdata_path
        self.true_human = true_human
        self.data_augment = data_augment
        self.iter_count = 0
        self.state_save_path = state_save_path
        self.reg = reg ### ewgularization part  
        
    def forward(self,x):
        out = self.fc1(x.float())
        out = F.relu(out)
        out = self.fc2(out)
        out = F.relu(out)
        out = self.fc4(out)
        out = F.relu(out)
        out = self.fc3(out)
        out = out.view(-1,self.N, self.K)
        out = self.sm(out)
        out = out.view(-1,self.N*self.K)
        return out
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss_fn(y, y_hat)
        celoss = -torch.mean(y * torch.log(y_hat)) 
        argloss = 0
        self.log("performance", {"iter": batch_idx, "loss": loss, "CEloss": celoss, "meanbeta":0, 'argloss':argloss, 'beta':0})
        return celoss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        self.iter_count += 1 
        y_hat = self.forward(x)
        loss = self.loss_fn(y, y_hat)
        celoss = -torch.mean(y * torch.log(y_hat))
        self.log("performance", {"iter": batch_idx, "val_loss": loss, "Cess": celoss, "meanbeta":0, 'argloss':0, 'beta':0})
        
        if (self.iter_count %10 == 0 ):
            torch.save(self.state_dict(), self.state_save_path + '_'+ str(self.iter_count) + '.pt')   
        return celoss

    def loss_fn(self, y, pred):
        rmsel = nn.MSELoss()
        return rmsel(y, pred)
    
    def loss_fn_partial(self, y, pred, select_signal):
        celoss = -torch.mean(y[select_signal] * torch.log(y_hat[select_signal]))
        return celoss

    def train_dataloader(self):
        if os.path.exists(self.traindata_path):
            data = torch.load(self.traindata_path)
            x, y = data['human_x'], data['human_y']
            if self.batch_size * self.batch_num < x.size()[0]:
                x, y = x[:self.batch_size * self.batch_num], y[:self.batch_size * self.batch_num]
        else:
            print('check train data path:', self.traindata_path)
            sys.exit(0)
            
        if self.data_augment:
            ### generate more data by symetry
            pass

        ds = torch.utils.data.TensorDataset(x,y)
        dl = torch.utils.data.DataLoader(ds, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        return dl
    
    def val_dataloader(self):
        if os.path.exists(self.valdata_path):
            data = torch.load(self.valdata_path)
            x, y = data['human_x'], data['human_y']
        else:
            print('check validation data path:', self.valdata_path)
            sys.exit(0)
        ds = torch.utils.data.TensorDataset(x,y)
        dl = torch.utils.data.DataLoader(ds, batch_size=x.size()[0], shuffle=False, num_workers=1)    
        return dl
    
    

class HumanDecisionMakerModel():
    ### A model based human decision maker
    ### non-rational and non-bayesian

    def __init__(self, alpha=1,beta=10,gamma=1,M=2,N=2,K=2):
        self.M, self.N, self.K = M, N, K
        self.alpha, self.beta, self.gamma = alpha, beta, gamma
        
    def forward(self, x):  
        M, N, K = self.M, self.N, self.K
        x = x.reshape(-1,M,N+N+1)
        uR = x[:,:,:N].reshape(-1,M,N,1)
        lam = x[:,:,N].reshape(-1,M,1,1)
        pi = x[:,:,N+1:].reshape(-1,M,1,K)
        return self.forward_info( uR, lam, pi)
        
    def forward_info(self, uR, lam, pi):
        ### x:  uR->M*N, prior->M*1,pi signal->M*K 
        ### return ->bs,1, N,K
        mu = (pi*lam)
        mu /= torch.maximum( mu.sum(axis=1,keepdim=True), torch.ones_like(mu)*0.001)   
        mu_re = torch.maximum(mu, torch.ones_like(mu)*0.001)
        mu_h = torch.exp (- torch.pow( - torch.log(mu_re), self.gamma))
        mu_h2 = mu_h /  torch.maximum( mu_h.sum(axis=1,keepdim=True), torch.ones_like(mu_h)*0.001)
        UR = mu_h2 * uR
        UR = UR.sum(axis=1,keepdim=True)
        UR_sm = UR
        UR_sm = F.softmax(self.beta * UR_sm, dim=2)
        UR_random = torch.ones_like(UR)
        UR_random = F.softmax (UR_random, dim=2)
        UR_h = self.alpha * UR_sm + (1-self.alpha) * UR_random 
        if torch.sum(torch.isnan(UR_h)) > 0:
            UR_h = F.softmax (torch.ones_like(UR), dim=2)
        return UR_h
    
    
class RationalHuman():
    def __init__(self, M=2,N=2,K=2):
        self.M, self.N, self.K = M, N, K
        
    def forward(self, x):  
        M, N, K = self.M, self.N, self.K
        x = x.reshape(-1,M,N+N+1)
        uR = x[:,:,:N].reshape(-1,M,N,1)
        lam = x[:,:,N].reshape(-1,M,1,1)
        pi = x[:,:,N+1:].reshape(-1,M,1,K)
        return self.forward_info( uR, lam, pi)
        
    def forward_info(self, uR, lam, pi):
        ### x:  uR->M*N, prior->M*1,pi signal->M*K 
        ### return ->bs,1, N,K
        mu = (pi*lam)
        mu /= torch.maximum( mu.sum(axis=1,keepdim=True), torch.ones_like(mu)*0.001)   
        mu_re = torch.maximum(mu, torch.ones_like(mu)*0.001)

        UR = mu_re * uR
        UR = UR.sum(axis=1,keepdim=True)
        UR_am = UR 
        idx = torch.argmax(UR_am, axis=2)
        UR_am = torch.zeros_like(UR_am).scatter_(dim=2, index = idx.unsqueeze(1), src = torch.ones_like(UR_am))   
        UR_h = UR_am
        return UR_h
    
class PolicyNeuralNetwork(pl.LightningModule):
    ### the NN for policy model when loss is defined by a human model
    
    def __init__(self, M=2, N=2, K=2, batch_size=2**6, batch_num=100000, num_workers=8, tile=None, beta = 10, lr=1e-6, beta_end = None, human_model = None,traindata_path=None, valdata_path=None):
        super(PolicyNeuralNetwork, self).__init__()
        self.M, self.N, self.K = M, N, K
        self.fc1 = nn.Linear(M * (N+N+1), 512) # u^S + u^R + λ
        self.fc2 = nn.Linear(512,512)
        self.fc3 = nn.Linear(512, M*K)
        self.sm = nn.Softmax(dim=2)
        
        self.batch_size = batch_size
        self.batch_num = batch_num
        self.num_workers = num_workers    
        self.tile = tile
        self.beta = beta
        if beta_end != None:
            self.beta_end = beta_end
        else:
            self.beta_end = beta
        self.beta_evolve = torch.logspace(start=float(np.log10(self.beta)), end=float(np.log10(self.beta_end)), steps=1000) # use a fixed steps to move
        self.evolve_step = 0
        self.lr = lr
        self.M = M
        self.N = N
        self.K = K
        self.human_model = human_model
        self.traindata_path = traindata_path
        self.valdata_path = valdata_path
        
    def forward(self, x):
        out = self.fc1(x.float())
        out = F.relu(out)
        out = self.fc2(out)
        out = F.relu(out)
        out = self.fc3(out)
        out = out.view(-1,self.M,self.K)
        out = self.sm(out)
        out = out.view(-1,self.M*self.K)
        return out
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        if batch_idx % 100 == 0: self.evolve_step = min(self.evolve_step + 1 , 999)
        loss = self.loss_fn_human(x,y_hat,human_model = self.human_model, reduce=True)
        
        torch.autograd.set_detect_anomaly(True)
        rmse, entropy = torch.tensor(0.0), torch.tensor(0.0)
        rmsel = nn.MSELoss()
        rmse = torch.sqrt(rmsel(y, y_hat))
        loss_arg = 0
        meanbeta = torch.mean(self.train_beta[batch_idx * self.batch_size : (batch_idx+1) * self.batch_size])
        self.log("performance", {"iter": batch_idx, "loss": loss, "rmse": rmse, "meanbeta":meanbeta, 'argloss':loss_arg, 'beta':self.beta_evolve[self.evolve_step]})
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn_human(x,y_hat,human_model = self.human_model,  reduce=True)
        rmse, entropy = torch.tensor(0.0), torch.tensor(0.0)
        loss_arg = 0
        self.log("performance", {"iter": batch_idx, "val-loss": loss, "rmse": rmse, "entropy":F.cross_entropy(y_hat, y), 'argloss':loss_arg})
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def train_dataloader(self):
        if os.path.exists(self.traindata_path):
            data = torch.load(self.traindata_path)
            x, y = data['prob_x'], data['prob_y']
        else:
            print('check train data path:', self.traindata_path)
            sys.exit(0)

        self.train_beta = torch.ones(self.batch_num * self.batch_size) * self.beta_end  ### beta of training 
        self.train_beta = self.train_beta.reshape(-1,1,1,1).to(device='cuda:0')
        ds = torch.utils.data.TensorDataset(x,y)
        dl = torch.utils.data.DataLoader(ds, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        return dl
    
    def val_dataloader(self):
        if os.path.exists(self.valdata_path):
            data = torch.load(self.valdata_path)
            x, y = data['prob_x'], data['prob_y']
        else:
            print('check val data path:', self.valdata_path)
            sys.exit(0)
        ds = torch.utils.data.TensorDataset(x,y)
        dl = torch.utils.data.DataLoader(ds, batch_size=x.size()[0], shuffle=False, num_workers=1)
        return dl
    
    def loss_fn_human(self, x,pred,human_model = None,reduce=False,batch_idx=-1):
        M,N,K = self.M, self.N, self.K
        x = x.reshape(-1,M,N+N+1)
        uR = x[:,:,:N].reshape(-1,M,N,1)
        uS = x[:,:,N:2*N].reshape(-1,M,N,1)
        lam = x[:,:,2*N:].reshape(-1,M,1,1)
        pi = pred.reshape(-1,M,1,K)
         
        if not human_model: ### when human model is not defined, a softmax model will be used as default
            mu = pi*lam
            mu /= torch.maximum( mu.sum(axis=1,keepdim=True), torch.ones_like(mu)*0.001)       
            UR = mu * uR
            UR = UR.sum(axis=1,keepdim=True)
            UR_sm = UR
            UR_sm = F.softmax(self.beta * UR_sm, dim=2)
#             print('shape of tensor:', mu.size(), UR.size(), UR_sm.size())
            total = -(lam * pi * uS * UR_sm).sum(axis=(1,2,3))
        else: ### use a human model
            human_x = torch.cat((uR.reshape(-1,M,N),lam.reshape(-1,M,1), pred.reshape(-1,M,K)),dim = 2).reshape(-1, M*(N+1+K))
            UR_h = human_model.forward(x= human_x).reshape(-1, 1,N,K)
            total = -(lam * pi * uS * UR_h).sum(axis=(1,2,3))

        if reduce:
            total = total.mean()   

        return total