import os
import numpy as np
import torch
import sys
from tqdm import tqdm
import re
sys.path.append("./")
os.environ["METIS_DLL"]="./lib/libmetis.so"
from models.GNOT.data_utils import get_model, get_loss_func

from utils.logging_utils import resetLogger
from models.ddno import DDNO 
from tqdm import tqdm
from args import get_inference_args
import numpy as np

import plotly.io as pio
pio.renderers.default = 'iframe'
# Set notebook mode to work in offline

from utils.domain import DecomposedSimplePolygonMeshDomain
from trimesh.base import Trimesh
from utils.data_utils import (get_inference_boundary_marker, 
                              transform_gt, 
                              get_inference_dolphinx_dataset, 
                              get_inference_mesh, 
                              get_inference_normalizer)
import logging

logger = logging.getLogger(__name__)

if __name__ == "__main__":
    resetLogger()
    args = get_inference_args()

    if not args.no_cuda and torch.cuda.is_available():
        device = torch.device('cuda:{}'.format(str(args.gpu)))
    else:
        device = torch.device("cpu")
        
    kwargs = {'pin_memory': False} if args.gpu else {}

    args.test_num = int(args.test_num) if args.test_num not in ['all', 'none'] else args.test_num

    test_dataset = get_inference_dolphinx_dataset(args)
    args.dataset_config = test_dataset.config

    args.space_dim = int(re.search(r'\d', args.dataset).group())
    args.normalizer =  test_dataset.y_normalizer.to(device) if test_dataset.y_normalizer is not None else None

    loss_func = get_loss_func(name=args.loss_name,args= args, regularizer=True,normalizer=args.normalizer)
    metric_func = get_loss_func(name='rel2', args=args, regularizer=False, normalizer=args.normalizer)

    gmesh, trimesh = get_inference_mesh(args)
    normalizer = get_inference_normalizer(args)(device)
    boundary_marker = get_inference_boundary_marker(args, gmesh)

    mesh = Trimesh(trimesh['vertices'], trimesh["faces"])
    domain = DecomposedSimplePolygonMeshDomain(mesh, 
                                               dim=2, 
                                               boundary_marker=boundary_marker, 
                                               n_parts=args.n_parts, depth=args.depth)

    local_operator = get_model(args)
    local_operator.load_state_dict(torch.load(args.model_path)["model"])
    model = DDNO(local_operator, domain, 2, normalizer=normalizer)
    model.to(device)

    losses = []

    for data in tqdm(test_dataset):
        graph, u_p, inputs_f = data
        input_func = []
        gt_sol = transform_gt(model, graph)
        
        inputs_f = inputs_f.to(device)
        u_p = u_p.to(device)
        graph = graph.to(device)
        gt_sol = gt_sol.to(device)
        
        epochs = args.epochs

        p = model.domain.n_parts
        tau = args.tau
        
        loss = []
        with torch.no_grad():
            bc = model.map_boundary(inputs_f[0])
            bic = (bc, None)
            sol = model.initialize(inputs_f)


            for i in range(epochs):

                local_sols = model(sol, bic, u_p, input_func)
                extended_sols = [(model.rm[i].T @ v + (1 - model.masks[i]) * sol).to(sol.device) for i, v in enumerate(local_sols)]

                sol = (1-tau*p)*sol + tau * sum(extended_sols)
                #print(metric_func(graph, sol, gt_sol))
                loss.append(round(float(metric_func(graph, sol, gt_sol)[2]), 4))
                #print(loss[-1])
                if len(loss) < 10:
                    pass
                else:
                    if loss[-1] == loss[-10]:
                        break
            losses.append(loss[-1])
            logger.info(loss[-1])
    logger.info(losses)