import sys
import os
import torch
import numpy as np
import shutil
from tqdm import trange

class pyOMT_raw():
    """
    Semi-discrete OT with two solver options:

    1) 'adam': the old Adam-based subgradient method
       - d_adam_m, d_adam_v used for momentum
       - self.run_adam()
    2) 'drag': DRAG solver (entropic drag with decreasing reg)
       - eps_t = 0.1/t^(1/3), gamma_t = sqrt(num_P)/t^(2/3)
       - self.run_drag()

    We keep:
      pre_cal(count=0) -> for gen_P
      cal_measure() -> also for gen_P
      self.d_U -> so gen_P can do torch.sort(self.d_U, dim=0,...)
    """

    def __init__(self, h_P, num_P, dim, max_iter, bat_size_P, bat_size_n, solver_type="drag", lr=1e-3):
        """
        Args:
            h_P: CPU Tensor [num_P, dim], possibly float64. We'll convert to float32.
            num_P: number of target points
            dim: dimension
            max_iter: total number of steps
            bat_size_P: chunk size for partial reads (rarely used here)
            bat_size_n: batch size of MC samples
            solver_type: either "adam" or "drag"
            lr: used only if solver_type=="adam"
        """
        if num_P % bat_size_P != 0:
            sys.exit('Error: (num_P) is not a multiple of (bat_size_P)')

        self.num_P = num_P
        self.dim = dim
        self.max_iter = max_iter
        self.bat_size_P = bat_size_P
        self.bat_size_n = bat_size_n

        # Which solver do we use?
        if solver_type not in ["adam", "drag"]:
            raise ValueError(f"solver_type must be 'adam' or 'drag', got {solver_type}")
        self.solver_type = solver_type
        self.lr = lr  # used only if solver_type=="adam"

        # Convert target measure to float32
        self.h_P = h_P.float()

        # We'll do everything on CUDA float32
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # The dual potentials h
        self.d_h = torch.zeros(num_P, dtype=torch.float32, device=self.device)

        # Storing MC samples in d_volP
        self.d_volP = torch.empty((bat_size_n, dim), dtype=torch.float32, device=self.device)

        # Keep a 2D buffer self.d_U so gen_P can do torch.sort(self.d_U, dim=0,...)
        # shape [num_P, bat_size_n]
        self.d_U = torch.empty((num_P, bat_size_n), dtype=torch.float32, device=self.device)

        # Quasi-random sampler
        self.qrng = torch.quasirandom.SobolEngine(dimension=self.dim)

        # For Adam solver only:
        self.d_adam_m = torch.zeros(num_P, dtype=torch.float32, device=self.device)
        self.d_adam_v = torch.zeros(num_P, dtype=torch.float32, device=self.device)

        print(f"Allocated GPU memory: {torch.cuda.memory_allocated()/1e6:.3f} MB")
        print(f"Reserved GPU memory: {torch.cuda.memory_reserved()/1e6:.3f} MB")

    def pre_cal(self, count=0):
        """
        gen_P calls p_s.pre_cal(ii).
        We'll ignore 'count'.
        Just sample from [-0.5, 0.5]^dim using Sobol + shift.
        """
        self.qrng.draw(self.bat_size_n, out=self.d_volP)  # [0,1]
        self.d_volP.add_(-0.5)                            # [-0.5,0.5]

    def cal_measure(self):
        """
        gen_P calls p_s.cal_measure().
        We'll define self.d_U as the 'score' matrix:
          d_U[i,b] = h_i - cost(x_b, p_i),
        so that gen_P can do sort(...) and pick top indices.
        We'll do one big cost matrix, shape [bat_size_n, num_P],
        then transpose to [num_P, bat_size_n].
        """
        d_P = self.h_P.to(self.device, torch.float32)
        cost = 0.5 * torch.cdist(self.d_volP, d_P, p=2).pow(2)  # shape [B, M]

        scores = self.d_h.unsqueeze(0) - cost  # [B,M]
        self.d_U.copy_(scores.transpose(0,1))  # [M,B]

    def run_drag(self):
        """
        Entropic drag solver:
          eps_t = 0.1 / t^(1/3)
          gamma_t = sqrt(num_P)/ t^(2/3)
        """
        vect_nu = torch.ones(self.num_P, dtype=torch.float32, device=self.device)/self.num_P
        self.d_h.uniform_(-0.01, 0.01)
        grad_history = []

        pbar = trange(1, self.max_iter+1, desc="drag (EOT)", dynamic_ncols=True)
        for t in pbar:
            eps_t = 0.1 / (t**(1/3))
            gamma_t = (self.num_P**0.5) / (t**(2/3))

            self.pre_cal()
            d_P = self.h_P.to(self.device, torch.float32)
            cost = 0.5 * torch.cdist(self.d_volP, d_P, p=2).pow(2)

            scores = (self.d_h.unsqueeze(0) - cost)/eps_t
            chi = torch.nn.functional.softmax(scores, dim=1)
            grad = chi.mean(dim=0) - vect_nu

            self.d_h -= gamma_t * grad
            g_norm = grad.norm().item()
            if t % 10 == 0:
                grad_history.append(g_norm)
            pbar.set_postfix({"eps": f"{eps_t:.4f}",
                              "gamma": f"{gamma_t:.4f}",
                              "grad_norm": f"{g_norm:.4e}"})
            # if g_norm < 1e-8:
            #    break

        print("drag done. Saving final potentials to './h_final_drag.pt'")
        torch.save(self.d_h, './h_final_drag.pt')

        return grad_history

    def run_adam(self):
        d_g = torch.zeros_like(self.d_h, dtype=torch.float32, device=self.device)

        pbar = trange(self.max_iter, desc="Adam Solver", dynamic_ncols=True)
        grad_history = []
        for t in pbar:
            # 1) sample from the source
            self.pre_cal()

            # 2) cost => shape [B, M]
            d_P = self.h_P.to(self.device, torch.float32)
            cost = 0.5*torch.cdist(self.d_volP, d_P, p=2).pow(2)

            # argmax subgradient
            scores = self.d_h.unsqueeze(0) - cost  # [B,M]
            argmax_idx = torch.argmax(scores, dim=1)  # [B]
            # bin counting
            d_g.fill_(0.0)
            bin_counts = torch.bincount(argmax_idx, minlength=self.num_P)
            d_g += bin_counts.float()/float(self.bat_size_n)
            # subgradient = g - 1/num_P
            d_g -= 1.0/self.num_P

            # Adam update
            beta1 = 0.9
            beta2 = 0.999
            eps = 1e-8

            self.d_adam_m.mul_(beta1).add_(d_g*(1-beta1))
            self.d_adam_v.mul_(beta2).addcmul_(d_g, d_g, value=(1-beta2))

            # weight update
            m_hat = self.d_adam_m/(1-beta1**(t+1))
            v_hat = self.d_adam_v/(1-beta2**(t+1))

            # step size is self.lr
            update = self.lr*m_hat/(v_hat.sqrt()+eps)
            self.d_h -= update

            # optional: center h
            self.d_h -= self.d_h.mean()

            g_norm = d_g.norm().item()
            if t % 10 == 0:
                grad_history.append(g_norm)
            pbar.set_postfix({"g_norm": f"{g_norm:.4e}"})

            # if g_norm<1e-3:
            #    break

        print("Adam-based solver done. Saving final potentials to './h_final_adam.pt'")
        torch.save(self.d_h, './h_final_adam.pt')
        return grad_history

    def set_h(self, h_tensor):
        self.d_h.copy_(h_tensor)

# ------- HelpeR train_omt -----------
def train_omt(p_s):
    """
    Decides which solver to run based on p_s.solver_type
    """
    if p_s.solver_type=="drag":
        return p_s.run_drag()
    elif p_s.solver_type=="adam":
        return p_s.run_adam()
    else:
        raise ValueError(f"Unknown solver_type {p_s.solver_type}")

# Utilities if you need them
def clear_file_in_folder(folder):
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
        except Exception as e:
            print(f"Failed to delete {file_path}, reason: {e}")

def clear_temp_data():
    folder_names = ['./adam_m', './adam_v', 'h']
    for folder in folder_names:
        clear_file_in_folder(folder)


    # np.savetxt('./h_final.csv',p_s.d_h.cpu().numpy(), delimiter=',')