import os
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from models.base import NeuralPDEABC
from models.networks import get_network
from utils import *

class NeuralElasticityBase(NeuralPDEABC):
    def __init__(self, cfg):
        super(NeuralElasticityBase, self).__init__(cfg)
        self.energy = cfg.energy
        self.dim = cfg.dim
        self.use_mesh = cfg.use_mesh
        self.mesh_path = cfg.mesh_path
        self.vis_resolution = cfg.vis_resolution

        # neural implicit network for deformation field
        self.deformation_field = get_network(cfg, self.dim, self.dim).cuda()
        self.deformation_field_prev = get_network(self.cfg, self.dim, self.dim).cuda()
        self.deformation_field_prev_prev = get_network(self.cfg, self.dim, self.dim).cuda()
        self._set_require_grads(self.deformation_field_prev, False)
        self._set_require_grads(self.deformation_field_prev_prev, False)
        with torch.no_grad():
            self.deformation_field_prev.load_state_dict(self.deformation_field.state_dict())
            self.deformation_field_prev_prev.load_state_dict(self.deformation_field.state_dict())
        self.create_optimizer()


    @property
    def _trainable_networks(self):
        return {'deformation': self.deformation_field}

    @NeuralPDEABC._training_loop
    def _initialize(self):
        """initialize all field to zeros"""
        samples = self.sample_in_training(self.sample_resolution)

        out_wt = self.deformation_field(samples)
        loss_wt = torch.mean(out_wt ** 2)
        loss_dict = {'main': loss_wt}

        if self.tb.train_iter == 0 or (self.tb.train_iter + 1) % self.cfg.vis_frequency == 0:
            self._vis_deformation_field(self.vis_resolution, attr='deformation')
        return loss_dict

    def initialize(self):
        self.tb = self.create_tb("initialize")
        self.create_optimizer()
        self.deformation_field_prev.load_state_dict(self.deformation_field.state_dict())
        self._initialize()
        self.save_ckpt('initialize')

    def step(self):
        pass

    #################### sampling during training #######################
    def sample_in_training(self, resolution):
        samples = []
        if self.use_mesh == True:
            for s in self.sample_pattern:
                if s == 'random':
                    random_samples = sample_mesh(self.mesh_V, self.mesh_F, resolution**self.dim, self.distrib)[:, 0:self.dim]
                    samples.append(random_samples.cuda().requires_grad_(True))
                elif s == 'uniform':
                    uniform_samples = self.mesh_V[:, 0:self.dim]
                    samples.append(uniform_samples.cuda().requires_grad_(True))
        else:
            for s in self.sample_pattern:
                if s == 'random':
                    random_samples = sample_random(resolution ** self.dim, self.dim, device=self.device).requires_grad_(True)
                    samples.append(random_samples)
                elif s == 'uniform':
                    uniform_samples = sample_uniform(resolution, self.dim, device=self.device).requires_grad_(True)
                    samples.append(uniform_samples)
                else:
                    raise NotImplementedError

        samples = torch.cat(samples, dim=0)
        return samples

    ################# visualization during training #####################
    def _vis_deformation_field(self, resolution, attr="deformation"):
        pass