"""
Poisson equation trainers for TensorGalerkin
"""

import os
import time
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from copy import deepcopy
from functools import partial
from itertools import product
from scipy.interpolate import griddata
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR

from ..core.base import TrainerBase
from ..core.config import PoissonConfig
from ..utils import save_ckpt, load_ckpt, apply_zero_boundary, manual_seed
from ..models import plain_GNN
from ..datasets import MeshGen, PoissonGen
from ..equations import PoissonEquation


class VaryfPoissonTrainer(TrainerBase):
    """
    Trainer for variable-f Poisson equation problems
    """
    
    def __init__(self, config):
        """Initialize the trainer"""
        super().__init__(config)
        
        # Set up tensorboard if requested
        if config.use_tensorboard:
            self.writer = SummaryWriter(log_dir="training_log")
            self.tags = ["train_loss", "data_error", "learning_rate"]

    def init_dataset(self):
        """Initialize dataset for variable-f Poisson problems"""
        class F:
            """Source function f(x,y) for Poisson equation"""
            def __init__(self, a, r=2.0, device='cpu'):
                self.r = r 
                K = a.shape[0]
                i, j = torch.meshgrid(torch.arange(1, K+1), torch.arange(1, K+1))
                a, i, j = a[None, :, :], i[None, :, :], j[None, :, :]
                self.a, self.i, self.j = a.to(device), i.to(device), j.to(device)
              
            def __call__(self, points):
                if points.device != self.a.device:
                    a, i, j, r = self.a.to(points.device), self.i.to(points.device), self.j.to(points.device), self.r
                else:
                    a, i, j, r = self.a, self.i, self.j, self.r

                x, y = points[:, 0][:, None, None], points[:, 1][:, None, None]
                K = a.shape[1]
                return np.pi/K/K * (a * (i*i + j*j)**r * torch.sin(np.pi*i*x) * torch.sin(np.pi*j*y)).sum((-2,-1))

        class Sol:
            """Analytical solution for Poisson equation"""
            def __init__(self, a, r=2.0, device='cpu'):
                self.r = r 
                K = a.shape[0]
                i, j = torch.meshgrid(torch.arange(1, K+1), torch.arange(1, K+1))
                a, i, j = a[None, :, :], i[None, :, :], j[None, :, :]
                self.a, self.i, self.j = a.to(device), i.to(device), j.to(device)
                
            def __call__(self, points):
                if points.device != self.a.device:
                    a, i, j, r = self.a.to(points.device), self.i.to(points.device), self.j.to(points.device), self.r
                else:
                    a, i, j, r = self.a, self.i, self.j, self.r
                x, y = points[:, 0][:, None, None], points[:, 1][:, None, None]
                K = a.shape[1]
                return 1/np.pi/K/K * (a * (i*i + j*j)**(r-1) * torch.sin(np.pi*i*x) * torch.sin(np.pi*j*y)).sum((-2,-1))

        # Generate datasets
        K = self.config.K
        n_samples = self.config.n_samples
        fs = []
        sols = []
        datasets = []
        
        # Create mesh once (all datasets use the same mesh)
        mesh = MeshGen.init_mesh(self.config)
        
        for _ in range(n_samples):
            a = 2 * (torch.rand(K, K) - 0.5)
            fs.append(F(a=a, device=self.device))
            sols.append(Sol(a=a, device=self.device))
            
            # Create PoissonEquation with the shared mesh and source function
            datasets.append(PoissonEquation(
                mesh=mesh,
                f=fs[-1]
            ))

        self.datasets = datasets
        self.fs = fs 
        self.sols = sols

        # Prepare node data
        n_nodes = self.datasets[0].mesh.points.shape[0]
        ndata = torch.zeros(len(self.datasets), n_nodes, 
                           3 if self.config.use_coord_feat else 1)
        
        for i, dataset in enumerate(self.datasets):
            # Convert mesh points to torch tensor for function evaluation
            mesh_points = torch.from_numpy(dataset.mesh.points[:, :2]).float()
            f_val = dataset.f(mesh_points)
            if self.config.use_coord_feat:
                ndata[i, :, :2] = torch.from_numpy(dataset.mesh.points[:, :2]).float()
                ndata[i, :, 2] = f_val
            else:
                ndata[i, :, 0] = f_val

        # Create graph from the first mesh  
        graph = MeshGen.mesh_to_pyg_graph(self.datasets[0].mesh)
        edge_index = graph.edge_index

        self.ndata = ndata.to(self.device)
        self.edge_index = edge_index.to(self.device)
        self.graph = graph

    def init_model(self):
        """Initialize the model"""
        input_dim = 3 if self.config.use_coord_feat else 1
        output_dim = 1
        
        return plain_GNN(
            num_features=input_dim, 
            num_hidden=self.config.n_hidden, 
            num_classes=output_dim, 
            num_layers=self.config.n_layers, 
            gnn_type=self.config.gnn, 
            dropout_in=self.config.dropout_in, 
            dropout=self.config.dropout
        ) 

    def init_optimizer(self):
        """Initialize optimizer and scheduler"""
        if self.config.optimizer == "adam":
            self.optimizer = torch.optim.Adam(self.model.parameters(), 
                               lr=self.config.lr)
        elif self.config.optimizer == "lbfgs":
            self.optimizer = torch.optim.LBFGS(self.model.parameters(),
                               lr=float(0.5),
                               max_iter=self.config.max_iter,
                               max_eval=50000,
                               history_size=150,
                               line_search_fn="strong_wolfe",
                               tolerance_change=1.0 * np.finfo(float).eps)
        else:
            raise NotImplementedError(f"Optimizer {self.config.optimizer} not implemented")
        
        # Initialize scheduler
        if self.config.use_scheduler:
            self.scheduler = StepLR(self.optimizer, 
                                   step_size=self.config.scheduler_step_size, 
                                   gamma=self.config.scheduler_gamma)
            if self.config.optimizer == "lbfgs":
                raise NotImplementedError(f"Scheduler not implemented for LBFGS optimizer")
        else:
            self.scheduler = None

    def step(self, U_t1):
        """Not applicable for Poisson equation (steady-state)"""
        raise NotImplementedError("Step method not applicable for Poisson equation")

    def multisteps(self, U_t0, steps):
        """Not applicable for Poisson equation (steady-state)"""
        raise NotImplementedError("Multisteps method not applicable for Poisson equation")

    def compute_loss(self, start_idx=None, end_idx=None):
        """Compute physics-informed loss using residuals"""
        batch_indices = slice(start_idx, end_idx) if start_idx is not None else slice(None)
        U = self.model(x=self.ndata[batch_indices], edge_index=self.edge_index)
        
        if self.config.use_coord_feat:
            f_values_batch = self.ndata[batch_indices, :, 2]  # [batch_size, n_nodes]
        else:
            f_values_batch = self.ndata[batch_indices, :, 0]  # [batch_size, n_nodes]
        
        if U.dim() == 3 and U.shape[-1] == 1:
            U = U.squeeze(-1)  # [batch_size, n_nodes]
        
        # TODO:Use any dataset for compute_residual (they all share the same mesh and 'a' parameter)
        # Just use the first dataset since mesh and 'a' are the same for all
        dataset_representative = self.datasets[0]

        R_batch = dataset_representative.compute_residual(U, f_values_batch)  # [batch_size, n_nodes]
        
        sample_losses = torch.mean(R_batch**2, dim=1)  # [batch_size]
        return sample_losses.mean()  

    def fit(self, verbose=False):
        """Train the model"""
        batch_size = self.config.batch_size
        os.makedirs(os.path.dirname(self.config.ckpt_path), exist_ok=True)

        # Training variables
        best_loss, best_iter, best_state = np.inf, -1, None
        losses = []
        mse_loss = nn.MSELoss()

        epoch = {   
            "lbfgs": self.config.max_iter,
            "adam": self.config.epoch
        }[self.config.optimizer]
        
        n_iter = (len(self.datasets) + batch_size - 1) // batch_size

        if self.config.optimizer == "lbfgs":
            def closure():
                self.optimizer.zero_grad()
                loss = self.compute_loss()
                loss.backward()
                losses.append(loss.item())
                pbar.update(1)
                pbar.set_postfix({"loss": loss.item()})
                pbar.set_description(f"iteration:{pbar.n}")
                return loss
            
            pbar = tqdm(total=self.config.max_iter, desc="Iteration:", colour="blue")
            self.optimizer.step(closure)

        elif self.config.optimizer == "adam":
            pbar = tqdm(total=self.config.epoch, desc="Epoch:", colour="blue")
            for ep in range(self.config.epoch):
                for it in range(n_iter):
                    start_idx = it * batch_size
                    end_idx = (it + 1) * batch_size
                    
                    self.optimizer.zero_grad()
                    loss = self.compute_loss(start_idx, end_idx)
                    loss.backward()
                    self.optimizer.step()
                    
                    if self.scheduler is not None:
                        self.scheduler.step()

                # Evaluate on entire dataset
                avg_loss = self.evaluate_on_entire_dataset()
                pbar.update(1)
                pbar.set_postfix({"loss": avg_loss})
                pbar.set_description(f"Epoch:{ep}")
                losses.append(avg_loss)
                
                if avg_loss < best_loss:
                    best_iter = len(losses)
                    if self.scheduler is not None:
                        best_state = (deepcopy(self.model.state_dict()), 
                                     deepcopy(self.optimizer.state_dict()), 
                                     deepcopy(self.scheduler.state_dict()))
                    else:
                        best_state = (deepcopy(self.model.state_dict()), 
                                     deepcopy(self.optimizer.state_dict()))
                    best_loss = avg_loss

        # Load best state and save
        if best_state is not None:
            self.model.load_state_dict(best_state[0])
            self.optimizer.load_state_dict(best_state[1])
            if self.scheduler is not None:
                self.scheduler.load_state_dict(best_state[2])
        
        self.save_ckpt()

        # Plot loss curve
        fig, ax = plt.subplots(figsize=(12, 8))
        ax.plot(np.arange(len(losses)), losses)
        if best_iter >= 0:
            ax.scatter([best_iter], [best_loss], c='r', marker='o', label="best loss")
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        ax.set_title('Loss vs Epoch')
        ax.legend()
        ax.set_xlim(left=0)
        ax.set_yscale('log')
        plt.savefig('loss.png')
        
        # Run test
        self.test()

    def evaluate_on_entire_dataset(self, batch_size=16):
        """Evaluate model on entire dataset"""
        self.model.eval()
        mse_loss = nn.MSELoss()
        
        total_batches = len(self.ndata) // batch_size
        total_batches += (len(self.ndata) % batch_size != 0)

        total_loss = 0.0
        for it in range(total_batches):
            with torch.no_grad():
                start_idx = it * batch_size
                end_idx = (it + 1) * batch_size
                loss = self.compute_loss(start_idx, end_idx)
                total_loss += loss.item()

        avg_loss = total_loss / total_batches
        self.model.train()
        return avg_loss

    def test(self):
        """Test the model and generate visualizations"""
        test_idx = 100  # which dataset to test 
        mse_loss = nn.MSELoss()
        batch_size = 16
        n_iter = (len(self.datasets) + batch_size - 1) // batch_size

        # Get ground truth solutions
        n_nodes = self.datasets[0].mesh.points.shape[0]
        solutions = torch.zeros(len(self.datasets), n_nodes).to(self.device)
        for i, (sol, dataset) in enumerate(zip(self.sols, self.datasets)):
            # Convert mesh points to tensor for evaluation
            mesh_points = torch.from_numpy(dataset.mesh.points[:, :2]).float()
            sol_val = sol(mesh_points)
            solutions[i] = sol_val
        
        # Get predictions
        predictions = torch.zeros(len(self.datasets), n_nodes).to(self.device)
        with torch.no_grad():
            for it in tqdm(range(n_iter)):
                start_idx = it * batch_size
                end_idx = (it + 1) * batch_size
                U = self.model(self.ndata[start_idx:end_idx], self.edge_index)
                predictions[start_idx:end_idx] = U.squeeze()
              
        # Apply boundary conditions
        boundary_mask = self.graph.boundary_mask
        predictions[:, boundary_mask] = 0.0
        solutions[:, boundary_mask] = 0.0
        
        l2_error = mse_loss(predictions, solutions)
        print('L2 error:', l2_error.item())

        # Compute relative error
        rel_error = torch.abs(predictions[test_idx] - solutions[test_idx]) / torch.abs(solutions[test_idx])
        rel_error[boundary_mask] = 0.0
        
        # Compute residuals
        R_solution = self.datasets[test_idx].compute_residual(solutions[test_idx])
        print('R_solution:', mse_loss(R_solution, torch.zeros_like(R_solution)).item())
        
        R_prediction = self.datasets[test_idx].compute_residual(predictions[test_idx])
        print('R_prediction:', mse_loss(R_prediction, torch.zeros_like(R_prediction)).item())

        # Generate visualizations
        solution = solutions[test_idx].cpu().numpy()
        prediction = predictions[test_idx].cpu().numpy()
        points = self.datasets[test_idx].mesh.points[:, :2]  # Only x,y coordinates

        grid_x, grid_y = np.meshgrid(np.linspace(0, 1, 100), np.linspace(0, 1, 100))
        grid_sol = griddata((points[:, 0], points[:, 1]), solution, (grid_x, grid_y), method='linear')
        grid_pre = griddata((points[:, 0], points[:, 1]), prediction, (grid_x, grid_y), method='linear')
        grid_error = griddata((points[:, 0], points[:, 1]), rel_error.cpu().numpy(), (grid_x, grid_y), method='linear')

        vmin = min(grid_sol.min(), grid_pre.min())
        vmax = max(grid_sol.max(), grid_pre.max())

        # Plot solutions
        fig, ax = plt.subplots(ncols=2, figsize=(12, 6))
        im1 = ax[0].imshow(grid_sol, origin='lower', vmin=vmin, vmax=vmax, cmap='jet')
        ax[0].set_title('Exact solution')
        im2 = ax[1].imshow(grid_pre, origin='lower', vmin=vmin, vmax=vmax, cmap='jet')
        ax[1].set_title('Predicted solution')
        cax = plt.axes([0.92, 0.1, 0.02, 0.8])
        cbar = plt.colorbar(im2, cax=cax)
        fig.savefig("prediction.png")

        # Plot error
        fig, ax = plt.subplots(figsize=(12, 8))
        im = ax.imshow(grid_error, origin='lower', cmap='jet', vmin=0, vmax=grid_error.max())
        ax.set_title('Relative error')
        cax = plt.axes([0.92, 0.1, 0.02, 0.8])
        cbar = plt.colorbar(im)
        plt.savefig("error.png")


# For compatibility with the original code
PoissonVaryfTrainer = VaryfPoissonTrainer