import argparse
import os
import torch
from pkdataloader import PKDataloader
from dtidataloader import DTIDataloader

from imagand import SDT, EMA

from torch import nn
import math

from diffusers.optimization import get_scheduler
from tqdm import tqdm

from dkiloader import LoadDKI

from utils import *

from sklearn.metrics import mean_squared_error
import csv

from sklearn.preprocessing import QuantileTransformer, PowerTransformer
from diffusion import DDIMScheduler

parser = argparse.ArgumentParser()

parser.add_argument('--lr', dest='lr', type=float, default=1e-3)
parser.add_argument('--wd', dest='wd', type=float, default=5e-2)
parser.add_argument('--warmup', dest='warmup', type=int, default=200)
parser.add_argument('--n_timesteps', dest='n_timesteps', type=int, default=2000)
parser.add_argument('--n_inference_timesteps', dest='n_inference_timesteps', type=int, default=150)
parser.add_argument('--num_epochs', dest='num_epochs', type=int, default=3000)
parser.add_argument('--batch_size', dest='batch_size', type=int, default=256)
parser.add_argument('--gamma', dest='gamma', type=float, default=0.994)

parser.add_argument('--data_dir', dest='data_dir', type=str, default='./data')
parser.add_argument('--save_dir', dest='save_dir', type=str, default='./output')

parser.add_argument('--data_type', choices=['pk', 'dti', 'both'], default='pk')
parser.add_argument('--generate_both', action='store_true', default=False)

parser.add_argument('--use_drug_dki', action='store_true', default=False)
parser.add_argument('--use_target_dki', action='store_true', default=False)

parser.add_argument('--transform', choices=['gaussian', 'uniform', 'power', 'none'], default='gaussian')
parser.add_argument('--embed_model', choices=[
    't5',
    'deberta',
    'chemberta_zinc',
    'chemberta_10m'
], default='t5')

args = parser.parse_args()

if not os.path.exists(args.save_dir):
    os.makedirs(args.save_dir)

noise = args.transform
pk_transform, dti_transform = None, None
if args.transform == "gaussian":
    print("Using Gaussian Transform...")
    pk_transform = QuantileTransformer(n_quantiles=50, output_distribution='normal')
    dti_transform = QuantileTransformer(n_quantiles=50, output_distribution='normal')
elif args.transform == "uniform":
    print("Using Uniform Transform...")
    pk_transform = QuantileTransformer(n_quantiles=50, output_distribution='uniform')
    dti_transform = QuantileTransformer(n_quantiles=50, output_distribution='uniform')
elif args.transform == "power":
    print("Using Power Transform...")
    pk_transform = PowerTransformer()
    dti_transform = PowerTransformer()
else:
    print("Using min-max Transform...")

os.makedirs(os.path.dirname(args.save_dir), exist_ok=True)

dkiloader = LoadDKI(args.data_dir)
drug_dki=None
if args.use_drug_dki:
    drug_dki = dkiloader.get_drug_dki("./data/pk_dti_fp_embeddings.npy")

target_dki=None
if args.use_target_dki:
    target_dki = dkiloader.get_target_dki("./data/prot_emb.pkl")

pk_dataloader = PKDataloader(
    args.embed_model,
    args.data_dir,
    pk_transform = pk_transform,
    drug_dki = drug_dki
)
dti_dataloader = DTIDataloader(
    args.embed_model,
    args.data_dir,
    dti_transform = dti_transform,
    drug_dki = drug_dki,
    target_dki = target_dki
)

pk_set = pk_dataloader.dataset
dti_set = dti_dataloader.dataset

dmss = pk_set.dmss if args.data_type == "pk" else dti_set.dmss
combined_dmss = pk_set.dmss + dti_set.dmss

pk_trainset, pk_valset = torch.utils.data.random_split(pk_set, [0.9,0.10])
dti , dti_testset = torch.utils.data.random_split(dti_set, [0.9,0.1])
dti_trainset, dti_valset = torch.utils.data.random_split(dti, [0.98,0.02])
# valset, testset = torch.utils.data.random_split(valset, [0.5,0.5])

# print(trainset[1]['gt'].shape)
# print(trainset[1]['ma'].shape)
# print(trainset[1]['ft'].shape)

pk_trainloader = torch.utils.data.DataLoader(pk_trainset, batch_size=args.batch_size, shuffle=True)
pk_valloader = torch.utils.data.DataLoader(pk_valset, batch_size=args.batch_size, shuffle=True)  

dti_trainloader = torch.utils.data.DataLoader(dti_trainset, batch_size=args.batch_size//2, shuffle=True)
dti_valloader = torch.utils.data.DataLoader(dti_valset, batch_size=args.batch_size//2, shuffle=True)  
dti_testloader = torch.utils.data.DataLoader(dti_testset, batch_size=args.batch_size//2, shuffle=True)

if args.data_type == "pk":
    steps_per_epoch = len(pk_trainset) 
elif args.data_type == "dti":
    steps_per_epoch = len(dti_trainset)
else:
    steps_per_epoch = len(dti_trainset) + 50 * len(pk_trainset) 

device = "cuda"

model = SDT(
    time_dim = 64,
    drug_cond_size = 768,
    prot_cond_size = 1024,
    pk_y_dim = 12,
    dti_y_dim = 3,
    patch_size = 16,
    dim = 256,
    depth = 12,
    heads = 16,
    mlp_dim = 768,
    dropout = 0.1,
    emb_dropout = 0.1
)
model.to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {total_params}")

def train(model, ema, gamma, dataloader, noise_scheduler, optimizer, lr_scheduler):
    model.train()
    running_loss = 0
    global_step = 0
    mse_loss = nn.MSELoss(reduction='none')
    for i, batch in enumerate(tqdm(dataloader)):
        de = batch['de'].to(device).float()
        pe = batch['pe'].to(device).float() if 'pe' in batch else None
        pk = batch['pk'].to(device).float() if 'pk' in batch else None
        dti = batch['dti'].to(device).float() if 'dti' in batch else None
        mask = batch['ma'].to(device)
        ddki = batch['dd'].to(device).float() if 'dd' in batch else None
        tdki = batch['td'].to(device).float() if 'td' in batch else None
        bs = de.shape[0]

        pk_noise = sample_noise(bs, pk_set.dmss)
        pk_noise = torch.tensor(pk_noise).to(device).float()

        dti_noise = sample_noise(bs, dti_set.dmss)
        dti_noise = torch.tensor(dti_noise).to(device).float()

        noise = sample_noise(bs, dmss)
        noise = torch.tensor(noise).to(device).float()
        timesteps = torch.randint(0,
                                noise_scheduler.num_train_timesteps,
                                (bs,),
                                device=device).long()

        noisy_pk = None
        noisy_dti = None
        if pk is not None:
            noisy_pk = noise_scheduler.add_noise(pk, pk_noise, timesteps)

        if dti is not None:
            noisy_dti = noise_scheduler.add_noise(dti, dti_noise, timesteps)

        optimizer.zero_grad()
        pk_noise_pred, dti_noise_pred = model(
            timesteps = timesteps,
            drug_cond = de,
            prot_cond = pe,
            pk = noisy_pk,
            dti = noisy_dti,
            drug_dki = ddki,
            target_dki = tdki
        )
        #noise_pred = pk_noise_pred if pk_noise_pred is not None else dti_noise_pred

        loss = 0
        if pk_noise_pred is not None:
            loss += mse_loss(pk_noise_pred, pk_noise)

        if dti_noise_pred is not None:
            loss += mse_loss(dti_noise_pred, dti_noise)

        #loss = mse_loss(noise_pred, noise)
        loss = (loss * mask.float()).sum()
        non_zero_elements = mask.sum()
        mse_loss_val = loss / non_zero_elements
        mse_loss_val.backward()
        optimizer.step()
        lr_scheduler.step()

        ema.update_params(gamma)
        gamma = ema.update_gamma(global_step)

        running_loss += mse_loss_val.item()
        global_step += 1
    return running_loss/global_step

def evaluate(
    name,
    e, 
    ema, 
    dataloader, 
    noise_scheduler, 
    n_inference_timesteps,
    pk_transform=None,
    dti_transform=None
):
    ema.ema_model.eval()
    before_mse_pk, before_mse_dti = 0, 0
    running_mse_pk, running_mse_dti = 0, 0
    global_step = 0
    vals = {}
    device = 'cuda'
    ema.ema_model.to(device)
    noise_scheduler.set_timesteps(n_inference_timesteps)
    with torch.no_grad():
        for i, batch in enumerate(tqdm(dataloader)):
            sm = batch['sm']
            mask = batch['ma']
            de = batch['de'].to(device).float()
            pe = batch['pe'].to(device).float() if 'pe' in batch else None
            pk = batch['pk'].to(device).float() if 'pk' in batch else None
            dti = batch['dti'].to(device).float() if 'dti' in batch else None
            ddki = batch['dd'].to(device).float() if 'dd' in batch else None
            tdki = batch['td'].to(device).float() if 'td' in batch else None
            bs = de.shape[0]

            tg = ["None"] * len(sm)
            if 'tg' in batch:
                tg = batch['tg']

            ys_pk = None
            ys_dti = None
            if pk is not None:
                ys = sample_noise(bs, pk_set.dmss)
                ys = torch.tensor(ys).to(device).float()
                ys_pk = ys

            if pe is not None:
                ys = sample_noise(bs, combined_dmss)
                ys = torch.tensor(ys).to(device).float()
                ys_pk = ys[:,:-3]
                ys_dti = ys[:,-3:]
                
                
            #timestep = torch.tensor([n_inference_timesteps], device=device).long()
            #ys[mask] = noise_scheduler.add_noise(gt[mask], ys[mask], timestep)
            raw_pk_mse = 0
            raw_dti_mse = 0
            if pk is not None:
                raw_pk_mse = mean_squared_error(pk[mask].flatten().cpu(), ys_pk[mask].flatten().cpu())
            if pe is not None:
                raw_dti_mse = mean_squared_error(dti[mask].flatten().cpu(), ys_dti[mask].flatten().cpu())
            # non_zero_elements = mask.sum()
            # raw_mse = raw_mse / non_zero_elements

            gen_pk, gen_dti = noise_scheduler.generate(
                ema.ema_model,
                de,
                pe,
                ys_pk,
                ys_dti,
                drug_dki = ddki,
                target_dki = tdki,
                num_inference_steps=n_inference_timesteps,
                eta=0.01,
                use_clipped_model_output=True,
                device = device
            )
            
            dti_mse = 0
            pk_mse = 0
            if pk is not None:
                pk_mse = mean_squared_error(pk[mask].flatten().cpu(), gen_pk[mask].flatten().cpu())
            if pe is not None:
                dti_mse = mean_squared_error(dti[mask].flatten().cpu(), gen_dti[mask].flatten().cpu())

            pks = list(gen_pk.cpu().numpy())
            if gen_dti is not None:
                dtis = list(gen_dti.cpu().numpy())

            for i, s in enumerate(sm):
                k = "{},{}".format(s, tg[i])

                pk = pk_transform.inverse_transform((np.array(pks[i])*5.1993).reshape(1,-1))[0] if pk_transform is not None else pks[i]
                v = pk
                if gen_dti is not None:
                    dti = dti_transform.inverse_transform((np.array(dtis[i])*5.1993).reshape(1,-1))[0] if dti_transform is not None else dtis[i]
                    v = np.concatenate([pk, dti], axis = 0)
                vals[k] = v

            before_mse_pk += raw_pk_mse 
            before_mse_dti += raw_dti_mse
            running_mse_pk += pk_mse 
            running_mse_dti += dti_mse
            global_step += 1

    with open(args.save_dir+'{}_{}_dict.csv'.format(name, e), 'w') as csv_file:  
        writer = csv.writer(csv_file)
        for key, value in vals.items():
            s,t = key.split(",")
            writer.writerow([s, t, value])

    return (running_mse_pk + running_mse_dti) / global_step, \
           (before_mse_pk + before_mse_dti) / global_step, \
           (running_mse_pk / global_step, before_mse_pk/ global_step), \
           (running_mse_dti / global_step, before_mse_dti / global_step)

total_num_steps = (steps_per_epoch * args.num_epochs)

ema = EMA(model, args.gamma, total_num_steps)
ns = DDIMScheduler(num_train_timesteps=args.n_timesteps,
                                beta_start=0.,
                                beta_end=0.7,
                                beta_schedule="cosine")

optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.wd,
    )

lr_scheduler = get_scheduler(
        "cosine",
        optimizer=optimizer,
        num_warmup_steps=args.warmup,
        num_training_steps=total_num_steps,
    )

l = ""
best_mse = 1
for e in range(args.num_epochs):

    loss =0 
    if args.data_type == "pk":
        loss = train(model, ema, args.gamma, pk_trainloader, ns, optimizer, lr_scheduler)
    elif args.data_type == "dti":
        loss = train(model, ema, args.gamma, dti_trainloader, ns, optimizer, lr_scheduler)
    else:
        for pk_e in range(50):
            pk_loss = train(model, ema, args.gamma, pk_trainloader, ns, optimizer, lr_scheduler)
            print(e, "pk_avgloss {}".format(pk_loss))
            loss += pk_loss

            # if (pk_e % 5 == 0) and (pk_e > 0):
            #     mse, bmse, pkmse, dtibmse = evaluate(f"pk_partial_{pk_e}", e, ema, pk_valloader, ns, args.n_inference_timesteps, pk_transform, dti_transform)
            #     print(e, "pk_partial avgloss {}, avgvalmse {}, beforemse: {}, {}, {}".format(loss, mse, bmse, pkmse, dtibmse))

        mse, bmse, pkmse, dtibmse = evaluate(f"pk_partial_{pk_e}", e, ema, pk_valloader, ns, args.n_inference_timesteps, pk_transform, dti_transform)
        print(e, "pk_partial avgloss {}, avgvalmse {}, beforemse: {}, {}, {}".format(loss, mse, bmse, pkmse, dtibmse))
        dti_loss = train(model, ema, args.gamma, dti_trainloader, ns, optimizer, lr_scheduler)
        print(e, "dti_avgloss {}".format(dti_loss))
        loss += dti_loss

        loss /= 51

    if (e % 10 == 0) and (e > 0):

        if args.data_type == "pk":
            mse, bmse, pkmse, dtibmse = evaluate("pk", e, ema, pk_valloader, ns, args.n_inference_timesteps, pk_transform, dti_transform)
        else:
            if args.data_type == "dti":
                mse, bmse, pkmse, dtibmse = evaluate("dti", e, ema, dti_valloader, ns, args.n_inference_timesteps, pk_transform, dti_transform)
            else:
                mse, bmse, pkmse, _ = evaluate("pk", e, ema, pk_valloader, ns, args.n_inference_timesteps, pk_transform, dti_transform)
                mse, bmse, _, dtibmse = evaluate("dti", e, ema, dti_valloader, ns, args.n_inference_timesteps, pk_transform, dti_transform)

            if e % 50 == 0:
                mse, bmse, _, dtibmse = evaluate("dti_test", e, ema, dti_testloader, ns, args.n_inference_timesteps, pk_transform, dti_transform)

        print(e, "avgloss {}, avgvalmse {}, beforemse: {}, {}, {}".format(loss, mse, bmse, pkmse, dtibmse))
        l += str({
            "type": "val",
            "e":e,
            "avgloss":loss,
            "avgvalmse":mse,
            "beforemse":bmse,
            "pk": pkmse,
            "dti": dtibmse,
        }) + "\n"

        if mse < best_mse:
            best_mse = mse

            if e>10 and args.data_type != "pk":
                _, _, _, _ = evaluate("dti_best", e, ema, dti_testloader, ns, args.n_inference_timesteps, pk_transform, dti_transform)
            torch.save({
                'e': e,
                'ema_model': ema.ema_model.state_dict(),
                'model': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, args.save_dir+"best_model.pt")
    else:
        print(e, "avgloss {}".format(loss))
        l += str({
            "type": "train",
            "e":e,
            "avgloss":loss,
        }) + "\n"

    with open(args.save_dir+'output.txt', 'w') as file:
        file.write(l)