import os
import numpy as np
import torch
import sys
from tqdm import tqdm
import re
sys.path.append("./")
os.environ["METIS_DLL"]="./lib/libmetis.so"
from models.GNOT.data_utils import get_model, get_loss_func

from utils.logging_utils import resetLogger
from models.ddno import DDNO 
from tqdm import tqdm
from args import get_inference_args
import numpy as np

import plotly.io as pio
pio.renderers.default = 'iframe'
# Set notebook mode to work in offline

from utils.domain import DecomposedSpaceTimeSimplePolygonMeshDomain
from trimesh.base import Trimesh
from utils.data_utils import (get_inference_boundary_marker, 
                              transform_gt, 
                              get_inference_dolphinx_dataset, 
                              get_inference_mesh, 
                              get_inference_normalizer)
import logging

logger = logging.getLogger(__name__)

if __name__ == "__main__":
    resetLogger()
    args = get_inference_args()

    if not args.no_cuda and torch.cuda.is_available():
        device = torch.device('cuda:{}'.format(str(args.gpu)))
    else:
        device = torch.device("cpu")
        
    kwargs = {'pin_memory': False} if args.gpu else {}

    args.test_num = int(args.test_num) if args.test_num not in ['all', 'none'] else args.test_num

    test_dataset = get_inference_dolphinx_dataset(args)
    args.dataset_config = test_dataset.config

    args.space_dim = int(re.search(r'\d', args.dataset).group())
    args.normalizer =  test_dataset.y_normalizer.to(device) if test_dataset.y_normalizer is not None else None

    loss_func = get_loss_func(name=args.loss_name,args= args, regularizer=True,normalizer=args.normalizer)
    metric_func = get_loss_func(name='rel2', args=args, regularizer=False, normalizer=args.normalizer)

    gmesh, trimesh = get_inference_mesh(args)
    normalizer = get_inference_normalizer(args)(device)
    boundary_marker = get_inference_boundary_marker(args, gmesh)


    mesh = Trimesh(trimesh['vertices'], trimesh["faces"])
    domain = DecomposedSpaceTimeSimplePolygonMeshDomain(mesh, 
                                                    dim=2, 
                                                    boundary_marker=boundary_marker, 
                                                    n_parts=args.n_parts, 
                                                    depth=args.depth,
                                                   time_step=args.time_step, 
                                                   time_span=args.time_span)
    local_operator = get_model(args)
    local_operator.load_state_dict(torch.load(args.model_path)["model"])

    model = DDNO(local_operator, domain, 2, 
             time_dependent=True, time_span=args.time_span, 
             normalizer=normalizer)
    model.to(device)

    losses = []
    iterations = []

    for data in tqdm(test_dataset):
        graph, u_p, inputs_f = data
        input_func = []
        gt_sol = transform_gt(model, graph)
        
        inputs_f = inputs_f.to(device)
        u_p = u_p.to(device)
        graph = graph.to(device)
        gt_sol = gt_sol.to(device)
        
        epochs = args.epochs

        p = model.domain.n_parts
        q = model.domain.num_interval
        tau = args.tau
        
        loss = []
        with torch.no_grad():
            boundary_input = inputs_f[1]
            initial_condition_input = inputs_f[0]
            
            bc = model.map_boundary(boundary_input)
            ic = model.map_input(initial_condition_input)
            sol = model.initialize(inputs_f)
            bic = (bc, ic)


            for i in range(epochs):

                temporal_local_sols = model(sol, bic, u_p, input_func)
                extended_temporal_sols = [sum([((model.rm[i].T @ temporal_local_sols[t][i] @ model.trm[t].T) + (1 -  model.masks[i] @ model.time_masks[t].T) * sol).to(sol.device) \
                                   for i, _ in enumerate(model.domain.subDomain)]) \
                                      for t, _ in enumerate(model.domain.subTimeInterval)]
                
                sol = (1-tau*(p*q))*sol + tau * sum(extended_temporal_sols)
                #print(metric_func(graph, sol, gt_sol))
                loss.append(round(float(metric_func(graph, sol, gt_sol)[0]), 4))
                #print(loss[-1])
                if len(loss) < 10:
                    pass
                else:
                    if loss[-1] == loss[-10]:
                        break

            #iterations.append(len(loss))
            losses.append(loss[-1])
            logger.info(loss[-1])
    #logger.info(iterations)
    logger.info(losses)