### model structure 

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torch.nn as nn
from Datagene import getXY, getXY_nstate_2action, lp_solver, decode_tensor
import os 
import numpy as np

### Define model
class NeuralNetwork(pl.LightningModule):
    def __init__(self, M=2, N=2, K=2, batch_size=2**6, batch_num=100000, num_workers=8, tile=None, beta = 10, lr=1e-6, beta_end = None, outlier_mode = None,traindata_path=None, valdata_path=None):
        super(NeuralNetwork, self).__init__()
        self.M, self.N, self.K = M, N, K
        self.fc1 = nn.Linear(M * (N+N+1), 512) # u^S + u^R + λ
        self.fc2 = nn.Linear(512,512)
        self.fc3 = nn.Linear(512, M*K)
        self.sm = nn.Softmax(dim=2)
        self.batch_size = batch_size
        self.batch_num = batch_num
        self.num_workers = num_workers    
        self.tile = tile
        self.beta = beta
        if beta_end != None:
            self.beta_end = beta_end
        else:
            self.beta_end = beta
        self.beta_evolve = torch.logspace(start=float(np.log10(self.beta)), end=float(np.log10(self.beta_end)), steps=1000) # use a fixed steps to move
        self.evolve_step = 0
        self.lr = lr
        self.M = M
        self.N = N
        self.K = K
        self.outlier_mode = outlier_mode
        if outlier_mode:
            data = torch.load('model1out.pt') 
            self.train_v, self.val_v = data['train'].detach().numpy(), data['val'].detach().numpy()
            self.train_idx = self.train_v[:,0] > 0.99
            self.val_idx = self.val_v[:,0] > 0.99
        self.traindata_path = traindata_path
        self.valdata_path = valdata_path
            
        
    def forward(self, x):
        out = self.fc1(x)
        out = F.relu(out)
        out = self.fc2(out)
        out = F.relu(out)
        out = self.fc3(out)
        out = out.view(-1,self.M,self.K)
        out = self.sm(out)
        out = out.view(-1,self.M*self.K)
        return out
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        if batch_idx % 100 == 0: self.evolve_step = min(self.evolve_step + 1 , 999)
        loss = self.loss_fn(x,y_hat,beta = float(self.beta_evolve[self.evolve_step]), reduce=True)
        rmse, entropy = torch.tensor(0.0), torch.tensor(0.0)
        rmsel = nn.MSELoss()
        rmse = torch.sqrt(rmsel(y, y_hat))
        loss_arg = self.loss_fn_argmax(x,y_hat, reduce=True)
        meanbeta = torch.mean(self.train_beta[batch_idx * self.batch_size : (batch_idx+1) * self.batch_size])
        self.log("performance", {"iter": batch_idx, "loss": loss, "rmse": rmse, "meanbeta":meanbeta, 'argloss':loss_arg, 'beta':self.beta_evolve[self.evolve_step]})
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(x,y_hat,beta = float(self.beta_evolve[self.evolve_step]), reduce=True)
        rmse, entropy = torch.tensor(0.0), torch.tensor(0.0)
        loss_arg = self.loss_fn_argmax(x,y_hat, reduce=True)
        self.log("performance", {"iter": batch_idx, "val-loss": loss, "rmse": rmse, "entropy":F.cross_entropy(y_hat, y), 'argloss':loss_arg})
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    
    def train_dataloader(self):
        if os.path.exists(self.traindata_path):
            data = torch.load(self.traindata_path)
            x, y = data['prob_x'], data['prob_y']
        else:
            print('check train data path:', self.traindata_path)
            sys.exit(0)

        self.train_beta = torch.ones(self.batch_num * self.batch_size) * self.beta_end  ### beta of training 
        self.train_beta = self.train_beta.reshape(-1,1,1,1).to(device='cuda:0')
        ds = torch.utils.data.TensorDataset(x,y)
        dl = torch.utils.data.DataLoader(ds, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        
        if self.outlier_mode:
            x = x[self.train_idx,:]
            y = y[self.train_idx,:]
            print('in train, ',x.size(), y.size())
        return dl    

    
    def val_dataloader(self):
        if os.path.exists(self.valdata_path):
            data = torch.load(self.valdata_path)
            x, y = data['prob_x'], data['prob_y']
        else:
            print('check val data path:', self.valdata_path)
            sys.exit(0)
        ds = torch.utils.data.TensorDataset(x,y)
        dl = torch.utils.data.DataLoader(ds, batch_size=x.size()[0], shuffle=False, num_workers=1)
        return dl
        
    def loss_fn_single(self,x,pred, beta=10):   
        M,N,K = self.M, self.N, self.K
        x = x.reshape(M,N+N+1)
        uR, uS, lam = x[:,:N], x[:,N:2*N], x[:,2*N:2*N+1]
        pi = pred.reshape(M,K) 
        # Equation 2
        def get_mu(m,k):
            mu = [pi[m1,k] * lam[m1] for m1 in range(M)]
            return mu[m] / sum(mu)
        # Equation 3
        def get_UR(n,k):
            total = 0
            for m in range(M):
                mu = get_mu(m,k)
                total += mu * uR[m,n]
            return total
        # Equation 10
        total = 0
        for m in range(M):
            for k in range(K):
                for n in range(N):
                    UR = [torch.exp(beta*get_UR(n1,k)) for n1 in range(N)]
                    total += lam[m] * pi[m,k] * uS[m,n] * UR[n]/sum(UR)
        return -total
    
    def loss_fn_argmax(self, x, pred, reduce=False):
        M,N,K = self.M, self.N, self.K
        x = x.reshape(-1,M,N+N+1)
        uR = x[:,:,:N].reshape(-1,M,N,1)
        uS = x[:,:,N:2*N].reshape(-1,M,N,1)
        lam = x[:,:,2*N:].reshape(-1,M,1,1)
        pi = pred.reshape(-1,M,1,K)

        mu = pi*lam 
        mu /= torch.maximum( mu.sum(axis=1,keepdim=True), torch.ones_like(mu)*0.001)       
        UR = mu * uR
        UR = UR.sum(axis=1,keepdim=True)

        # sender utility 
        US = mu * uS
        US = US.sum(axis=1,keepdim=True)
        
        # Argmax
        UR_am = UR 
        UR_am += 0.001 * US
        idx = torch.argmax(UR_am, axis=2)
        UR_am = torch.zeros_like(UR_am).scatter_(dim=2, index = idx.unsqueeze(1), src = torch.ones_like(UR_am))
        total = -(lam * pi * uS * UR_am).sum(axis=(1,2,3)) ## Change UR_am to UR_sm if you want to use softmax instead of argmax
        if reduce:
            total = total.mean()
        return total

    def loss_fn(self,x,pred, beta=300, reduce=False, verbose=False, beta_change = False, batch_idx = -1):   
        M,N,K = self.M, self.N, self.K
        x = x.reshape(-1,M,N+N+1)
        uR = x[:,:,:N].reshape(-1,M,N,1).requires_grad_(True)
        uS = x[:,:,N:2*N].reshape(-1,M,N,1)
        lam = x[:,:,2*N:].reshape(-1,M,1,1)
        pi = pred.reshape(-1,M,1,K)
        
        mu = pi*lam
        mu /= torch.maximum( mu.sum(axis=1,keepdim=True), torch.ones_like(mu)*0.001)       
        UR = mu * uR
        UR = UR.sum(axis=1,keepdim=True)
        
        if verbose:
            print(UR.shape)
            print(UR)
        # Softmax
        UR_sm = UR
        if not beta_change:
            UR_sm = F.softmax(beta * UR_sm, dim=2)
            total = -(lam * pi * uS * UR_sm).sum(axis=(1,2,3)) ## Change UR_am to UR_sm if you want to use softmax instead of argmax
        else:
            UR_sm_beta = F.softmax(self.train_beta[batch_idx * self.batch_size : (batch_idx+1) * self.batch_size] * UR_sm, dim=2)
            total = -(lam * pi * uS * UR_sm_beta).sum(axis=(1,2,3)) ## Change UR_am to UR_sm if you want to use softmax instead of argmax
            total = total.mean()
            A = torch.autograd.grad(outputs=total, inputs=uR,retain_graph=True)[0]
            for serach_beta in range(1):
                idx = torch.median(torch.abs(A).reshape(-1,M*N ),dim = (1))[0] <= 1e-3
                tmp_beta = self.train_beta[batch_idx * self.batch_size : (batch_idx+1) * self.batch_size]
                self.train_beta[batch_idx * self.batch_size : (batch_idx+1) * self.batch_size][idx] =  torch.maximum ( tmp_beta[idx]/1.5, torch.ones_like(tmp_beta[idx]) * float(self.beta_evolve[self.evolve_step]))
                self.train_beta[batch_idx * self.batch_size : (batch_idx+1) * self.batch_size][~idx] = torch.maximum ( tmp_beta[~idx], torch.ones_like(tmp_beta[~idx]) * float(self.beta_evolve[self.evolve_step]))
                UR_sm_beta = F.softmax(self.train_beta[batch_idx * self.batch_size : (batch_idx+1) * self.batch_size] * UR_sm, dim=2)
                total = -(lam * pi * uS * UR_sm_beta).sum(axis=(1,2,3)) ## Change UR_am to UR_sm if you want to use softmax instead of argmax
                total = total.mean()
                
        if verbose:
            print('----------')
            print(UR_sm.shape)
            print(UR_sm)
                

        if reduce:
            total = total.mean()
        return total
    