import torch.nn as nn
import datetime
import sys
import time
from torch.utils.tensorboard import SummaryWriter

from visualize import *
from dataset import *
from util import *

from model.adversary import Classifier
from sparsity.block import sparse_criterion_with_label
from parse_args import parse_arguments
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from causal import MLPImplicitSCM
from causal import MLPFixedOrderSCM, MLPVariableOrderCausalModel, UnstructuredPrior
from model.lcm import ILCM, ELCM, SoftILCM
from model.encoder import SONEncoder, GaussianEncoder
from nets import make_mlp
from causal import HeuristicInterventionEncoder
from training import VAEMetrics
from experiment_utils import *
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel

def set_logger(args, model_name, stage=''):
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    formatter = logging.Formatter('%(message)s')

    console = logging.StreamHandler(sys.stdout)
    console.setFormatter(formatter)
    console.setLevel(logging.DEBUG)
    logger.addHandler(console)

    if args.ckpt:
        ckpt = args.ckpt.split('/')[-1][:-3].replace('_', '-')
    else:
        ckpt = 'public'

    print(f'Save log into {model_name}')
    folder_name = os.path.join(args.expdir, f'{model_name}/{stage}')
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)

    handler = logging.FileHandler(f'{folder_name}/{model_name}.log', 'w', 'utf-8')
    handler.setFormatter(formatter)
    handler.setLevel(logging.DEBUG)
    logger.addHandler(handler)

    logging.getLogger('matplotlib.font_manager').disabled = True
    logging.getLogger('PIL').setLevel(logging.WARNING)
    logging.debug(f'{time.asctime(time.localtime())}')

    logging.info(args)


def preprocess_data(args, df):
    """
    Preprocess EpicKitchens and ProcTHOR datasets
    """

    # attributes
    dict_noun_index = {k: v for v, k in enumerate(df['noun_class'].unique())}
    dict_noun_class = {v: k for v, k in enumerate(df['noun_class'].unique())}
    dict_verb_index = {k: v for v, k in enumerate(df['verb_class'].unique())}
    dict_verb_class = {v: k for v, k in enumerate(df['verb_class'].unique())}
    df['noun_index'] = df.apply(lambda row: dict_noun_index[row.noun_class], axis=1)
    df['verb_index'] = df.apply(lambda row: dict_verb_index[row.verb_class], axis=1)

    num_instance = len(df)
    num_noun = len(df['noun_class'].unique())
    num_verb = len(df['verb_class'].unique())
    logging.info(f'Dataset stat: # instance {num_instance}, # noun {num_noun}, # verb {num_verb}')

    # symmetry
    # if args.dataset == 'procthor':
    #     # pdb.set_trace()
    #     from procthor.action import action_symmetry
    #     symmetric_verb_class = action_symmetry()
    #     symmetric_verb_index = {dict_verb_index[k]: dict_verb_index[v] for k, v in symmetric_verb_class.items()}
    # else:
    #     symmetric_verb_index = None
    symmetric_verb_index = None

    # rebalance data
    dict_verb, dict_noun, dict_verb_noun = df_to_dict(df)
    stat_verb = dict_to_stat(dict_verb)
    # show_stat(stat_verb, figname=f'fig/stat_{args.dataset}_verb')

    # feasible combinations
    bool_verb_noun = torch.zeros((num_verb, num_noun)).bool()
    for (verb, noun) in dict_verb_noun.keys():
        bool_verb_noun[(dict_verb_index[verb], dict_noun_index[noun])] = True

    dict_verb = balance_stat(dict_verb, stat_verb, args.nature_seed)
    stat_verb = dict_to_stat(dict_verb)
    # show_stat(stat_verb, figname=f'fig/stat_{args.dataset}_verb_reb')

    indices = [name for names in dict_verb.values() for name in names]
    logging.info(f'{len(indices)} / {len(df)} instances are kept from rebalance')
    df = df[df.index.isin(indices)].reset_index(drop=True)

    # block preprocessing
    # if args.dataset == "procthor":
    #     num_blk = num_verb - int(len(symmetric_verb_index) / 2)
    #     verb_block = -torch.ones(num_verb, dtype=torch.uint8)
    #     cnt = 0
    #     for idx_verb in range(num_verb):
    #         if idx_verb in symmetric_verb_index and symmetric_verb_index[idx_verb] < idx_verb:
    #             verb_block[idx_verb] = verb_block[symmetric_verb_index[idx_verb]]
    #         else:
    #             verb_block[idx_verb] = cnt
    #             cnt += 1
    #     assert verb_block.max().int() + 1 == num_blk
    #     logging.info(f'{num_blk} blocks of latent variables')
    #
    # else:
    #     verb_block = None
    verb_block = None

    # pdb.set_trace()

    df_iid, df_ood = split_df(df, axis=args.ood, seed=args.nature_seed)
    dict_verb_iid, dict_noun_iid, dict_verb_noun_iid = df_to_dict(df_iid)
    dict_verb_ood, dict_noun_ood, dict_verb_noun_ood = df_to_dict(df_ood)

    # show_split(dict_verb_noun_iid, dict_verb_noun_ood, figname=f'fig/split/test.png')
    # show_split(dict_verb_noun_iid, dict_verb_noun_ood, figname=f'fig/split/split_{args.dataset}_{args.ood}_{args.seed}.pdf')

    # split
    max_num_ood = min(5000, int(0.5 * len(df_iid)))
    num_valid = max(min(len(df_ood), max_num_ood), 1)

    df_iid = df_iid.sample(frac=1, random_state=args.nature_seed)  # shuffle order
    df_train = df_iid[num_valid:]
    df_test = df_iid[:num_valid]

    # ood validation set for model selection
    if len(df_ood) < num_valid * 2:
        logging.warning("duplication between ood validation and ood test")
    else:
        logging.info("disjoint ood validation and test set")
    df_valid = df_ood[-num_valid:]
    df_ood = df_ood[:num_valid]

    df_train = df_train[:args.train_size]

    num_instance = len(df_train)
    num_noun = len(df_train['noun_class'].unique())
    num_verb = len(df_train['verb_class'].unique())
    logging.info(f'Training Dataset stat: # instance {num_instance}, # noun {num_noun}, # verb {num_verb}')

    # transform
    if args.encoder is None:
        transform = None
    elif args.encoder[:3] in ['res', 'vit'] or args.encoder[:5] == 'group':
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
        normalize = transforms.Normalize(mean=mean, std=std)
        if args.translation > 0.0 and not args.bbox:
            transform = transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(1. - args.translation, 1.)),
                transforms.ToTensor(),
                normalize,
            ])
            # todo: clean up logic
        else:
            transform = transforms.Compose([
                transforms.Resize([args.resolution, args.resolution]),
                transforms.ToTensor(),
                normalize,
            ])
    elif args.encoder[:4] == 'slot':
        transform = transforms.Compose([
            transforms.Resize([128, 128]),
            transforms.ToTensor(),
        ])
        # transforms.ConvertImageDtype(dtype=torch.float32)
    elif args.encoder[:4] == 'clip':
        mean = (0.48145466, 0.4578275, 0.40821073)
        std = (0.26862954, 0.26130258, 0.27577711)
        normalize = transforms.Normalize(mean=mean, std=std)
        transform = transforms.Compose([
            transforms.Resize([224, 224]),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        transform = None

    data_train = ActionDataset(args.dataset, df_train, args.path_data, transform, dict_noun_index, dict_verb_index,
                               args.mask, args.bbox)
    data_test = ActionDataset(args.dataset, df_test, args.path_data, transform, dict_noun_index, dict_verb_index,
                              args.mask, args.bbox)
    data_ood = ActionDataset(args.dataset, df_ood, args.path_data, transform, dict_noun_index, dict_verb_index,
                             args.mask, args.bbox)
    data_valid = ActionDataset(args.dataset, df_valid, args.path_data, transform, dict_noun_index, dict_verb_index,
                               args.mask, args.bbox)

    # # data sanity check
    # data_train[0]
    # pdb.set_trace()

    logging.info(f'# train: {len(data_train)}      # test: {len(data_test)}        # ood: {len(data_ood)}')

    # create loader
    if args.distributed:
        train_sampler = DistributedSampler(data_train)
        val_sampler = DistributedSampler(data_valid)
        loader_train = DataLoader(
            data_train, batch_size=args.batch_size, shuffle=False,
            num_workers=args.num_workers, pin_memory=True, worker_init_fn=seed_worker, sampler=train_sampler)
        loader_valid = DataLoader(
            data_valid, batch_size=args.batch_size, shuffle=False,
            num_workers=args.num_workers, pin_memory=True, worker_init_fn=seed_worker, sampler=val_sampler)
    else:
        loader_train = DataLoader(
            data_train, batch_size=args.batch_size, shuffle=True,
            num_workers=args.num_workers, pin_memory=True, worker_init_fn=seed_worker)
        loader_valid = DataLoader(
            data_valid, batch_size=args.batch_size, shuffle=False,
            num_workers=args.num_workers, pin_memory=True, worker_init_fn=seed_worker)
    loader_test = DataLoader(
        data_test, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True, worker_init_fn=seed_worker)
    loader_ood = DataLoader(
        data_ood, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True, worker_init_fn=seed_worker)

    return loader_train, loader_test, loader_ood, loader_valid, dict_noun_class, dict_verb_class, symmetric_verb_index, verb_block, bool_verb_noun

def set_loader(args, ratio_valid=0.2):
    # load meta data
    if args.dataset == 'procthor':
        '''
            procthor: ['scene', 'idx', 'figure', 'noun_class', 'verb_class', 'xmin', 'ymin', 'xmax', 'ymax']
        '''
        filecsv = f'{args.path_data}/annotations.csv'
        if os.path.exists(filecsv):
            df = pd.read_csv(filecsv)
            logging.info(f'Loaded annotations from {filecsv}')
        else:
            files = glob.glob(f'{args.path_data}/proc_*/annotations.csv')
            files.sort()
            stack = list()
            for file in files:
                df = pd.read_csv(file, header=None,
                                 names=['scene', 'idx', 'figure', 'noun_class', 'verb_class', 'xmin', 'ymin', 'xmax',
                                        'ymax'])
                stack.append(df)
            df = pd.concat(stack, ignore_index=True)
            df.to_csv(filecsv, index=False)

        df = df[df['verb_class'] != 'none']
        df = df[df['verb_class'] != 'cook']

        if args.select_actions is not None:
            random.seed(args.nature_seed)
            unique_indices = df['verb_class'].unique()
            selected_indices = random.sample(list(unique_indices), min(args.select_actions, len(unique_indices)))
            df = df[df['verb_class'].isin(selected_indices)].reset_index(drop=True)

        loader_train, loader_test, loader_ood, loader_valid, dict_noun_class, dict_verb_class, symmetric_verb_index, verb_block, bool_verb_noun = preprocess_data(
            args,
            df)

        return loader_train, loader_valid, loader_test, loader_ood, dict_noun_class, dict_verb_class, symmetric_verb_index, verb_block, bool_verb_noun

    elif args.dataset == 'epickitchens':
        '''
            epickitchens: ['participant_id', 'video_id', 'narration_timestamp', 'start_timestamp',
                           'stop_timestamp', 'start_frame', 'stop_frame', 'narration', 'verb',
                           'verb_class', 'noun', 'noun_class', 'all_nouns', 'all_noun_classes',
                           'start_score', 'stop_score']
        '''
        df_verb, df_noun = load_categories(f'{args.path_data}/annotations')
        df = load_annotations(f'{args.path_data}/prepro')
        # align var name
        df.rename(columns={'noun_class': 'noun_index', 'verb_class': 'verb_index'}, inplace=True)
        df['noun_class'] = df.apply(lambda row: df_noun.loc[row.noun_index].key, axis=1)
        df['verb_class'] = df.apply(lambda row: df_verb.loc[row.verb_index].key, axis=1)

        # print(df)
        # df.to_csv('epicframes.csv')
        # pdb.set_trace()

        if args.select_actions is not None:
            random.seed(args.nature_seed)
            unique_indices = df['verb_index'].unique()
            selected_indices = random.sample(list(unique_indices), min(args.select_actions, len(unique_indices)))
            df = df[df['verb_index'].isin(selected_indices)].reset_index(drop=True)

        loader_train, loader_test, loader_ood, loader_valid, dict_noun_class, dict_verb_class, symmetric_verb_index, verb_block, bool_verb_noun = preprocess_data(
            args,
            df)

        return loader_train, loader_valid, loader_test, loader_ood, dict_noun_class, dict_verb_class, symmetric_verb_index, verb_block, bool_verb_noun

    elif args.dataset == 'synthetic':
        # Regenerate data if necessary
        if not Path(args.path_data).exists():
            generate_datasets(args)

        # Load train data
        filename = Path(args.path_data) / "train.pt"
        logging.debug(f"Loading data from {filename}")
        data = torch.load(filename)
        train_data = TensorDataset(*data)

        # Load val data
        filename = Path(args.path_data) / "val.pt"
        logging.debug(f"Loading data from {filename}")
        data = torch.load(filename)
        val_data = TensorDataset(*data)

        # Load test data
        filename = Path(args.path_data) / "test.pt"
        logging.debug(f"Loading data from {filename}")
        data = torch.load(filename)
        test_data = TensorDataset(*data)

        nature_file = Path(args.path_data) / "nature.pt"
        nature = torch.load(nature_file)
        info = ''
        for item in nature.keys():
            info += f'{item}:{nature[item]}\n'

        with open(os.path.join(args.path_data, 'nature_graph_data.txt'), 'w') as fp:
            fp.write(info)
        fp.close()

        if args.distributed:
            train_sampler = DistributedSampler(train_data)
            val_sampler = DistributedSampler(val_data)
            loader_train = DataLoader(train_data, batch_size=args.batch_size, shuffle=False,
                                      num_workers=args.num_workers,
                                      pin_memory=True, worker_init_fn=seed_worker, sampler=train_sampler)
            loader_valid = DataLoader(val_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers,
                                      pin_memory=True, worker_init_fn=seed_worker, sampler=val_sampler)
        else:
            loader_train = DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True, worker_init_fn=seed_worker)
            loader_valid = DataLoader(val_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers,
                                      pin_memory=True, worker_init_fn=seed_worker)
        loader_test = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers,
                                 pin_memory=True, worker_init_fn=seed_worker)

        return loader_train, loader_valid, loader_test, None, None, None, None, None, None

    else:
        raise NotImplementedError


def set_model(args, num_action=0, num_object=0):

    if args.dataset in ["procthor", "epickitchens"]:
        action_classifier = Classifier(2 * args.dim_z, num_action)
        object_classifier = Classifier(2 * args.dim_z, num_object)
    else:
        action_classifier = None
        object_classifier = None

    if args.model in ["elcm", "betavae"]:

        scm = create_scm(args)

        encoder, decoder = create_encoder_decoder(args)

        model = ELCM(
            scm,
            encoder=encoder,
            decoder=decoder,
            intervention_prior=None,
            dim_z=args.dim_z,
            action_classifier=action_classifier,
            object_classifier=object_classifier,
        )

    elif args.model in ["ilcm", "dvae"]:

        scm = create_scm(args)

        encoder, decoder = create_encoder_decoder(args)

        intervention_encoder = create_intervention_encoder(args)

        model = ILCM(
            scm,
            encoder=encoder,
            decoder=decoder,
            intervention_encoder=intervention_encoder,
            intervention_prior=None,
            averaging_strategy=args.averaging_strategy,
            dim_z=args.dim_z,
            action_classifier=action_classifier,
            object_classifier=object_classifier,
        )

    elif args.model == "softilcm":
        scm = create_scm(args)

        encoder, decoder = create_encoder_decoder(args)

        encoder_hidden_layers = args.encoder_hidden_layers
        encoder_hidden = [args.encoder_hidden_units for _ in range(encoder_hidden_layers)]
        decoder_hidden_layers = args.decoder_hidden_layers
        decoder_hidden = [args.decoder_hidden_units for _ in range(decoder_hidden_layers)]
        noise_encoder = GaussianEncoder(
            encoder_type=args.encoder,
            encoder_decoder="encoder",
            hidden=encoder_hidden,
            input_features=args.dim_x,
            output_features=args.dim_z,
            fix_std=args.encoder_fix_std,
            init_std=args.encoder_std,
            min_std=args.encoder_min_std,
            amin=args.amin,
            resolution=args.resolution,

        )  # same as Encoder in GCRL
        noise_decoder = GaussianEncoder(
            encoder_type=args.encoder,
            encoder_decoder="decoder",
            hidden=decoder_hidden,
            input_features=args.dim_z,
            output_features=args.dim_x,
            fix_std=args.decoder_fix_std,
            init_std=args.decoder_std,
            min_std=args.decoder_min_std,
            amin=args.amin,
            resolution=args.resolution,
        )

        model = SoftILCM(
            scm,
            encoder=encoder,
            decoder=decoder,
            noise_encoder=noise_encoder,
            noise_decoder=noise_decoder,
            averaging_strategy=args.averaging_strategy,
            dim_z=args.dim_z,
            adversarial=args.adversarial,
            noise_model=args.noise_model,
            num_components=args.num_components,
            num_samples=args.num_samples,
            action_classifier=action_classifier,
            object_classifier=object_classifier,
        )

    else:
        raise ValueError(f"Unknown value for args.model: {args.model}")

    return model


def create_scm(args):
    """Create SCM or implicit causal structure"""

    logging.info(f"Creating {args.scm} SCM")
    noise_centric = args.model in {
        "softilcm",
        "ilcm",
        "dvae"
    }

    if args.scm == "ground_truth":
        raise NotImplementedError
    elif args.scm == "unstructured":  # Baseline VAE
        scm = UnstructuredPrior(dim_z=args.dim_z)
    elif noise_centric and args.scm == "mlp":
        logging.info(
            f"Graph parameterization for noise-centric learning: {args.scm_adjacency_matrix}"
        )
        scm = MLPImplicitSCM(
            model_type=args.model,
            graph_parameterization=args.scm_adjacency_matrix,
            manifold_thickness=args.scm_manifold_thickness,
            hidden_units=args.scm_hidden_units,
            hidden_layers=args.scm_hidden_layers,
            homoskedastic=args.scm_homoskedastic,
            dim_z=args.dim_z,
            min_std=args.scm_min_std,
            var_diminish=args.var_diminish,
            init=args.scm_init
        )
    elif (
            not noise_centric
            and args.scm == "mlp"
            and args.scm_adjacency_matrix in {"enco", "dds"}
    ):
        logging.info(
            f"Adjacency matrix: learnable, {args.scm_adjacency_matrix} parameterization"
        )
        scm = MLPVariableOrderCausalModel(
            graph_parameterization=args.scm_adjacency_matrix,
            manifold_thickness=args.scm_manifold_thickness,
            hidden_units=args.scm_hidden_units,
            hidden_layers=args.scm_hidden_layers,
            homoskedastic=args.scm_homoskedastic,
            dim_z=args.dim_z,
            enhance_causal_effects_at_init=False,
            min_std=args.scm_min_std,
        )
    elif (
            not noise_centric
            and args.scm == "mlp"
            and args.scm_adjacency_matrix == "fixed_order"
    ):
        logging.info(f"Adjacency matrix: learnable, fixed topological order")
        scm = MLPFixedOrderSCM(
            manifold_thickness=args.scm_manifold_thickness,
            hidden_units=args.scm_hidden_units,
            hidden_layers=args.scm_hidden_layers,
            homoskedastic=args.scm_homoskedastic,
            dim_z=args.dim_z,
            enhance_causal_effects_at_init=False,
            min_std=args.scm_min_std,
        )
    else:
        raise ValueError(f"Unknown value for args.scm: {args.scm}")

    return scm


def create_intervention_encoder(args):
    """Creates an intervention encoder"""
    logging.info(f"Creating {args.intervention_encoder} intervention encoder")

    if args.intervention_encoder == "learnable_heuristic":
        intervention_encoder = HeuristicInterventionEncoder()

    elif args.intervention_encoder == "mlp":
        n_interventions = args.dim_z + 1  # atomic or empty interventions
        features = (
                [2 * args.dim_z]
                + [args.intervention_encoder_hidden_units]
                * args.intervention_encoder_hidden_layers
                + [n_interventions]
        )
        intervention_encoder = make_mlp(features, final_activation="softmax")
    else:
        raise ValueError(
            f"Unknown value for cfg.model.intervention_encoder.type: "
            f"{args.intervention_encoder}"
        )

    return intervention_encoder


def train(args, model, criterion, optim_model, optim_discriminator, loader,
          model_interventions, pretrain, model_noise,
          deterministic_intervention_encoder, writer, graph_kwargs, epoch, step, rank=0,
          stage='training'):
    """High-level training function"""

    model.train()

    loss_meter = AverageMeter('loss', ':4.2f')
    mse_meter = AverageMeter('MSE loss', ':4.2f')
    action_ce_meter = AverageMeter('Action CE loss', ':4.2f')
    object_ce_meter = AverageMeter('Object CE loss', ':4.2f')
    mse_consistency_meter = AverageMeter('MSE consistency', ':4.2f')
    mse_inverse_consistency_meter = AverageMeter('MSE inverse consistency', ':4.2f')
    kl_epsilon_meter = AverageMeter('Kl divergence e', ':4.2f')
    kl_noise_meter = AverageMeter('Kl divergence n', ':4.2f')
    noise_meter = AverageMeter('Post-intervention noise', ':4.2f')
    e1_meter = AverageMeter('Pre-intervention epsilon', ':4.2f')
    e2_meter = AverageMeter('Post-intervention epsilon', ':4.2f')
    shifts_meter = AverageMeter('Shifts', ':4.2f')
    scales_meter = AverageMeter('Scales', ':4.2f')
    kn_meter = AverageMeter('K noise', ':4.2f')
    gn_meter = AverageMeter('G noise', ':4.2f')

    progress = ProgressMeter(
        len(loader),
        [loss_meter, mse_meter, mse_consistency_meter, kl_epsilon_meter, kl_noise_meter])

    steps_per_epoch = len(loader)

    nan_counter = 0
    for i, batch in enumerate(loader):
        fractional_epoch = step / steps_per_epoch

        # Step-based schedules
        (
            beta,
            beta_intervention,
            consistency_regularization_amount,
            cyclicity_regularization_amount,
            edge_regularization_amount,
            inverse_consistency_regularization_amount,
            z_regularization_amount,
            intervention_entropy_regularization_amount,
            intervention_encoder_offset,
        ) = step_schedules(args, model, fractional_epoch)

        # GPU
        if args.dataset in ["epickitchens", "procthor"]:
            x1, x2, label, noun, s1, s2 = batch
            x1 = x1.cuda()
            x2 = x2.cuda()
            label = label.cuda()
            noun = noun.cuda()
            s1 = s1.cuda()
            s2 = s2.cuda()

            intervention_labels = None
        else:
            x1, x2, _, shifts1, scales1, _, shifts2, scales2, intervention_labels, _ = batch
            x1, x2, intervention_labels = (
                x1.cuda(),
                x2.cuda(),
                intervention_labels.cuda(),
            )

            s1 = None
            s2 = None
            label = intervention_labels
            noun = None

        # Model forward pass
        if args.mask:
            log_prob, model_outputs = model(
                x1,
                x2,
                s1,
                s2,
                beta=beta,
                true_action=label,
                true_object=noun,
                beta_intervention_target=beta_intervention,
                pretrain_beta=args.pretrain_beta,
                full_likelihood=args.full_likelihood,
                likelihood_reduction=args.likelihood_reduction,
                pretrain=pretrain,
                model_noise=model_noise,
                model_interventions=model_interventions,
                deterministic_intervention_encoder=deterministic_intervention_encoder,
                intervention_encoder_offset=intervention_encoder_offset,
                **graph_kwargs,
            )
        else:
            log_prob, model_outputs = model(
                x1,
                x2,
                beta=beta,
                true_action=label,
                true_object=noun,
                beta_intervention_target=beta_intervention,
                pretrain_beta=args.pretrain_beta,
                full_likelihood=args.full_likelihood,
                likelihood_reduction=args.likelihood_reduction,
                pretrain=pretrain,
                model_noise=model_noise,
                model_interventions=model_interventions,
                deterministic_intervention_encoder=deterministic_intervention_encoder,
                intervention_encoder_offset=intervention_encoder_offset,
                **graph_kwargs,
            )

        # Loss and metrics
        vae_loss, disc_loss, metrics = criterion(
            log_prob,
            true_intervention_labels=intervention_labels,
            z_regularization_amount=z_regularization_amount,
            edge_regularization_amount=edge_regularization_amount,
            cyclicity_regularization_amount=cyclicity_regularization_amount,
            consistency_regularization_amount=consistency_regularization_amount,
            inverse_consistency_regularization_amount=inverse_consistency_regularization_amount,
            intervention_entropy_regularization_amount=intervention_entropy_regularization_amount,
            **model_outputs,
        )

        # Optimizer step
        finite = optimizer_step(args, vae_loss, disc_loss, model, optim_model, optim_discriminator, x1, x2, label, model_noise)
        if not finite:
            nan_counter += 1

        step += 1

        if rank == 0:
            # These emtrics are avergaed over batch size and they are just a float
            loss_meter.update(metrics["loss"])
            if "mse" in metrics:
                mse_meter.update(metrics["mse"])
            if "consistency_mse" in metrics:
                mse_consistency_meter.update(metrics["consistency_mse"])
            if "inverse_consistency_mse" in metrics:
                mse_inverse_consistency_meter.update(metrics["inverse_consistency_mse"])
            if "action_ce" in metrics:
                action_ce_meter.update(metrics["action_ce"])
            if "object_ce" in metrics:
                object_ce_meter.update(metrics["object_ce"])
            if "kl_epsilon" in metrics:
                kl_epsilon_meter.update(metrics["kl_epsilon"])
            if "kl_noise" in metrics:
                kl_noise_meter.update(metrics["kl_noise"])
            if "noise_proj" in metrics:
                noise_meter.update(metrics["noise_proj"])
            if "e1_proj" in metrics:
                e1_meter.update(metrics["e1_proj"])
            if "e2_proj" in metrics:
                e2_meter.update(metrics["e2_proj"])
            if "shifts" in metrics:
                shifts_meter.update(metrics["shifts"])
            if "scales" in metrics:
                scales_meter.update(metrics["scales"])
            if "kn" in metrics:
                kn_meter.update(metrics["kn"])
            if "gn" in metrics:
                gn_meter.update(metrics["gn"])

            progress.display(i + 1)

    if rank == 0:
        writer.add_scalar(f"total_loss/{stage}", loss_meter.avg, epoch)
        writer.add_scalar(f"recon_loss/{stage}", mse_meter.avg, epoch)
        writer.add_scalar(f"consistency_loss/{stage}", mse_consistency_meter.avg, epoch)
        writer.add_scalar(f"inverse_consistency_loss/{stage}", mse_inverse_consistency_meter.avg, epoch)
        writer.add_scalar(f"action_cross_entropy/{stage}", action_ce_meter.avg, epoch)
        writer.add_scalar(f"object_cross_entropy/{stage}", object_ce_meter.avg, epoch)
        writer.add_scalar(f"kl_epsilon/{stage}", kl_epsilon_meter.avg, epoch)
        writer.add_scalar(f"kl_noise/{stage}", kl_noise_meter.avg, epoch)
        writer.add_scalar(f"Post-intervention noise/{stage}", noise_meter.avg, epoch)
        writer.add_scalar(f"Pre-intervention epsilon/{stage}", e1_meter.avg, epoch)
        writer.add_scalar(f"Post-intervention epsilon/{stage}", e2_meter.avg, epoch)
        writer.add_scalar(f"Shifts/{stage}", shifts_meter.avg, epoch)
        writer.add_scalar(f"Scales/{stage}", scales_meter.avg, epoch)
        writer.add_scalar(f"K noise/{stage}", kn_meter.avg, epoch)
        writer.add_scalar(f"G noise/{stage}", gn_meter.avg, epoch)

    return step

@torch.no_grad()
def validate(args, model, criteria, val_loader):
    """Validation loop, computing a number of metrics and checkpointing the best model"""

    model.eval()

    loss, nll, metrics = compute_metrics_on_dataset(args, model, criteria, val_loader)

    if args.dataset == "synthetic":
        metrics.update(eval_dci_scores(args, model, partition="val"))

        dec = 4
        causal_importance_matrix = np.zeros((args.dim_z, args.dim_z))
        for r in range(args.dim_z):
            for c in range(args.dim_z):
                causal_importance_matrix[r, c] = round(metrics[f'causal_importance_matrix_{r}_{c}'], dec)
        logging.info('\n------- Causal Importance Matrix ------')
        logging.info(causal_importance_matrix)

    else:
        metrics.update(eval_accuracy(args, model, val_loader))

    return metrics


def embedding(model, loader):
    stack_verb, stack_noun, stack_feat, stack_emb1, stack_emb2 = list(), list(), list(), list(), list()

    for i, batch in enumerate(loader):

        # data
        first_img, second_img, verb, noun, first_mask, second_mask = batch
        if torch.cuda.is_available():
            first_img = first_img.cuda()
            second_img = second_img.cuda()
            first_mask = first_mask.cuda()
            second_mask = second_mask.cuda()

        # feat
        with torch.no_grad():
            if args.mask:
                _, feat, emb1, emb2 = model(first_img, second_img, first_mask, second_mask)
            else:
                _, feat, emb1, emb2 = model(first_img, second_img)

        # stack
        stack_verb.append(verb)
        stack_noun.append(noun)
        stack_feat.append(feat.detach())
        stack_emb1.append(emb1.detach())
        stack_emb2.append(emb2.detach())

    stack_verb = torch.cat(stack_verb)
    stack_noun = torch.cat(stack_noun)
    stack_feat = torch.cat(stack_feat).cpu()
    stack_emb1 = torch.cat(stack_emb1).cpu()
    stack_emb2 = torch.cat(stack_emb2).cpu()

    return stack_verb, stack_noun, stack_feat, stack_emb1, stack_emb2


def tsne(model, loader_test, loader_ood, dict_noun_class, dict_verb_class, args, acc_test, acc_ood):
    model.eval()
    foldername = 'fig/tsne'
    if not os.path.exists(foldername):
        os.makedirs(foldername)

    test_verb, test_noun, test_feat, test_emb1, test_emb2 = embedding(model, loader_test)
    ood_verb, ood_noun, ood_feat, ood_emb1, ood_emb2 = embedding(model, loader_ood)

    if args.ckpt:
        ckpt = args.ckpt.split('/')[-1][:-3].replace('_', '-')
    else:
        ckpt = 'public'

    figname = f'{args.ood}_{args.train_size}_{args.model}_{args.dim}_tran_{args.translation}_critic_{args.critic_action}_{args.critic_state}_linear_{args.linear}_mask_{args.mask}_bbox_{args.bbox}_encoder_{args.finetune}_pretrain_{ckpt}_sparse_{args.sparse}_amin_{args.amin}_seed_{args.seed}'

    figname += f'_iid_{acc_test:.2f}_ood_{acc_ood:.2f}'
    show_tsne(test_feat, test_noun, test_verb, ood_feat, ood_noun, ood_verb, dict_noun_class, dict_verb_class,
              foldername + f'/{figname}_feat')
    show_tsne(test_emb1, test_noun, test_verb, ood_emb1, ood_noun, ood_verb, dict_noun_class, dict_verb_class,
              foldername + f'/{figname}_emb1')
    show_tsne(test_emb2, test_noun, test_verb, ood_emb2, ood_noun, ood_verb, dict_noun_class, dict_verb_class,
              foldername + f'/{figname}_emb2')


def main(args):
    if args.distributed:
        torch.distributed.init_process_group(backend='nccl')
        # rank = int(os.environ["LOCAL_RANK"])
        # torch.cuda.set_device(rank)
        rank = int(os.environ['RANK'])
        num_gpus = int(os.environ['WORLD_SIZE'])
        torch.cuda.set_device(rank % num_gpus)
    else:
        torch.cuda.set_device(args.gpu)
        device_name = torch.cuda.get_device_name(torch.cuda.current_device())
        logging.info(device_name)
        rank = 0

    timestamp = datetime.datetime.now().astimezone().strftime("%Y%m%d_%H%M%S")
    model_name = f"{args.encoder}_{args.epochs}_{args.seed}_{timestamp}"
    set_seed(args.seed)

    if rank == 0:
        writer = SummaryWriter(log_dir=args.expdir + '/' + model_name + '/log', flush_secs=10)
        set_logger(args, model_name, stage='train')
        save_configs(datetime, vars(args), model_name)
    else:
        writer = None

    loader_train, loader_valid, loader_test, loader_ood, dict_noun_class, dict_verb_class, symmetric_verb_index, verb_block, bool_verb_noun = set_loader(
        args)

    if args.model is None:
        output_path = f'fig/pairs/'
        if not os.path.exists(output_path):
            os.makedirs(output_path)
        for idx, (first, second, verb, noun, _, _) in enumerate(loader_test.dataset):
            show_pair(first, second, dict_verb_class[verb], dict_noun_class[noun], suffix=f'_iid_{idx}',
                      savedir=output_path)
            if idx % 10 == 0:
                print(f'# {idx} / {len(loader_train.dataset)}')
            if idx >= 10:
                break
        for idx, (first, second, verb, noun, _, _) in enumerate(loader_ood.dataset):
            show_pair(first, second, dict_verb_class[verb], dict_noun_class[noun], suffix=f'_ood_{idx}',
                      savedir=output_path)
            if idx % 10 == 0:
                print(f'# {idx} / {len(loader_train.dataset)}')
            if idx >= 10:
                break
    else:

        model = set_model(args, num_action=args.num_actions, num_object=args.num_objects)
        optim_model, optim_discriminator, scheduler = create_optimizer_and_scheduler(args, model)

        if args.ckpt:
            best_loss = load_all_model(args, model, optim_model, optim_discriminator, lr_scheduler=scheduler)
        else:
            best_loss = 1e10

        if args.distributed:
            model.cuda()
            model = DistributedDataParallel(model, device_ids=[rank], find_unused_parameters=True)
        else:
            model.cuda()

        for state in optim_model.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

        if optim_discriminator != None:
            for state in optim_discriminator.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.cuda()

        criterion = VAEMetrics(args.dim_z)

        if args.finetune:
            if args.model[:4] == 'clip':
                grad_param(model.encoder.visual)
            else:
                grad_param(model)
        param_all, param_tra = count_param(model)
        logging.info(f'count parameters: total = {param_all}, trainable = {param_tra}')

        step = args.start_epoch * len(loader_train)
        for epoch in range(args.start_epoch, args.epochs):
            # Graph sampling settings
            graph_kwargs = determine_graph_learning_settings(args, epoch, model)

            # Epoch-based schedules
            model_interventions, pretrain, deterministic_intervention_encoder, model_noise = epoch_schedules(
                args, model, epoch, optim_model, loader_valid
            )

            step = train(args, model, criterion, optim_model, optim_discriminator,
                  loader_train, model_interventions, pretrain, model_noise,
                  deterministic_intervention_encoder, writer, graph_kwargs, epoch, step, rank)

            # LR scheduler
            if scheduler is not None and epoch < args.epochs - 1:
                scheduler.step()

                # Optionally reset Adam stats
                if (
                        args.lr_schedule == "cosine_restarts_reset"
                        and (epoch + 1) % args.lr_schedule_restart_every_epochs == 0
                        and epoch + 1 < args.epochs
                ):
                    logging.info(f"Resetting optimizer at epoch {epoch + 1}")
                    reset_optimizer_state(optim_model)

            # Ensure all models have the same weights before validation
            if epoch % args.val_every_epoch == 0:
                if args.distributed:
                    dist.barrier()
                    metrics = validate(args, model, criterion, loader_valid)
                    dist.barrier()  # Synchronize all processes before gathering

                    if args.dataset == "synthetic":
                        causal_disentanglement = torch.tensor(metrics['causal_disentanglement']).cuda()
                        dist.reduce(causal_disentanglement, dst=0, op=dist.ReduceOp.SUM)
                    else:
                        action_accuracy = torch.tensor(metrics['action_accuracy']).cuda()
                        dist.reduce(action_accuracy, dst=0, op=dist.ReduceOp.SUM)
                        object_accuracy = torch.tensor(metrics['object_accuracy']).cuda()
                        dist.reduce(object_accuracy, dst=0, op=dist.ReduceOp.SUM)

                    val_loss = metrics['loss']
                    dist.reduce(val_loss, dst=0, op=dist.ReduceOp.SUM)

                else:
                    metrics = validate(args, model, criterion, loader_valid)
                    if args.dataset == "synthetic":
                        causal_disentanglement = metrics['causal_disentanglement']

                    else:
                        action_accuracy = metrics['action_accuracy']
                        object_accuracy = metrics['object_accuracy']

                    val_loss = metrics['loss']

                # Rank 0 now has the aggregated metrics, print or save them as needed
                if rank == 0:
                    if args.dataset == "synthetic":
                        # Print DCI disentanglement score
                        logging.info(f"Epoch {epoch}: causal disentanglement = {causal_disentanglement:.2f}")
                        writer.add_scalar(f"causal_disentanglement/validation", causal_disentanglement, epoch)

                    else:
                        logging.info(f"Epoch {epoch}: Action accuracy = {action_accuracy:.2f}")
                        logging.info(f"Epoch {epoch}: Object accuracy = {object_accuracy:.2f}")
                        writer.add_scalar(f"action_accuracy/validation", action_accuracy, epoch)
                        writer.add_scalar(f"object_accuracy/validation", object_accuracy, epoch)


                    new_val_loss = val_loss.item()
                    # save checkpoints
                    if new_val_loss < best_loss: # TODO change this
                        best_loss = new_val_loss
                        save_all_model(args, model, model_name, optim_model, optim_discriminator, epoch, best_loss, timestamp)

        if rank == 0:
            writer.close()

        set_manifold_thickness(args, model, None)
        save_all_model(args, model, model_name, optim_model, optim_discriminator, epoch, best_loss, timestamp)

        # tsne(model, loader_test, loader_ood, dict_noun_class, dict_verb_class, args, acc_test, acc_ood)

    print('Well done')


if __name__ == "__main__":
    args = parse_arguments()
    print("Number of GPUs:", torch.cuda.device_count())
    main(args)
