import torch
from torch import nn
import torch.optim as optim
from tqdm import tqdm
from utils.Jexc import exc_operator
from utils.Jorthogonal_test import Jtest
import time
import os
from UHstruct.fobj_val import UltraE_fval_obj
from UHstruct.get_hit_ranks import UltraE_test_hits_rank

torch.manual_seed(0)
class admm(nn.Module):
    def __init__(self, T, H,Q,P, config_yaml, lambda_val, mode='d'):
        super().__init__()
        self.device = config_yaml["device"]
        self.T = T.to(self.device)
        self.H = H.to(self.device)
        self.Q = nn.Parameter(data=Q.to(self.device))
        self.P = P.to(self.device)

        self.config_yaml = config_yaml
        self.maxiter = int(float(config_yaml["run"]["maxiter"]))
        self.d = config_yaml["datafeature"]["d"]
        self.p = config_yaml["datafeature"]["p"]
        self.stopt = config_yaml["run"]["stop_t"]

        self.stopindex = 0

        self.myeps = 1e-8
        self.lambda_val = lambda_val
        self.mode = mode
        self.Xgrad = None

        self.maxiter = int(float(config_yaml["run"]["maxiter"]))
        self.d = config_yaml["datafeature"]["d"]
        self.p = config_yaml["datafeature"]["p"]
        self.stopt = config_yaml["run"]["stop_t"]
        self.stoplr = float(config_yaml["run"]["stop_lr"])

        J = torch.eye(self.d)
        J[self.p:, self.p:] = -1 * torch.eye(self.d - self.p)
        self.J = J.to(torch.float32).to(self.device)
        self.myeps = 1e-8
        self.lipsconst = float(config_yaml["ADMM"]["lipsconst"])

        self.Pi = torch.randn(self.Q.size()).to(torch.float32).to(self.device) * 0.1
        self.ksi = torch.randn(self.Q.size()).to(torch.float32).to(self.device) * 0.1
        self.Y = self.Q.clone().to(torch.float32).to(self.device)

        self.Jerr0 = Jtest(self.Q, self.p)

    def Train(self):
        hist_Obj = torch.zeros(self.maxiter, 1).to(self.device)
        hist_t = torch.zeros(self.maxiter, 1).to(self.device)
        hist_err = torch.zeros(self.maxiter, 1).to(self.device)
        start_time = time.time()

        for iter in tqdm(range(self.maxiter), desc='Outiter'):
            with torch.no_grad():
                hist_Obj[iter] = UltraE_fval_obj(self.T, self.H, self.P, self.Q, self.config_yaml)
                hist_err[iter] = Jtest(self.Q, self.p)

            self.stepforward()

            if self.mode == 'd':
                if iter % 2 == 0:
                    if self.lambda_val < 1e12:
                        self.lambda_val *= 2

            with torch.no_grad():
                hist_t[iter] = time.time() - start_time
                if hist_t[iter] > self.stopt:
                    self.stopindex = 1
                if self.stopindex == 1:
                    break

        hist_Obj = hist_Obj[hist_Obj != 0]
        hist_t = hist_t[:len(hist_Obj)]
        hist_err = hist_err[:len(hist_Obj)]
        return hist_Obj, self.Q, hist_t, hist_err

    def stepforward(self):
        fval = UltraE_fval_obj(self.T, self.H, self.P, self.Q, self.config_yaml)
        fval.backward()

        if torch.isnan(self.Q.grad).any():
            self.Q.grad[torch.isnan(self.Q.grad)] = 0
            # os.system("pause")
        if torch.isinf(self.Q.grad).any():
            self.Q.grad[torch.isinf(self.Q.grad)] = 0

        # self.optimizer.step()
        self.updateR()
        self.Q.grad = None

    @torch.no_grad()
    def updateR(self):
        grad = self.Q.grad.to(torch.float32)

        # sub-problem 1
        try:
            self.Q.data = (torch.pinverse((self.lipsconst+self.lambda_val)*torch.eye(self.d).to(self.device)+self.lambda_val*self.Y@self.Y.t())
                           @ (-(grad - self.lipsconst * self.Q + self.Y @ self.ksi.t()+self.J @self.Pi -
                                self.lambda_val * self.Y@self.J-self.lambda_val*self.J@self.Y)))
        except:
            self.Q.data = self.Q
        self.Q.data = torch.clamp(self.Q, -100, 100)

        diff = self.Q.T @ self.Y - self.J
        self.ksi += self.lambda_val * diff
        self.ksi = torch.clamp(self.ksi, -10, 10)

        # sub-problem 2
        try:
            self.Y = (torch.pinverse(self.lambda_val * (self.Q @ self.Q.t()+torch.eye(self.d).to(self.device)))
                      @(-(self.Q @ self.ksi - self.J @ self.Pi - self.lambda_val * self.Q @ self.J - self.lambda_val * self.J @ self.Q)))
        except:
            self.Y = self.Y

        diff = self.J @ self.Q - self.Y
        self.Pi += self.lambda_val * diff
        self.Pi = torch.clamp(self.Pi, -10, 10)



