import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import scipy.io as scio
import os 
import time
import csv
from Datagene import *
from Human_Model import *
from Policy_module import *
import sys

class policyNN_single(pl.LightningModule):
    def __init__(self, M=2, N=2, K=2, batch_size=2**6, batch_num=100000, num_workers=8, tile=None, beta = 10, 
                 lr=1e-6, beta_end = None,traindata_path=None, valdata_path=None,Nsteps=100, fc_size = 512):
        super(policyNN_single, self).__init__()
        self.M, self.N, self.K = M, N, K
        self.fc1 = nn.Linear(M * (N+N+1), fc_size) # u^S + u^R + λ
        self.fc2 = nn.Linear(fc_size,fc_size)
        self.fc3 = nn.Linear(fc_size, M*K)
        self.fc4 = nn.Linear(fc_size,fc_size)
        self.sm = nn.Softmax(dim=2)
        
        self.batch_size = batch_size
        self.batch_num = batch_num
        self.num_workers = num_workers    
        self.tile = tile
        self.beta = beta
        if beta_end != None:
            self.beta_end = beta_end
        else:
            self.beta_end = beta
        self.beta_evolve = torch.logspace(start=float(np.log10(self.beta)), end=float(np.log10(self.beta_end)), steps=Nsteps) # use a fixed steps to move
        self.evolve_step = 0
        self.lr = lr
        self.M = M
        self.N = N
        self.K = K
        self.traindata_path = traindata_path
        self.valdata_path = valdata_path
            
        
    def forward(self, x):
        out = self.fc1(x)
        out = F.relu(out)
        out = self.fc2(out)
        out = F.relu(out)
        out = self.fc4(out)
        out = F.relu(out)
        out = self.fc3(out)
        out = out.view(-1,self.M,self.K)
        out = self.sm(out)
        out = out.view(-1,self.M*self.K)
        return out
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        if batch_idx % 100 == 0: self.evolve_step = min(self.evolve_step + 1 , 99)
        loss = self.loss_fn(x,y_hat,beta = float(self.beta_evolve[self.evolve_step]), reduce=True)
        rmse, entropy = torch.tensor(0.0), torch.tensor(0.0)
        rmsel = nn.MSELoss()
        loss_arg = self.loss_fn_argmax(x,y_hat, reduce=True)
        meanbeta = 0
        self.log("performance", {"iter": batch_idx, "loss": loss, "rmse": 0, "meanbeta":meanbeta, 'argloss':loss_arg, 'beta':self.beta_evolve[self.evolve_step]})
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(x,y_hat,beta = float(self.beta_evolve[self.evolve_step]), reduce=True)
        rmse, entropy = torch.tensor(0.0), torch.tensor(0.0)
        loss_arg = self.loss_fn_argmax(x,y_hat, reduce=True)
        self.log("performance", {"iter": batch_idx, "val-loss": loss, "rmse": rmse, "entropy":0, 'argloss':loss_arg})
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def train_dataloader(self):
        if os.path.exists(self.traindata_path):
            data = torch.load(self.traindata_path)
            x, y = data['prob_x'], data['prob_y']
        else:
            print('check train data path:', self.traindata_path)
            sys.exit(0)
        ds = torch.utils.data.TensorDataset(x,y)
        dl = torch.utils.data.DataLoader(ds, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        return dl    

    
    def val_dataloader(self):
        if os.path.exists(self.valdata_path):
            data = torch.load(self.valdata_path)
            x, y = data['prob_x'], data['prob_y']
        else:
            print('check val data path:', self.valdata_path)
            sys.exit(0)
        ds = torch.utils.data.TensorDataset(x,y)
        dl = torch.utils.data.DataLoader(ds, batch_size=x.size()[0], shuffle=False, num_workers=1)
        return dl
        
    
    def loss_fn_argmax(self, x, pred, reduce=False):
        M,N,K = self.M, self.N, self.K
        x = x.reshape(-1,M,N+N+1)
        uR = x[:,:,:N].reshape(-1,M,N,1)
        uS = x[:,:,N:2*N].reshape(-1,M,N,1)
        lam = x[:,:,2*N:].reshape(-1,M,1,1)
        pi = pred.reshape(-1,M,1,K)

        mu = pi*lam 
        mu /= torch.maximum( mu.sum(axis=1,keepdim=True), torch.ones_like(mu)*0.001)       
        UR = mu * uR
        UR = UR.sum(axis=1,keepdim=True)

        # sender utility 
        US = mu * uS
        US = US.sum(axis=1,keepdim=True)
        
        # Argmax
        UR_am = UR 
        UR_am += 0.001 * US
        idx = torch.argmax(UR_am, axis=2)
        UR_am = torch.zeros_like(UR_am).scatter_(dim=2, index = idx.unsqueeze(1), src = torch.ones_like(UR_am))
        total = -(lam * pi * uS * UR_am).sum(axis=(1,2,3)) ## Change UR_am to UR_sm if you want to use softmax instead of argmax
        if reduce:
            total = total.mean()
        return total

    def loss_fn(self,x,pred, beta=300, reduce=False, verbose=False, beta_change = False, batch_idx = -1):   
        M,N,K = self.M, self.N, self.K
        x = x.reshape(-1,M,N+N+1)
        uR = x[:,:,:N].reshape(-1,M,N,1).requires_grad_(True)
        uS = x[:,:,N:2*N].reshape(-1,M,N,1)
        lam = x[:,:,2*N:].reshape(-1,M,1,1)
        pi = pred.reshape(-1,M,1,K)
        
        mu = pi*lam
        mu /= torch.maximum( mu.sum(axis=1,keepdim=True), torch.ones_like(mu)*0.001)       
        UR = mu * uR
        UR = UR.sum(axis=1,keepdim=True)
        
        if verbose:
            print(UR.shape)
            print(UR)
        # Softmax
        UR_sm = UR
        if not beta_change:
            UR_sm = F.softmax(beta * UR_sm, dim=2)
            total = -(lam * pi * uS * UR_sm).sum(axis=(1,2,3)) ## Change UR_am to UR_sm if you want to use softmax instead of argmax
        else:
            print(' this function should no longer be called')
            pass
                
        if verbose:
            print('----------')
            print(UR_sm.shape)
            print(UR_sm)
                
        if reduce:
            total = total.mean()
        return total
    
    
bs = 1000
batch_num = 16
batch_size = 64
M = int(sys.argv[1])
N = int(sys.argv[2])
K = N
path = ''
tstart = time.time()
trainp = path + f'onereceiver/train_dataM={M}N={N}K={K}.pt'
valp = path + f'onereceiver/val_dataM={M}N={N}K={K}.pt'
pl.seed_everything(42, workers=True)
for lr in [0.001,0.002]:
    for fc_size in [256,512,1024]:
        policynn = policyNN_single(M=M, N=N, K=K, batch_size=batch_size, batch_num=batch_num, num_workers=10, tile=None, 
                                      beta = 10, lr=lr, beta_end = 200,Nsteps=100,traindata_path=trainp, valdata_path=valp)
        csv_logger = CSVLogger("./logs", name="policy_onereceiver", version=f'policy_4fc_M{M}N{N}K{K}lr{lr}fc{fc_size}')
        trainer = pl.Trainer(accelerator='gpu' , deterministic=False, max_epochs=200,check_val_every_n_epoch=10,
                             logger=csv_logger,enable_progress_bar = False)
        print('training begin!')
        history = trainer.fit(policynn)
        print('training finished! time cost: ', time.time()-tstart, ' for parameter:(M,lr,fc_size)=', M,lr,fc_size)
        torch.save(policynn.state_dict(), f'onereceiver/PolicyNN_4fc_M{M}N{N}K{K}lr{lr}fc{fc_size}.pt')
    