import torch
from torch import nn
from pytorch_lightning import LightningModule
import pytorch_lightning as pl
import numpy as np
import torch.optim as optim
from utils import *

nt, nx, L = 250, 50, 16

class FCNN(nn.Module):
    def __init__(self, depth, width, act, input_dim, output_dim):
        super(FCNN, self).__init__()
        if act=='tanh':
            self.activation=nn.Tanh()
        elif act=='prelu':
            self.activation=nn.PReLU()
        elif act=='relu':
            self.activation=nn.ReLU()
        else:
            print('activation error!!')
        
        self.layer_list = []
        self.layer_list.append(nn.Linear(input_dim, width))
        self.layer_list.append(self.activation)
        for i in range(depth-2):
            self.layer_list.append(nn.Linear(width, width))
            self.layer_list.append(self.activation)
        self.layer_list.append(nn.Linear(width, output_dim))
        self.layer_list = nn.Sequential(*self.layer_list)
        
    def forward(self, input_x):
        output_u = self.layer_list(input_x)
        
        return output_u

def periodic_don(x):
    x *= 2/L * np.pi
    return torch.concat(
        [torch.cos(x), torch.sin(x), torch.cos(2*x), torch.sin(2*x)], 2
    )

class DeepONet(LightningModule):
    def __init__(self, hparams):
        super(DeepONet, self).__init__()
        self.depth_branch = hparams.depth_branch
        self.depth_trunk = hparams.depth_trunk
        self.width = hparams.width
        self.act = hparams.act
        self.time_slice = hparams.time_slice
        self.dim = hparams.dim
        self.basis = hparams.basis
        self.step_train = hparams.step_train
        self.step_test = hparams.step_test
        self.lr = hparams.lr
        self.step_size = hparams.step_size
        self.factor = hparams.factor
        self.seed = hparams.seed
        # Branch and Trunk FCNNs
        self.branch = FCNN(self.depth_branch, self.width, self.act, nx*self.time_slice, self.basis)
        self.trunk = FCNN(self.depth_trunk, self.width, self.act, 4*self.dim + 1, self.basis)
        self.b0 = torch.nn.Parameter(torch.tensor(0.), requires_grad=True)
        self.loss_fn = nn.MSELoss()
        
        self.save_hyperparameters()
        pl.seed_everything(self.seed)

    def forward(self, x_, u_, step):
        basis = self.trunk(torch.concat([x_[:, :, 0:1], periodic_don(x_[:, :, 1:2])], 2)) # x[:, :, 0]: time , x[:, :, 1]:position -->periodicity
        weights = self.branch(u_)
        weights = weights.unsqueeze(dim=1).repeat(1, self.time_slice * step * nx, 1)
        output = torch.einsum('bij, bij->bi', basis, weights)
        return output


    def training_step(self, batch, batch_idx):
        num_batch = batch['u'].size()[0]
        t, pos, u0, u = self.prepare_data_for_step(batch, num_batch, self.step_train)
        grid = torch.cat((t, pos), dim=2)
        
        output = self(grid, u0, self.step_train).reshape(num_batch, -1)
        loss = self.loss_fn(output, u)
        rel_error = torch.mean(rel_L2_error(output.detach(), u.detach()))

        self.log('train_loss', loss , prog_bar=True)
        self.log('train_rel_error', rel_error, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        num_batch = batch['u'].size()[0]
        t, pos, u0, u = self.prepare_data_for_step(batch, num_batch, self.step_train)
        grid = torch.cat((t, pos), dim=2)
        
        output = self(grid, u0, self.step_train).reshape(num_batch, -1)
        loss = self.loss_fn(output, u)
        rel_error = torch.mean(rel_L2_error(output, u))

        self.log('val_loss', loss, prog_bar=True)
        self.log('val_rel_error', rel_error, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        num_batch = batch['u'].size()[0]
        t, pos, u0, u = self.prepare_data_for_step(batch, num_batch, self.step_test)
        grid = torch.cat((t, pos), dim=2)
        
        output = self(grid, u0, self.step_test).reshape(num_batch, -1)
        loss = self.loss_fn(output, u)
        rel_error = torch.mean(rel_L2_error(output, u))

        self.log('test_loss', loss, prog_bar=True)
        self.log('test_rel_error', rel_error, prog_bar=True)
        return {'test_loss': loss, 'test_rel_error': rel_error}

    def prepare_data_for_step(self, batch, num_batch, step):
        t, pos = batch['t'][:, self.time_slice:self.time_slice*(step+1)].to(self.device), batch['x'].to(self.device)
        t, pos = t.to(torch.float), pos.to(torch.float)
        u0, u = batch['u'][:, :, :self.time_slice].to(self.device), batch['u'][:, :, self.time_slice:self.time_slice*(step+1)].to(self.device)
        u0, u = u0.permute(0, 2, 1), u.permute(0, 2, 1)
        u0, u = u0.reshape(num_batch,-1).to(torch.float), u.reshape(num_batch,-1).to(torch.float)
        t = t.view(num_batch, -1, 1)
        pos = pos.view(num_batch, -1, 1)
        t = t.repeat(1, 1, nx).view(num_batch, -1, 1)
        pos = pos.repeat(1, self.time_slice*step, 1)
        return t, pos, u0, u

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.lr)
#        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 200], gamma=0.8)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.factor)
        return [optimizer], [scheduler]