import os
import copy
import time
import torch
import numpy as np
import gurobipy as gp
import matplotlib.pyplot as plt
from gurobipy import GRB, quicksum as qsum

import ot_solver
import mlot_simple as mlot


class MultilayerSolver:
    model = gp.Model("Multi-layer")
    # model.remove(model.getVars() + model.getConstrs())

    P : list[list[list]]
    P_gurobi = None
    C : list[list[list]]
    N : int = None
    SOURCE : torch.tensor
    TARGET : torch.tensor

    _finish_init = False
    area_num : list[int] = None

    def __init__(self, source, target, area_num, M):
        try:
            self.SOURCE = source
            self.TARGET = target
            self.area_num = area_num
            self.C = M
            self.N = len(area_num)
            self._finish_init = True
        except Exception as e:
            self.C = None
            self._finish_init = False


    def create_variable(self):
        assert self._check_problem(), "Havn't finish initialization"
        self.P = [None for _ in range(self.N-1)]
        # print("In gurobi:", self.area_num)
        for k in range(self.N-1):
            # print("Var shape:", self.area_num[k], self.area_num[k+1])
            self.P[k] = self.model.addVars(self.area_num[k], self.area_num[k+1], lb=0, ub=GRB.INFINITY, name=f"P_{k}")
            
    def initProblem(self, write=False):
        assert self._check_problem(), "Havn't finish initialization"
        # 目标最小化 cost
        self.model.setObjective(
            qsum(
                self.C[k][i][j] * self.P[k][i, j] for k in range(self.N-1)
                for i in range(self.area_num[k]) for j in range(self.area_num[k+1])),
            GRB.MINIMIZE
        )
        # 边缘约束
        p0, pn = self.P[0], self.P[-1]
        for i in range(self.area_num[0]):
            self.model.addConstr( qsum(p0[i, j] for j in range(self.area_num[1])) == self.SOURCE[i], name="source")
        for i in range(self.area_num[-1]):
            self.model.addConstr( qsum(pn[j, i] for j in range(self.area_num[-2])) == self.TARGET[i], name="target")

        # 中心约束
        self.model.addConstrs((qsum(self.P[k-1][s, i] for s in range(self.area_num[k-1]))
                                ==
                               qsum(self.P[k][i, t] for t in range(self.area_num[k+1]))
                            for k in range(1, self.N - 1) for i in range(self.area_num[k])), name="center_constraint")
        
        # 单位约束--每个点至多流量为1
        # self.model.addConstrs((qsum(self.P[k][i, j] for j in range(self.area_num[k+1])) <= 1
        #                     for k in range(self.N-1) for i in range(self.area_num[k])), name="unit_constraint")
        
        if write:
            self.model.write("multi-layer.lp")
    
    def _check_problem(self):
        return self._finish_init == True
    
    def solve(self):
        self.model.setParam('OutputFlag', 0)
        self.model.optimize()
        if self.model.status == GRB.OPTIMAL:
            print('Minimum cost by gurobi:', self.model.objVal)
            # for v in self.model.getVars():
            #     print(f'{v.varName} {v.x}')
        else:
            assert ValueError, "Optimization was not successful."
        
        record = self.model.objVal, self.model.getVars()
        return record
    

    def getP(self):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        P_gurobi = [torch.ones((self.area_num[i], self.area_num[i+1])).to(device, dtype=torch.float64) for i in range(self.N-1)]
        vars = self.model.getVars()
        for var in vars:
            if var.varName[0] == "P":
                try:
                    k = int(var.varName[2])
                    assert k in range(self.N-1), f"index k should be in range({self.N-1})"
                    i, j = map(int, var.varName[4:-1].split(","))
                    P_gurobi[k][i, j] = var.x
                except Exception as e:
                    print(f"error: {e} at {var.varName}")
                    if var.x > 0:
                        print(f"var.x: {var.x}")
                        assert False
        self.P_gurobi = P_gurobi
        return P_gurobi

    def getLayers(self):
        layers = []
        for i in range(self.N-1):
            layers.append(torch.sum(self.P_gurobi[i], dim=1))
        layers.append(torch.sum(self.P_gurobi[-1], dim=0))
        return layers
        

    def __str__(self):
        print(f"source = {self.SOURCE}\ntarget = {self.TARGET}")
        if self.C != None:
            for i in range(len(self.C)):
                print(f"C[{i}]: {self.C[i]}")
        else:
            print(f"C: None")
        return f"N = {self.N}\narea_index={self.area_index}\narea_num={self.area_num}"

    def __repr__(self):
        return self.__str__()
    

if __name__ == "__main__":

    def _generate_M(_n):
        points = []
        _M = []
        for i in range(len(_n)):
            x = np.random.uniform(-1e-10, 1e-10, _n[i]) + 0.001*i
            y = np.random.uniform(-10, 10, _n[i])
            points.append(np.stack([x, y], axis=1))

            if i > 0:
                tmp = points[i-1][:, np.newaxis, :] - points[i]
                euclidean = np.sqrt(np.sum(tmp**2, axis=-1))
                _M.append(euclidean)
        return points, _M
    def _generate_s_t(_n):
        _s = np.random.rand(_n[0])
        _t = np.random.rand(_n[-1])
        return _s / np.sum(_s), _t / np.sum(_t)
    def generateLineData(_n):
        points, _M = _generate_M(_n)
        _s, _t = _generate_s_t(_n)
        data = {
            'M': _M,
            'source': _s,
            'target': _t
        }
        return data
    
    def compo1(r, a):
        tmp1 = 1 + a*a*r*r
        tmp2 = 0.5 * r * np.sqrt(tmp1)
        return tmp2
    def compo2(r, a):
        tmp1 = 1 + a*a*r*r
        tmp2 = np.sqrt(tmp1) + a*r
        tmp3 = 1/(2*a) * np.log(tmp2)
        return tmp3
    def helicalDist(p1, p2, draw=False):
        r"""
        极坐标输入, 返回阿基米德螺线距离
        - \frac{dr}{d\theta} = b, 阿基米德螺线方程 r = b (\theta - \theta_0)
        - 代入两点坐标得 b = \frac{1}{\theta_2 - \theta_1}, \theta_0 = \theta_1 - r_1/b
        - 但尤其注意: 螺线的方向不确定, 故两点的坐标角度需要经过调整是否 + 2*pi
        """
        mp1, mp2 = copy.deepcopy(p1), copy.deepcopy(p2)
        r1, r2 = mp1[0], mp2[0]
        theta1, theta2 = mp1[1], mp2[1]
        is_2bigger1 = (theta2 > theta1)
        is_closer = (abs(theta2 - theta1) <= np.pi)
        if is_2bigger1:
            if is_closer:   # 逆时针
                mid_theta = np.linspace(theta1, theta2, 100)
            else:        # 顺时针
                theta1 += 2 * np.pi
                mid_theta = np.linspace(theta2, theta1, 100)
        else:
            if is_closer:   # 顺时针
                mid_theta = np.linspace(theta2, theta1, 100)
            else:     # 逆时针
                theta2 += 2 * np.pi
                mid_theta = np.linspace(theta1, theta2, 100)

        if draw:
            b = 1 / (theta2 - theta1)
            theta0 = (theta1 - r1 / b).item()
            r_values = b * (mid_theta - theta0)
            x = r_values * np.cos(mid_theta)
            y = r_values * np.sin(mid_theta)
            plt.plot(x, y)
        a = theta2 - theta1
        l = compo1(r2, a) - compo1(r1, a) + compo2(r2, a) - compo2(r1, a)
        return abs(l)
    def _calculate_distance_matrices(layers):
        _M = [torch.zeros((len(layers[i]), len(layers[i+1])), dtype=torch.float64) for i in range(len(layers)-1)]
        for i in range(len(layers)-1):
            for j in range(len(layers[i])):
                for k in range(len(layers[i+1])):
                    _M[i][j, k] = helicalDist(layers[i][j], layers[i+1][k])
        return _M
    def _generate_ring_M(_n):
        angle_points = []
        for i in range(len(_n)):
            angle_per_layer = []
            for j in range(_n[i]):
                angle = np.random.uniform(0, 2 * np.pi)
                angle_per_layer.append((i + 1, angle))
            angle_points.append(angle_per_layer)

        _M = _calculate_distance_matrices(angle_points)
        return angle_points, _M
    def generateRingData(_n):
        if not os.path.exists('./results/cache/ring{}.pt'.format(sum(_n))):
            angle_points, _M = _generate_ring_M(_n)
        else:
            _M = torch.load('./results/cache/ring{}.pt'.format(sum(_n)))
        _s, _t = _generate_s_t(_n)
        data = {
            'M': _M,
            'source': _s,
            'target': _t
        }
        return data
    def sinkhorn_distance(M, P, addition=None):
        distance = 0
        for k in range(len(P)):
            try:
                if addition == None:
                    distance += (P[k] * M[k]).sum()
                else:
                    distance += torch.sum(P[k] * M[k]) * addition[k]
            except Exception as e:
                print("Error at {}-th: {}".format(k, e))
                print("Please check shape: P[k] {}, M[k] {}".format(P[k].shape, M[k].shape))
        return distance
    def solve_with_gurobi(gp_model, write=False):
        start = time.time()
        gp_model.create_variable()
        gp_model.initProblem(write)
        min_cost, vars = gp_model.solve()
        end = time.time()
        return min_cost, vars, (end-start)
    def solve_with_pot_dp(M, n, s, t, type="emd", reg=1e-5, max_iter=10000):
        start = time.time()
        short_path = ot_solver.cpp_create_short_path(M, n)

        _max_num = np.max(short_path)
        _M = short_path / _max_num
        T = ot_solver.solve(s, t, _M, type, reg, max_iter)
        dis_dp = _max_num * np.sum(_M * T)
        end = time.time()
        return dis_dp, (end-start)
    def solve_with_mlot_single(s, t, M, eps, max_iter=1000):
        _max_num = max([np.max(Mi) for Mi in M])
        _M = [Mi / _max_num for Mi in M]
        record_max = [_max_num for _ in M]

        start = time.time()
        T, log = mlot.multi_sinkhorn_single(s, t, _M, eps, record_max, max_iter)
        end = time.time()

        return T, log, (end-start)
    
    SEED = 0
    np.random.seed(SEED)

    n = [2500, 5000, 2500]
    layer = len(n)
    N = sum(n)

    # data = generateLineData(n)
    data = generateRingData(n)
    M = data['M']
    s, t = data['source'], data['target']
    print("N = {}".format(N))
    print("n = {}".format(n))


    gp_model = MultilayerSolver(s, t, n, M)
    dis_gurobi, vars, time_gurobi = solve_with_gurobi(gp_model, write=False)
    print("[Gurobi]:")
    print("\tObjective= {:.4f}".format(dis_gurobi))
    print("\tTime= {:.3f}".format(time_gurobi))


    # P_gurobi = gp_model.getP()
    # all_layers = gp_model.getLayers()


    # second_layer = all_layers[1]  # gt 第二层分布 (都是三层的实验, 只算第二层的KL散度)
    # last_third_layer = all_layers[-3]
    # print("Second layer:", second_layer.shape, second_layer)
    # print("Last third layer:", last_third_layer.shape, last_third_layer)
    # torch.save(second_layer, './secondLayerN={}({}).pt'.format(N, SEED))
    # torch.save(last_third_layer, './lastThirdLayerN={}({}).pt'.format(N, SEED))
    
    # 保存所有层的 ground truth
    # for i, l in enumerate(all_layers):
    #     if i == 0 or i == len(all_layers)-1:
    #         continue
    #     torch.save(l, './results/cache/layer{}.pt'.format(i))
    #     print("save layer{}: shape={}".format(i, l.shape))


    # dis_dp, time_dp = solve_with_pot_dp(M, n, s, t, type="sinkhorn", reg=1e-3)
    # print("[DP]:")
    # print("\tObjective= {:.4f}".format(dis_dp))
    # print("\tTime= {:.3f}".format(time_dp))


    # T_single, log_single, time_single = solve_with_mlot_single(s, t, M, 3e-3, max_iter=10000)
    # dis_single = sinkhorn_distance(M, T_single)
    # print("[MLOT-Single]:")
    # print("\tObjective= {:.4f}".format(dis_single))
    # print("\tTime= {:.3f}".format(time_single))
