import torch
from torch import nn
import time
from tqdm import tqdm
from torch.autograd import Variable
from utils.initial_x import getU, getV, getH

import torch.optim as optim
from UHstruct.fobj_val import UltraE_fval_obj
from UHstruct.get_hit_ranks import UltraE_test_hits_rank


class CSDM_al(nn.Module):
    def __init__(self, T, H, theta, mu, ksi, Q, P, config_yaml):
        super().__init__()
        self.device = config_yaml["device"]
        self.T = T.to(self.device)
        self.dataH = H.to(self.device)
        self.Q = Q.to(self.device)
        self.P = P.to(self.device)

        self.config_yaml = config_yaml
        self.maxiter = int(float(config_yaml["run"]["maxiter"]))
        self.Cmaxiter = int(float(config_yaml["run"]["Cmaxiter"]))
        self.d = config_yaml["datafeature"]["d"]
        self.p = config_yaml["datafeature"]["p"]
        self.stopt = config_yaml["run"]["stop_t"]
        self.stoplr = float(config_yaml["run"]["stop_lr"])
        self.stopindex = 0

        self.theta = nn.Parameter(data=theta)
        self.mu = nn.Parameter(data=mu)
        self.ksi = nn.Parameter(data=ksi)
        self.getM()


        self.inneriter = int(1e2)
        J = torch.eye(self.d)
        J[self.p:, self.p:] = -1 * torch.eye(self.d - self.p)
        self.J = J.to(self.device)
        self.lambdaa = 0
        self.myeps = 1e-8
        self.optimizer = optim.Adagrad([self.theta, self.mu, self.ksi], lr=float(config_yaml["run"]["lr_eb"]), weight_decay=float(config_yaml["run"]["weight_decay"]))

    def Train(self):
        hist_Obj = torch.zeros(self.maxiter, 1).to(self.device)
        hist_t = torch.zeros(self.maxiter, 1).to(self.device)
        start_time = time.time()
        for iter in tqdm(range(self.maxiter), desc='Outiter'):
            with torch.no_grad():
                hist_Obj[iter] = UltraE_fval_obj(self.T, self.dataH, self.P, self.Q, self.config_yaml)

            self.innerhist_obj = torch.zeros(int(self.Cmaxiter/self.config_yaml['run']['inner_dispgap'])).to(self.device)
            flag = 0
            for Citer in range(self.Cmaxiter):
                self.step_forward()
                self.getM()

            with torch.no_grad():
                hist_t[iter] = time.time() - start_time
                if hist_t[iter] > self.stopt:
                    self.stopindex = 1

                if self.stopindex == 1:
                    break

        hist_Obj = hist_Obj[hist_Obj != 0]
        hist_t = hist_t[:len(hist_Obj)]
        return hist_Obj, self.Q, hist_t

    def step_forward(self):
        fval = UltraE_fval_obj(self.T, self.dataH, self.P, self.Q, self.config_yaml)
        fval.backward()

        if torch.isnan(self.theta.grad).any():
            self.theta.grad[torch.isnan(self.theta.grad)] = 0
        if torch.isnan(self.ksi.grad).any():
            self.ksi.grad[torch.isnan(self.ksi.grad)] = 0
        if torch.isnan(self.mu.grad).any():
            self.mu.grad[torch.isnan(self.mu.grad)] = 0

        self.optimizer.step()
        self.theta.grad = None
        self.ksi.grad = None
        self.mu.grad = None

    def getM(self):

        self.U = getU(self.theta).to(torch.float32).to(self.device)
        self.H = getH(self.d, self.p, torch.clamp(self.mu, -1, 1)).to(torch.float32).to(self.device)
        self.V = getV(self.ksi).to(torch.float32).to(self.device)
        self.Q = torch.mm(self.U, torch.mm(self.H, self.V)).to(self.device)

        self.U.retain_grad()
        self.V.retain_grad()
        self.H.retain_grad()
        self.Q.retain_grad()

