import torch
import pytorch_lightning as pl
import numpy as np
import torch.optim as optim

from utils import *
from torch import nn
from torch_geometric.nn import MessagePassing, aggr, InstanceNorm, radius_graph, knn_graph

nt, nx, L = 250, 50, 16 # Plot for 1d
#nt, nx, L = 100, 32, 5 # Plot for 2d

def periodic(x):
    x *= 2/L * np.pi
    return torch.concat(
        [torch.cos(x), torch.sin(x), torch.cos(2*x), torch.sin(2*x)], 1
    )


class Model(MessagePassing):
    def __init__(self, gamma, phi):
        super(Model, self).__init__(aggr='mean', flow='target_to_source')
        self.gamma = gamma
        self.phi = phi
        self.norm = InstanceNorm(128)
        
    def forward(self, u, edge_index, rel_pos, batch):
        x = self.propagate(edge_index, u=u, rel_pos=rel_pos)
        x = self.norm(x, batch)
        
        return x

    def message(self, u_i, u_j, rel_pos):
        phi_input = torch.cat([u_i, u_j-u_i, rel_pos], dim=1)
        return self.phi(phi_input)
    
    def update(self, aggr, u):
        gamma_input = torch.cat([u, aggr], dim=1)
        dudt = self.gamma(gamma_input)
        return dudt


class FCNN(nn.Module):
    def __init__(self, depth, width, act, input_dim, output_dim):
        super(FCNN, self).__init__()
        if act=='tanh':
            self.activation=nn.Tanh()
        elif act=='prelu':
            self.activation=nn.PReLU()
        elif act=='relu':
            self.activation=nn.ReLU()
        elif act=='leaky':
            self.activation=nn.LeakyReLU()
        elif act=='swish':
            self.activation=Swish()
        else : 
            self.activation=nn.LayerNorm(output_dim)
        
        if (depth==1):
            self.layer_list = nn.Sequential(nn.Linear(input_dim, output_dim))
            
        if depth == 1:
            self.layer_list = nn.Sequential(nn.Linear(input_dim, output_dim))
        else:
            self.layer_list = []
            self.layer_list.append(nn.Linear(input_dim, width))
            self.layer_list.append(self.activation)
            for i in range(depth-2):
                self.layer_list.append(nn.Linear(width, width))
                self.layer_list.append(self.activation)
            self.layer_list.append(nn.Linear(width, output_dim))
            self.layer_list = nn.Sequential(*self.layer_list)

            
    def forward(self, input_x):
        output_u = self.layer_list(input_x)
        
        return output_u


# Assume the necessary import statements and class definitions are here.
class GraphDeepONet(pl.LightningModule):
    def __init__(self, hparams):
        super(GraphDeepONet, self).__init__()
        self.lr=hparams.lr
        self.weight_decay=hparams.weight_decay
        self.step_size=hparams.step_size
        self.factor=hparams.factor
        self.time_slice=hparams.time_slice
        self.neighbors=hparams.neighbors
        self.depth_trunk=hparams.depth_trunk
        self.depth_enc=hparams.depth_enc
        self.depth_branch=hparams.depth_branch
        self.depth_message=hparams.depth_message
        self.depth_aggr=hparams.depth_aggr
        self.width=hparams.width
        self.width_enc=hparams.width_enc
        self.basis=hparams.basis
        self.act=hparams.act
        self.seed=hparams.seed
        self.dim=hparams.dim
        self.step_train=hparams.step_train
        self.step_test=hparams.step_test

        self.save_hyperparameters()
        # Set the seed for reproducibility
        pl.seed_everything(self.seed)

        self.enc = FCNN(self.depth_enc, self.width, self.act, self.time_slice, self.width_enc)
        self.enc_ = FCNN(self.depth_enc, self.width, self.act, self.width_enc+self.dim+1, self.width_enc)
                
        branch_list = []
        
        for i in range(self.depth_branch):
            gamma = FCNN(self.depth_message, self.width, self.act, self.width+self.width_enc, self.width_enc)
            phi = FCNN(self.depth_message, self.width, self.act, 2*self.width_enc+self.dim, self.width)
            branch_list.append(Model(gamma, phi))
        
        self.branch_list = nn.ModuleList(branch_list)          
        self.Trunk = FCNN(self.depth_trunk, self.width, self.act, 4*self.dim, self.basis)
                                            
        self.use_bias = False
        if self.use_bias:
            self.b = torch.nn.Parameter(torch.zeros(1))

        self.gate_nn = FCNN(self.depth_aggr, self.width, self.act, self.width_enc+self.dim, self.basis)
        self.nn = FCNN(self.depth_aggr, self.width, self.act, self.width_enc+self.dim, self.basis)     
        self.aggr = aggr.AttentionalAggregation(self.gate_nn, self.nn)
        self.time = FCNN(self.depth_aggr, self.width, self.act, 1, self.basis)
        
        self.loss_fn = nn.MSELoss()
    
    
    def forward(self, grid, sensor, t, edge_index, rel_pos, batch_ind, mode = 'train'):
        # sensor : [B*Nx, Time_window+1], grid : [B*Nx, 1]
        B = batch_ind.max()+1
        Nx = sensor.shape[0]//B  # [B*Nx, Time_Window+1]
        sensor_pos = sensor[:, -self.dim:]
        coeff = self.enc(sensor[:, :-self.dim])
        basis = self.Trunk(periodic(grid))

        repeated_t = t.unsqueeze(0).unsqueeze(0).repeat(B, 1, 1)
        # Preinitialize tensors
        coeff_mean_list = torch.zeros((B, self.time_slice, self.basis), device=sensor.device) 
                
        if mode=='train':
            step = self.step_train+1  
        elif mode=='valid':
            step = self.step_train
        else: 
            step = self.step_test
        coeff_list = torch.zeros((B, step*self.time_slice, self.basis), device=sensor.device)  
        coeff_mean = torch.concat([coeff, sensor_pos], dim=1)
        coeff_mean = self.aggr(coeff_mean, batch_ind)
        
        if mode=='train':
            for h in range(self.time_slice): 
                time = self.time(repeated_t[:,:,h])
                coeff_mean_list[:, h, :] = coeff_mean*time 
        
            coeff_list[:, 0:self.time_slice, :] = coeff_mean_list
            step -= 1

        for i in range(step):
            coeff_dif = coeff.clone()
#            if time_dep_term:
#                coeff_dif = torch.concat([coeff_dif, sensor_pos, t[::num_input][i].unsqueeze(0).unsqueeze(1).repeat(B*Nx, 1)], dim=1)
#                coeff_dif = self.enc_(coeff_dif)            
            for branch in self.branch_list:
                coeff_dif = branch(coeff_dif, edge_index, rel_pos, batch_ind)
                
            coeff = coeff + coeff_dif

            coeff_mean = torch.concat([coeff, sensor_pos], dim=1)
            coeff_mean = self.aggr(coeff_mean, batch_ind)
            for h in range(self.time_slice): 
                time = self.time(repeated_t[:,:,h+(i+1)*self.time_slice])
                coeff_mean_list[:, h, :] = coeff_mean*time

            if mode=='train':
                coeff_list[:, (i+1)*self.time_slice:(i+2)*self.time_slice, :] = coeff_mean_list

            else:
                coeff_list[:, i*self.time_slice:(i+1)*self.time_slice, :] = coeff_mean_list
        if mode=='train':
            step +=1
            
        coeff = coeff_list.reshape(-1, 1, step*self.time_slice, self.basis).repeat(1, Nx, 1, 1)
        basis = basis.reshape(B, -1, 1, self.basis).repeat(1, 1, step*self.time_slice, 1)

        u = torch.einsum("bijk,bijk->bij", coeff, basis)        
        
        if self.use_bias:
            u += self.b.to(u.device)
        
        return u
    
    def shared_step(self, batch, step, mode='train'):
        M = nx**self.dim
        u0, u, t, edge_index, pos, batch_ind = batch.u0, batch.u, batch.t, batch.edge_index, batch.pos, batch.batch
        pos = pos.type(torch.float)
        
        with torch.no_grad():
            rel_pos = pos[edge_index[1]] - pos[edge_index[0]]
        rel_pos = rel_pos.type(torch.float)
        output = self.forward(pos, u0, t, edge_index, rel_pos, batch_ind, mode).reshape(len(batch.sim_ind), M * step * self.time_slice)
        loss = self.loss_fn(u.reshape(len(batch.sim_ind), M * step * self.time_slice), output)
        rel_error = rel_L2_error(u.reshape(len(batch.sim_ind), M * step * self.time_slice), output)
        rel_error = torch.mean(rel_error)
        
        return loss, rel_error
    
    
    def training_step(self, batch, batch_idx):
        step = self.step_train+1 
        loss, rel_error = self.shared_step(batch, step, mode='train')
        self.log('train_loss', loss, prog_bar=True, batch_size=batch.size(0))
        self.log('train_rel_error', rel_error, prog_bar=True, batch_size=batch.size(0))

        return loss
    
    def validation_step(self, batch, batch_idx):
        step = self.step_train
        loss, rel_error = self.shared_step(batch, step, mode='valid')
        self.log('val_loss', loss, prog_bar=True, batch_size=batch.size(0))
        self.log('val_rel_error', rel_error, prog_bar=True, batch_size=batch.size(0))
        return loss
    
    def test_step(self, batch, batch_idx):
        step = self.step_test
        loss, rel_error = self.shared_step(batch, step, mode='test')
        self.log('test_loss', loss, prog_bar=True, batch_size=batch.size(0))
        self.log('test_rel_error', rel_error, prog_bar=True, batch_size=batch.size(0))
        return {'test_loss': loss, 'test_rel_error': rel_error}
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.factor)
        return [optimizer], [scheduler]