"""
Train a noised image classifier on ImageNet.
"""

import os
import pandas as pd
import torch as th
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.optim import AdamW
from Dataloader import loader

from pathlib import Path
import sys
sys.path.append(str(Path.cwd()))

from configs import get_configs

from Diffusion import logger, dist_util

from Diffusion.fp16_util import MixedPrecisionTrainer
from Diffusion.resample import create_named_schedule_sampler
from Diffusion.train_util import parse_resume_step_from_filename, log_loss_dict

from script_util import create_anti_causal_predictor, create_gaussian_diffusion


def main():
    config = get_configs.get_default_configs()

    dist_util.setup_dist()
    logger.configure(Path(config.experiment_name) / ("classifier_train_" + "_".join(config.classifier.label)),
                     format_strs=["log", "stdout", "csv", "tensorboard"])
    
    logger.log("creating model and diffusion...")
    diffusion = create_gaussian_diffusion(config)

    model = create_anti_causal_predictor(config)
    model.to(dist_util.dev())

    if config.classifier.training.noised:
        schedule_sampler = create_named_schedule_sampler(
            config.classifier.training.schedule_sampler, diffusion
        )

    logger.log("creating data loader...")
    image_clinical_df=config.data.ANDI_path
    tadpole_data = pd.read_csv(image_clinical_df)

    tadpole_data = pd.read_csv('/home/s2263384/.cache/latest_longitudinal_training.csv')
    # no condtioning on x
    data = loader.get_data_loader(tadpole_data, config.score_model.training.batch_size, split_set='train')
    val_data = loader.get_data_loader(tadpole_data, config.score_model.training.batch_size, split_set='val')
    
    logger.log("training...")
    # Needed for creating correct EMAs and fp16 parameters.
    dist_util.sync_params(model.parameters())
    
    resume_step = 0

    mp_trainer = MixedPrecisionTrainer(
        model=model, use_fp16=config.classifier.training.classifier_use_fp16, initial_lg_loss_scale=16.0
    )

    model = DDP(
        model,
        device_ids=[dist_util.dev()],
        output_device=dist_util.dev(),
        broadcast_buffers=False,
        bucket_cap_mb=128,
        find_unused_parameters=False,
    )

    logger.log(f"creating optimizer...")
    opt = AdamW(mp_trainer.master_params, lr=config.classifier.training.lr,
                weight_decay=config.classifier.training.weight_decay)

    logger.log("training classifier model...")

    def forward_backward_log(data_loader, prefix="train"):
        data_dict = next(data_loader)
        labels = {}
        for label_name in config.classifier.label:
            assert label_name in list(data_dict.keys()), f'label {label_name} are not in data_dict{data_dict.keys()}'
            labels[label_name] = data_dict[label_name].to(dist_util.dev())

        batch = data_dict["image"].to(dist_util.dev())

        # Noisy images
        if config.classifier.training.noised:
            t, _ = schedule_sampler.sample(batch.shape[0], dist_util.dev())
            batch = diffusion.q_sample(batch, t)
        else:
            t = th.zeros(batch.shape[0], dtype=th.long, device=dist_util.dev())
        loss_dict = get_predictor_loss(model, labels, batch, t)
        loss = th.stack(list(loss_dict.values())).sum()
        losses = {f"{prefix}_{loss_name}": loss_value.detach() for loss_name, loss_value in loss_dict.items()}
        log_loss_dict(diffusion, t, losses)

        del losses
        loss = loss.mean()
        if loss.requires_grad:
            mp_trainer.zero_grad()
            mp_trainer.backward(loss)

    for step in range(config.classifier.training.iterations - resume_step):
        logger.logkv("step", step + resume_step)
        logger.logkv(
            "samples",
            (step + resume_step + 1) * config.classifier.training.batch_size * dist.get_world_size(),
        )
        if config.classifier.training.anneal_lr:
            set_annealed_lr(opt, config.classifier.training.lr,
                            (step + resume_step) / config.classifier.training.iterations)
        forward_backward_log(data)
        mp_trainer.optimize(opt)
        if val_data is not None and not step % config.classifier.training.eval_interval:
            with th.no_grad():
                with model.no_sync():
                    model.eval()
                    forward_backward_log(val_data, prefix="val")
                    model.train()
        if not step % config.classifier.training.log_interval:
            logger.dumpkvs()
        if (
                step
                and dist.get_rank() == 0
                and not (step + resume_step) % config.classifier.training.save_interval
        ):
            logger.log("saving model...")
            save_model(mp_trainer, opt, step + resume_step)

    if dist.get_rank() == 0:
        logger.log("saving model...")
        save_model(mp_trainer, opt, step + resume_step)
    dist.barrier()


def get_predictor_loss(model, labels, batch, t):
    output = model(batch, timesteps=t)
    loss_dict = {}
    loss_dict["loss"] = F.mse_loss(output.view(-1), list(labels.values())[0].to(th.float) , reduction="mean")
    return loss_dict


def set_annealed_lr(opt, base_lr, frac_done):
    lr = base_lr * (1 - frac_done)
    for param_group in opt.param_groups:
        param_group["lr"] = lr


def save_model(mp_trainer, opt, step):
    if dist.get_rank() == 0:
        th.save(
            mp_trainer.master_params_to_state_dict(mp_trainer.master_params),
            os.path.join(logger.get_dir(), f"model{step:06d}.pt"),
        )
        th.save(opt.state_dict(), os.path.join(logger.get_dir(), f"opt{step:06d}.pt"))

if __name__ == "__main__":
    main()
