import pytorch_lightning as pl
import numpy as np 
import torch
import torch.optim as optim

from torch import nn 
from torch.nn import Softmax
from pytorch_lightning import LightningModule
from torch_geometric.nn import MessagePassing, aggr, InstanceNorm
from utils import *

nt, nx, L = 250, 50, 16

def periodic_don(x):
    x *= 2/L * np.pi
    return torch.concat(
        [torch.cos(x), torch.sin(x), torch.cos(2*x), torch.sin(2*x)], 2
    )

class FCNN(nn.Module):
    def __init__(self, depth, width, act, input_dim, output_dim):
        super(FCNN, self).__init__()
        if act=='tanh':
            self.activation=nn.Tanh()
        elif act=='prelu':
            self.activation=nn.PReLU()
        elif act=='relu':
            self.activation=nn.ReLU()
        else:
            print('activation error!!')
        
        self.layer_list = []
        self.layer_list.append(nn.Linear(input_dim, width))
        self.layer_list.append(self.activation)
        for i in range(depth-2):
            self.layer_list.append(nn.Linear(width, width))
            self.layer_list.append(self.activation)
        self.layer_list.append(nn.Linear(width, output_dim))
        self.layer_list = nn.Sequential(*self.layer_list)
        
    def forward(self, input_x):
        output_u = self.layer_list(input_x)
        
        return output_u
    
    
class VIDON(pl.LightningModule):
    def __init__(self, hparams):
        super(VIDON, self).__init__()
        self.depth_branch=hparams.depth_branch
        self.depth_trunk=hparams.depth_trunk
        self.depth_enc=hparams.depth_enc
        self.depth_aggr=hparams.depth_aggr
        self.width_branch=hparams.width_branch
        self.width_trunk=hparams.width_trunk
        self.width_enc=hparams.width_enc
        self.width_aggr=hparams.width_aggr
        self.act=hparams.act
        self.dim=hparams.dim
        self.basis=hparams.basis
        self.head=hparams.head
        self.time_slice=hparams.time_slice
        self.step_train=hparams.step_train
        self.step_test=hparams.step_test
        self.lr=hparams.lr
        self.factor=hparams.factor
        self.step_size=hparams.step_size
        
        self.save_hyperparameters()
        
        self.loss_fn = nn.MSELoss()
        
        self.enc_value = FCNN(self.depth_enc, self.width_enc, self.act, 1, self.width_enc)
        self.enc_pos = FCNN(self.depth_enc, self.width_enc, self.act, self.dim, self.width_enc)
        self.enc_t = FCNN(self.depth_enc, self.width_enc, self.act, 1, self.width_enc)
        
        self.gate_nn = []
        self.nn = []
        
        for i in range(self.head):
            self.gate_nn.append(FCNN(self.depth_aggr, self.width_aggr, self.act, self.width_enc, self.basis))
            self.nn.append(FCNN(self.depth_aggr, self.width_aggr, self.act, self.width_enc, self.basis))
            
        self.gate_nn, self.nn= nn.ModuleList(self.gate_nn), nn.ModuleList(self.nn)    
        
        self.branch = FCNN(self.depth_branch, self.width_branch, self.act, self.head*self.basis, self.basis)
        self.trunk = FCNN(self.depth_trunk, self.width_trunk, self.act, 4*self.dim+1, self.basis)
        self.b0 = torch.nn.Parameter(torch.tensor(0.), requires_grad=True)
        self.num_basis = self.basis
        
    def forward(self, grid, sensor_pos, sensor_t, u0, step):
        basis = self.trunk(torch.concat([grid[:, :, 0:1], periodic_don(grid[:, :, 1:2])], 2))
        encode = self.enc_pos(sensor_pos)+self.enc_t(sensor_t) + self.enc_value(u0)
        
        for gate_nn, nn in zip(self.gate_nn, self.nn):
            aggr = Softmax(dim=1)(gate_nn(encode))*nn(encode)
            aggr = torch.sum(aggr, dim=1)

            try:
                aggr_list = torch.concat([aggr_list, aggr], dim=-1)
            except:
                aggr_list = aggr
                
        weights = self.branch(aggr_list)
        weights = weights.unsqueeze(dim=1).repeat(1, self.time_slice*step*nx, 1)
        output=torch.einsum('bij, bij->bi', basis, weights)
        
        return output
    
    def process_batch(self, batch, step):
        num_batch = batch['u'].size()[0]
        t, pos = batch['t'][:, self.time_slice:self.time_slice*(step+1)].to(self.device), batch['x'].to(self.device)
        t, pos = t.to(torch.float), pos.to(torch.float)
        u0, u = batch['u'][:, :, :self.time_slice], batch['u'][:, :, self.time_slice:self.time_slice*(step+1)]
        u0, u = u0.to(torch.float), u.to(torch.float)
        u0 = u0.permute(0, 2, 1).reshape(num_batch, -1, 1)
        u = u.permute(0, 2, 1).reshape(num_batch, -1)

        sensor_pos = pos.repeat(1, self.time_slice, 1)
        sensor_t = batch['t'][:, :self.time_slice].to(torch.float).unsqueeze(dim=2)
        sensor_t = sensor_t.repeat(1, 1, nx).reshape(num_batch, -1, 1)
        sensor_grid = torch.cat([sensor_pos, sensor_t], dim=2)

        t = t.view(num_batch, -1, 1)
        pos = pos.view(num_batch, -1, 1)
        t = t.repeat(1, 1, nx).view(num_batch, -1, 1)
        pos = pos.repeat(1, self.time_slice * step, 1)
        grid = torch.cat((t, pos), dim=2)

        output = self(grid, sensor_pos, sensor_t, u0, step).reshape(num_batch, -1)
        loss = self.loss_fn(output, u)
        return loss, output, u

    def training_step(self, batch, batch_idx):
        loss, output, u = self.process_batch(batch, self.step_train)
        train_rel_error = torch.mean(rel_L2_error(output.detach().cpu(), u.detach().cpu()))
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_rel_error', train_rel_error, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, output, u = self.process_batch(batch, self.step_train)
        val_rel_error = torch.mean(rel_L2_error(output, u))
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_rel_error', val_rel_error, prog_bar=True)
        return loss 
    
    def test_step(self, batch, batch_idx):
        loss, output, u = self.process_batch(batch, self.step_test)
        test_rel_error = torch.mean(rel_L2_error(output, u))
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_rel_error', test_rel_error, prog_bar=True)
        return {'test_loss': loss, 'test_rel_error': test_rel_error}
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
#        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 50, 100, 150], gamma=0.8)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.factor)
        return [optimizer], [scheduler]