import os, random, argparse, pandas as pd, numpy as np
from tqdm import tqdm
import torch, torch.nn as nn
import sys
import pandas as pd
import anndata as ad
import scanpy as sc
device = torch.device('cuda')

# set random seed
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.random.manual_seed(SEED)

# load package requirments
from DeepRUOT.losses import MMD_loss, OT_loss1, OT_loss2, Density_loss, Local_density_loss 
from DeepRUOT.utils import group_extract, sample, to_np, generate_steps, cal_mass_loss, parser, _valid_criterions
from DeepRUOT.plots import plot_comparision, plot_losses
from DeepRUOT.train import train_un1
from DeepRUOT.models import velocityNet, growthNet, scoreNet, dediffusionNet, indediffusionNet, FNet, ODEFunc2
from DeepRUOT.constants import ROOT_DIR, DATA_DIR, NTBK_DIR, IMGS_DIR, RES_DIR
from DeepRUOT.exp import setup_exp
from DeepRUOT.eval import generate_plot_data
from torchdiffeq import odeint_adjoint as odeint

import torch.optim as optim

data = torch.load(sys.argv[1], weights_only = False)
suffix = sys.argv[2]  # output suffix appended to output files 

dim = data['x'].shape[1]
df = pd.DataFrame(np.hstack([data['t_idx'][:, None], data['x']]), columns = ['samples', ] + [f'x{i+1}' for i in range(dim)])
T = len(np.unique(data['t_idx']))
# DeepRUOT assumes the time interval is [0, T-1], so we need to rescale the diffusion coefficient accordingly
sigma = (0.25 / (T-1))**0.5
f_net = FNet(in_out_dim=dim, hidden_dim=128, n_hiddens=3, activation='leakyrelu')

import sys
sys.argv = [
    'DeepRUOT Training',
    '-d', 'file',
    '-c', 'ot1',
    '-n', 'emt',
]
args = parser.parse_args()
opts = vars(args)

groups = sorted(df.iloc[:, 0].unique())
steps = generate_steps(groups)
use_geo = opts['use_geo']
model_layers = opts['model_layers']
model_features = len(df.columns) - 1

optimizer = torch.optim.Adam(f_net.parameters())

# Use default DeepRUOT options 
opts['criterion']='ot1'
criterion =  _valid_criterions[opts['criterion']]()
sample_size = (opts['sample_size'], )
sample_with_replacement = opts['sample_with_replacement' ]
apply_losses_in_time = opts['apply_losses_in_time']
n_local_epochs = 1 # opts['local_epochs']
n_epochs = opts['epochs']
n_post_local_epochs = opts['local_post_epochs']
n_batches = opts['batches']
hold_one_out = opts['hold_one_out']
hold_out = opts['hold_out']
hinge_value = opts['hinge_value']
top_k = opts['top_k']
lambda_density = opts['lambda_density']
lambda_density_local = opts['lambda_density_local']
use_density_loss = opts['use_density_loss']
use_local_density = opts['use_local_density']
lambda_local = opts['lambda_local']
lambda_global = opts['lambda_global']
n_points=opts['n_points']
n_trajectories=opts['n_trajectories'] 
n_bins=opts['n_bins']

local_losses = {f'{t0}:{t1}':[] for (t0, t1) in steps}
batch_losses = []
globe_losses = []

f_net=f_net.to(device)
initial_size=df[df['samples']==0].x1.shape[0]
sample_sizes = df.groupby('samples').size()
ref0 = sample_sizes / sample_sizes.iloc[0]  
relative_mass = torch.tensor(ref0.values) 

sample_size = (df[df['samples']==0.0].values.shape[0],)
exp_dir = "."

# Pretrain: distribution reconstruction training 
for epoch in tqdm(range(n_local_epochs), desc='Pretraining Epoch'):
	l_loss, b_loss, g_loss = train_un1(
		f_net, df, groups, optimizer,50, 
		criterion = criterion, use_cuda = True,
		local_loss=True, global_loss=False, apply_losses_in_time=apply_losses_in_time,
		hold_one_out=hold_one_out, hold_out=hold_out, 
		 hinge_value=hinge_value, lambda_ot=0.1, lambda_mass=1, lambda_energy=0.001,
		 use_pinn=False, use_penalty=False,use_density_loss=False,lambda_density=10,
		top_k = top_k, sample_size = sample_size,relative_mass=relative_mass,initial_size=initial_size,
		sample_with_replacement = sample_with_replacement, logger=None, device=device,best_model_path=exp_dir+f"/best_model_{suffix}.pt"
	)
	for k, v in l_loss.items():  
		local_losses[k].extend(v)
	batch_losses.extend(b_loss)
	globe_losses.extend(g_loss)

f_net.load_state_dict(torch.load(os.path.join(exp_dir+f"/best_model_{suffix}.pt"),map_location=torch.device('cpu')))
f_net.to(device)
for param in f_net.g_net.parameters():
    param.requires_grad = False

# Pretrain: velocity 
for epoch in tqdm(range(n_local_epochs), desc='Pretraining Epoch'):
	l_loss, b_loss, g_loss = train_un1(
		f_net, df, groups, optimizer,30, 
		criterion = criterion, use_cuda = True,
		local_loss=True, global_loss=False, apply_losses_in_time=apply_losses_in_time,
		hold_one_out=hold_one_out, hold_out=hold_out, 
		 hinge_value=hinge_value, lambda_ot=0.1, lambda_mass=0, lambda_energy=0.001,
		 use_pinn=False, use_penalty=False,use_density_loss=False,lambda_density=10,
		top_k = top_k, sample_size = sample_size,relative_mass=relative_mass,initial_size=initial_size,
		sample_with_replacement = sample_with_replacement, logger=None, device=device,best_model_path=exp_dir+f"/best_model_{suffix}.pt"
	)
	for k, v in l_loss.items():  
		local_losses[k].extend(v)
	batch_losses.extend(b_loss)
	globe_losses.extend(g_loss)

f_net.load_state_dict(torch.load(os.path.join(exp_dir+f"/best_model_{suffix}.pt"),map_location=torch.device('cpu')))
f_net.to(device)

n=dim
samples = df['samples'].values
column_names = [f'x{i}' for i in range(1, n + 1)]

obsm_data = df[column_names].values
print("obsm_data shape:", obsm_data.shape)

adata = ad.AnnData(obs=pd.DataFrame(index=samples))
adata.obsm['X_pca'] = obsm_data
adata_loaded = adata

adata.obs['samples']=df['samples'].values

n_times = len(adata.obs["samples"].unique())
print(n_times)
X = [
    adata.obsm["X_pca"][adata.obs["samples"] == t]
    for t in range(n_times)
]

from DeepRUOT.utils import OTPlanSampler, ConditionalFlowMatcher, ExactOptimalTransportConditionalFlowMatcher, TargetConditionalFlowMatcher,SchrodingerBridgeConditionalFlowMatcher, VariancePreservingConditionalFlowMatcher,generate_state_trajectory
from DeepRUOT.models import scoreNet2

batch_size = df[df['samples']==0].x1.shape[0]
time = torch.Tensor(groups)
SF2M = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma)
sf2m_score_model=scoreNet2(in_out_dim=dim, hidden_dim=128,  activation='leakyrelu').float().to(device)
sf2m_optimizer = torch.optim.Adam(
    list(sf2m_score_model.parameters()), 1e-4
)
trajectory = generate_state_trajectory(X, n_times,batch_size, f_net, time, device)

from DeepRUOT.utils import get_batch

max_norm_ut = torch.tensor(0.0)
lambda_penalty=1 # 0 default, 1 if numerical issues
for i in tqdm(range(10_000)):
    sf2m_optimizer.zero_grad()
    t, xt, ut,eps = get_batch(SF2M, X, trajectory,batch_size, n_times, return_noise=True)
    t=torch.unsqueeze(t,1)
    lambda_t = SF2M.compute_lambda(t % 1)
    value_st=sf2m_score_model(t, xt)
    st = sf2m_score_model.compute_gradient(t, xt)
    positive_st = torch.relu(value_st)
    penalty = lambda_penalty * torch.max(positive_st)
    # max_norm_ut = torch.maximum(torch.max(torch.sum(ut**2, dim=1)), max_norm_ut)
    score_loss = torch.mean((lambda_t[:, None] * st + eps) ** 2)
    if i % 100 == 0:
        print(torch.max(positive_st))
        print(f"{i}:  {score_loss.item():0.2f}")
    loss = score_loss+penalty
    loss.backward()
    sf2m_optimizer.step()

torch.save(sf2m_score_model.state_dict(), os.path.join(exp_dir, f"score_model_{suffix}.pt"))

datatime0=torch.tensor(df[df['samples'] == 0.0].values[:, 1:])

sf2m_score_model.load_state_dict(torch.load(os.path.join(exp_dir, f"score_model_{suffix}.pt"),map_location=torch.device('cpu')))
sf2m_score_model.to(device)
f_net.load_state_dict(torch.load(os.path.join(exp_dir+f"/best_model_{suffix}.pt"),map_location=torch.device('cpu')))
f_net.to(device)

from DeepRUOT.train import train_all
device='cpu'
optimizer = torch.optim.SGD(list(f_net.parameters())+list(sf2m_score_model.parameters()),1e-5)

for epoch in tqdm(range(n_local_epochs), desc='Training Epoch'):
	l_loss, b_loss, g_loss = train_all(
		f_net, df, groups, optimizer,10,
		criterion = criterion, use_cuda = True,
		local_loss=True, global_loss=False, apply_losses_in_time=apply_losses_in_time,
		hold_one_out=hold_one_out, hold_out=hold_out, sf2m_score_model=sf2m_score_model,
		 hinge_value=hinge_value,datatime0=datatime0,device=device, lambda_initial=0.1,
		 use_pinn=True, use_penalty=True,use_density_loss=False,lambda_density=10,
		top_k = top_k, sample_size = sample_size,relative_mass=relative_mass,initial_size=initial_size,
		sample_with_replacement = sample_with_replacement, logger=None, sigmaa=sigma,lambda_pinn=1,
	)
	for k, v in l_loss.items():  
		local_losses[k].extend(v)
	batch_losses.extend(b_loss)
	globe_losses.extend(g_loss)

torch.save(sf2m_score_model.state_dict(), os.path.join(exp_dir, f"score_model_result_{suffix}"))
torch.save(f_net.state_dict(), os.path.join(exp_dir, f"model_result_{suffix}"))


