import logging
import time
import numpy as np
import torch
from base_config import BaseConfig


# 打印相关信息
def log(obj):
    print(obj)
    logging.info(obj)


class PINNConfig(BaseConfig):
    def __init__(self, param_dict, train_dict, model, model_id):
        super().__init__()
        self.init(loss_name='sum')
        self.model = model
        self.model_id = model_id
        # 设置使用设备:cpu, cuda
        lb, ub, self.device, self.path, self.root_path = self.unzip_param_dict(
            param_dict=param_dict)
        # 上下界
        self.lb = self.data_loader(lb, requires_grad=False)
        self.ub = self.data_loader(ub, requires_grad=False)

        # 加载训练参数
        self.lambda_true, x, self.d, self.N_R, self.M = self.unzip_train_dict(
            train_dict=train_dict)

        # 区域内点
        self.x = []
        for i in range(self.d):
            xi = x[:, i:i+1]
            X = self.data_loader(xi)
            self.x.append(X)

        # 初始猜测
        # u = np.zeros(shape=[self.N_R, 1])
        # u[0, 0] = 1
        u = np.random.rand(self.N_R, 1)
        u = u/np.linalg.norm(u)
        self.u = self.data_loader(u, requires_grad=False)
        self.lambda_last = 1
        self.shift1 = 0
        self.shift2 = 0
        self.shift3 = 0
        self.shift4 = 0
        self.tmp_lambda = 0
        self.lambda_ = None
        self.eigenvec = None

    # 训练参数初始化
    def init(self, loss_name='mean', model_name='PINN'):
        self.start_time = None
        # 小于这个数是开始保存模型
        self.min_loss = 1e20
        # 记录运行步数
        self.nIter = 0
        # 损失计算方式
        if loss_name == 'sum':
            self.loss_fn = torch.nn.MSELoss(reduction='sum')
        else:
            self.loss_fn = torch.nn.MSELoss(reduction='mean')
        # 保存模型的名字
        self.model_name = model_name

    # 参数读取

    def unzip_param_dict(self, param_dict):
        param_data = (param_dict['lb'], param_dict['ub'],
                      param_dict['device'], param_dict['path'],
                      param_dict['root_path'])
        return param_data

    def unzip_train_dict(self, train_dict):
        train_data = (
            train_dict['lambda_'],
            train_dict['x'],
            train_dict['d'],
            train_dict['N_R'],
            train_dict['M'],
        )
        return train_data

    def net_model(self, x):
        if isinstance(x, list):
            X = torch.cat((x), 1)
        else:
            X = x
        X = self.coor_shift(X, self.lb, self.ub)
        result = self.model.forward(X)
        # 强制Dirichlet边界条件
        g_x = 1
        for i in range(self.d):
            # g_x = g_x * (1-torch.exp(-(x[i]-self.lb[i])))*(1-torch.exp(-(x[i]-self.ub[i])))
            g_x = g_x * (torch.exp((x[i]-self.lb[i]))-1)*(torch.exp(-(x[i]-self.ub[i]))-1)
        result = g_x * result
        return result

    def forward(self, x):
        result = self.net_model(x)
        return result

    # 训练一次
    def optimize_one_epoch(self,max_val=None, max_vec=None, max_val1=None, max_vec1=None,max_val2=None,max_vec2=None,optimizer=None):
        if self.start_time is None:
            self.start_time = time.time()

        # 初始化loss为0
        self.optimizer = optimizer
        self.loss = torch.tensor(0.0, dtype=torch.float64).to(self.device)
        self.loss.requires_grad_()
        # 区域点
        x = self.x
        u = self.forward(x)
        u_xx = None
        for i in range(self.d):
            xi = x[i]
            u_xi = self.compute_grad(u, xi)
            u_xixi = self.compute_grad(u_xi, xi)
            if u_xx is None:
                u_xx = u_xixi
            else:
                u_xx = u_xx + u_xixi
        # Rayleigh-Quotient 计算最小特征值 lambda_
        # <Lu, u>/<u, u>
        if self.model_id == 1:
            Lu = -u_xx - self.shift1 * u
        elif self.model_id == 2:
            # print(max_vec.shape, max_val.shape, u.shape)
            Lu = -u_xx - max_val*torch.matmul(max_vec,
                                      torch.matmul(max_vec.transpose(0, 1), u)) - self.shift2 * u
        elif self.model_id == 3:
            Lu = -u_xx - max_val*torch.matmul(max_vec,
                                      torch.matmul(max_vec.transpose(0, 1), u)) - max_val1*torch.matmul(
                 max_vec1, torch.matmul(max_vec1.transpose(0, 1), u)) - self.shift3 * u
        elif self.model_id == 4:
            Lu = -u_xx - max_val*torch.matmul(max_vec,
                                      torch.matmul(max_vec.transpose(0, 1), u)) - max_val1*torch.matmul(
                max_vec1,
                torch.matmul(max_vec1.transpose(0, 1), u)) - max_val2*torch.matmul(
                max_vec2, torch.matmul(max_vec2.transpose(0, 1), u))
        tmp_loss = self.loss_func(Lu - self.lambda_last * self.u)
        # 先计算 u^(k+1)
        # self.u = Lu/self.lambda_last
        u1 = Lu
        # 归一化
        u1 = u1 / torch.norm(u1, p=2)

        # u = u/torch.norm(u)
        loss_PM = self.loss_func(u1, self.u)
        self.u = self.data_loader(self.detach(u), requires_grad=False)
        # 权重
        alpha_PM = 1
        self.loss = loss_PM * alpha_PM

        # 反向传播
        self.loss.backward()
        # 运算次数加1
        self.nIter = self.nIter + 1
        # # 区域计算Rayleigh-Quotient
        Luu = torch.sum(Lu * self.u)
        uu = torch.sum(u ** 2)
        lambda_ = self.detach(Luu / uu)
        lambda_ = lambda_.max()
        self.lambda_last = lambda_

        if self.nIter == self.N_R / 20:
            if self.model_id == 1:
                self.shift1 = lambda_ - 2
            elif self.model_id == 2:
                self.shift2 = lambda_ - 2
            elif self.model_id == 3:
                self.shift3 = lambda_ - 2
            self.lambda_last = 1
        if self.model_id == 1:
            self.tmp_lambda = lambda_ + self.shift1
        elif self.model_id == 2:
            self.tmp_lambda = lambda_ + self.shift2
        elif self.model_id == 3:
            self.tmp_lambda = lambda_ + self.shift3
        elif self.model_id == 4:
            self.tmp_lambda = lambda_ + self.shift4
        # 保存模型
        loss = self.detach(tmp_loss)
        # loss = self.detach(self.loss)
        # loss = self.detach(torch.norm(self.u-u))
        # loss = self.detach(torch.norm(self.u-u/torch.norm(u)))
        if loss < self.min_loss:
            if self.nIter > self.N_R / 20:
                if self.model_id == 1:
                    self.lambda_ = lambda_ + self.shift1
                elif self.model_id == 2:
                    self.lambda_ = lambda_ + self.shift2
                elif self.model_id == 3:
                    self.lambda_ = lambda_ + self.shift3
            else:
                self.lambda_ = lambda_
            self.min_loss = loss
            self.eigenvec = self.u
            PINNConfig.save(net=self,
                            path=self.root_path + '/' + self.path,
                            name=self.model_name)
        # 打印日志
        loss_remainder = 10
        if np.remainder(self.nIter, loss_remainder) == 0:
            # 打印常规loss
            loss_PM = self.detach(loss_PM)

            abs_lambda = np.abs(self.lambda_true - self.lambda_)
            rel_lambda = abs_lambda / self.lambda_true

            log_str = str(self.optimizer_name) + ' Iter ' + str(self.nIter) + ' Loss ' + str(loss) + \
                      ' lambda_ ' + str(self.tmp_lambda) + ' loss_PM ' + str(loss_PM) + \
                      ' min_loss ' + str(self.min_loss) + \
                      ' lambda ' + str(self.lambda_) + ' abs_lambda ' + str(abs_lambda) + ' rel_lambda ' + str(
                rel_lambda) + \
                      ' LR ' + str(self.optimizer.state_dict()['param_groups'][0]['lr'])

            log(log_str)

            # 打印耗时
            elapsed = time.time() - self.start_time
            print('Time: %.4fs Per %d Iterators' % (elapsed, loss_remainder))
            logging.info('Time: %.4f s Per %d Iterators' %
                         (elapsed, loss_remainder))
            self.start_time = time.time()
        return self.loss,self.lambda_,self.eigenvec
