import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import scipy.io as scio
import os 
import time
import csv
from Datagene import *
from Human_Model import *
from Policy_module import *
import sys

class PolicyNN_nreceiver_simplified(pl.LightningModule):
    ### n receiver case, signal only contain 2 case
    
    def __init__(self, N=2, batch_size=2**6, batch_num=100000, num_workers=8, tile=None, beta = 10, fc_size=512,
                 lr=1e-6, beta_end = None, human_model = None,traindata_path=None, valdata_path=None, Nsteps=100):
        super(PolicyNN_nreceiver_simplified, self).__init__()
        self.N = N
        self.fc1 = nn.Linear(2 *( N+1), fc_size) # u^S + u^R + λ
        self.fc2 = nn.Linear(fc_size,fc_size)
        self.fc4 = nn.Linear(fc_size,fc_size)
        self.fc3 = nn.Linear(fc_size, 2*2 )
        self.sm = nn.Softmax(dim=2)
        
        self.batch_size = batch_size
        self.batch_num = batch_num
        self.num_workers = num_workers    
        self.tile = tile
        self.beta = beta
        if beta_end != None:
            self.beta_end = beta_end
        else:
            self.beta_end = beta
        self.beta_evolve = torch.logspace(start=float(np.log10(self.beta)), end=float(np.log10(self.beta_end)), steps=Nsteps) # use a fixed steps to move
        self.evolve_step = 0
        self.lr = lr
        self.N = N
        self.human_model = human_model
        self.traindata_path = traindata_path
        self.valdata_path = valdata_path
    
    def forward(self, x):
        out = self.fc1(x.float())
        out = F.relu(out)
        out = self.fc2(out)
        out = F.relu(out)
        out = self.fc4(out)
        out = F.relu(out)
        out = self.fc3(out)
        out = out.view(-1,2,2)
        out = self.sm(out)
        out = out.view(-1,2*2)
        return out
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        if batch_idx % 10000 == 0: self.evolve_step = min(self.evolve_step + 1 , 99)
        loss = self.loss_fn_nreceiver(x, y_hat, beta = self.beta_evolve[self.evolve_step], reduce=True, batch_idx=batch_idx)
        torch.autograd.set_detect_anomaly(True)
        rmse, entropy = torch.tensor(0.0), torch.tensor(0.0)
        loss_arg = torch.tensor(0.0)
        meanbeta = torch.mean(self.train_beta[batch_idx * self.batch_size : (batch_idx+1) * self.batch_size])
        self.log("performance", {"iter": batch_idx, "loss": loss, "rmse": 0, "meanbeta":meanbeta, 'argloss':loss_arg,
                                 'beta':self.beta_evolve[self.evolve_step]})
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn_nreceiver(x, y_hat, beta = self.beta_evolve[self.evolve_step], reduce=True)
        rmse, entropy = torch.tensor(0.0), torch.tensor(0.0)
        loss_arg = 0
        self.log("performance", {"iter": batch_idx, "val-loss": loss, "rmse": rmse, 
                                 "entropy":0, 'argloss':loss_arg})
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def train_dataloader(self):
        if os.path.exists(self.traindata_path):
            data = torch.load(self.traindata_path)
            x, y = data['prob_x'], data['prob_y']
        else:
            print('check train data path:', self.traindata_path)
            sys.exit(0)
        self.train_beta = torch.ones(self.batch_num * self.batch_size) * self.beta_end  ### beta of training 
        self.train_beta = self.train_beta.reshape(-1,1,1,1).to(device='cuda:0')
        ds = torch.utils.data.TensorDataset(x,y)
        dl = torch.utils.data.DataLoader(ds, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        return dl
    
    def val_dataloader(self):
        if os.path.exists(self.valdata_path):
            data = torch.load(self.valdata_path)
            x, y = data['prob_x'], data['prob_y']
        else:
            print('check val data path:', self.valdata_path)
            sys.exit(0)
        ds = torch.utils.data.TensorDataset(x,y)
        dl = torch.utils.data.DataLoader(ds, batch_size=x.size()[0], shuffle=False, num_workers=1)
        return dl
    
    def loss_fn_nreceiver(self, x, pred, beta = 1.0, reduce=False, batch_idx=-1):
        N = self.N
        x = x.reshape(-1,2,N+1)
        uR =  x[:,:,:N].reshape(-1,2,N,1)
        lam = x[:,:,-1].reshape(-1,2,1,1)
        uS = torch.arange(N+1,dtype=float)/(N)
        pi = pred.reshape(-1,2,1,2)
        
        mu = pi * lam 
        mu /= torch.maximum( mu.sum(axis=1,keepdim=True), torch.ones_like(mu)*0.001) ### bs 2 1 2**N 
        UR = mu * uR  ### bs 2 N 2**N  
        UR_sm = F.softmax(beta * UR, dim=1) # (bs,2,N,2**N)
        total = -((pi * lam).sum(axis=1).reshape(-1,1,1,2) * UR_sm[:,1,:,:].reshape(-1,1,N,2)).sum(axis=(1,2,3))/float(N)
        if reduce:
            total = total.mean()   
        return total
    
    
    
    
bs = 1000
batch_num = 10
batch_size = 1024

M = int(sys.argv[1])

trainp = 'nreceivers/train_dataN=' + str(M) + '.pt'
valp = 'nreceivers/val_dataN='+ str(M) + '.pt'
pl.seed_everything(42, workers=True)


for lr in [0.001,0.002,0.005,0.01,0.02]:
    for fc_size in [256,512,1024]:
        policynn = PolicyNN_nreceiver_simplified(N=M, batch_size=batch_size, batch_num=batch_num, num_workers=10, tile=None, 
                                      beta = 10, fc_size=fc_size,lr=lr, beta_end = 200, human_model = None,Nsteps=100,
                                      traindata_path=trainp, valdata_path=valp)
        csv_logger = CSVLogger("./logs", name="policy_nreceivers", version=f'policy4fc_n={M}_lr={lr}_fc={fc_size}')
        trainer = pl.Trainer(accelerator='gpu' , deterministic=False, max_epochs=200,check_val_every_n_epoch=2,
                             logger=csv_logger,enable_progress_bar = True)
        tstart = time.time()
        print('training begin!')
        history = trainer.fit(policynn)
        print('training finished! time cost: ', time.time()-tstart)
        torch.save(policynn.state_dict(), f'nreceivers/PolicyNN4fc_n={M}_lr={lr}_fc={fc_size}.pt')