
"""Main script for Diffusion-TTA"""
import os
import copy
import random
import warnings
import csv

import wandb
import hydra
from hydra.utils import get_original_cwd
from omegaconf import OmegaConf, open_dict
from mergedeep import merge
import numpy as np
import pickle
import torch
import torch.backends.cudnn as cudnn
from tqdm import tqdm

torch.backends.cudnn.benchmark = True

from dataset.catalog import DatasetCatalog
from diff_tta import utils, engine
from diff_tta.vis_utils import (
    visualize_classification_with_image,
    visualize_diffusion_loss,
    visualize_classification_improvements,
)
from diff_tta.models import build

def log_one_epoch(config, before_avg_acc=None, after_avg_acc=None, before_tta_acc=None, after_tta_acc=None):
    if before_avg_acc is None:
        before_avg_acc = sum(before_tta_acc) / len(before_tta_acc)
    if after_avg_acc is None:
        after_avg_acc = sum(after_tta_acc) / len(after_tta_acc)
    before_avg_acc = before_avg_acc * 100
    after_avg_acc = after_avg_acc * 100
    improvement_avg = (after_avg_acc - before_avg_acc)
    # os.makedirs(os.path.dirname(config.log_path), exist_ok=True)
    log_path = os.path.join(config.log_path, "log.txt")
    os.makedirs(os.path.dirname(log_path), exist_ok=True)
    task_name = config.input.split if config.input.split is not None else config.input.dataset_name
    content = f"Task: {task_name}\n" \
              f"Before TTA Avg Acc: {before_avg_acc.item():.4f}%\n" \
              f"After TTA Avg Acc: {after_avg_acc.item():.4f}%\n" \
              f"Improvement Avg: {improvement_avg.item():.4f}%\n"
    with open(log_path, 'a+') as f:
        f.write(content)

def tta_one_epoch(config, dataloader, tta_model, optimizer, scaler,
                  autoencoder, image_renormalizer):
    """Perform test time adaptation over the entire dataset.

    Args:
        config: configuration object for hyper-parameters.
        dataloader: The dataloader for the dataset.
        tta_model: A test-time adaptation wrapper model.
        optimizer: A gradient-descent optimizer for updating classifier.
        scaler: A gradient scaler used jointly with optimizer.
        autoencoder: A pre-trained autoencoder model (e.g. VQVAE).
        image_renormalizer: An object for renormalizing images.
    """
    cwd = config.cwd
    discrete_sampling_accuracy = []
    
    tta_model.eval()

    # Keep a copy of the original model state dict, so that we can reset the
    # model after each image
    tta_class_state_dict = copy.deepcopy(tta_model.state_dict())
    
    # Enlarge batch size by accumulating gradients over multiple iterations
    config.tta.gradient_descent.train_steps = (
        config.tta.gradient_descent.train_steps
        * config.tta.gradient_descent.accum_iter
    )

    analysis_path = None
    if config.analysis:
        print("=> Analysis mode is on")
        analysis_path = os.path.join(config.log_path, f"analysis_{config.input.split.replace('/', '_')}.csv")
        header = ['pred_conf', 'pred_label', 'gt_label', 'correct']
        csv_handle = open(analysis_path, 'w')
        csvwriter = csv.writer(csv_handle)
        csvwriter.writerow(header)
        csv_handle.close()

    # Start iterations
    start_index = 0
    last_index = len(dataloader.dataset)
    pbar = tqdm(range(start_index, last_index), desc=f"Task {config.input.split}")
    for img_ind in pbar:
        # Enable/disable to upload visualization to wandb
        visualize = (
            (config.log_freq > 0 and img_ind % config.log_freq == 0)
            or img_ind == last_index - 1
        )
        # The dictionary for visualization
        wandb_dict = {}

        # Fetch data from the dataset
        # print(f"\n\n Example: {img_ind}/{last_index} \n\n")
        batch = dataloader.dataset[img_ind]
        batch = engine.preprocess_input(batch, config.gpu)
        
        # We will classify before and after test-time adaptation via
        # gradient descent. We run tta_model.evaluate(batch, after_tta=True) to
        # save the classification results

        # Step 1: Predict pre-TTA classification. The results are saved in
        # `before_tta_stats_dict` and `tta_model.before_tta_acc`
        before_tta_stats_dict = tta_model.evaluate(batch, before_tta=True)

        if config.analysis:
            # bs=1
            csv_handle = open(analysis_path, 'a')
            csvwriter = csv.writer(csv_handle)
            pred_conf = before_tta_stats_dict["before_tta_pred_conf"].item()
            pred_label = before_tta_stats_dict["before_tta_pred_label"].item()
            gt_label = before_tta_stats_dict["before_tta_gt_label"].item()
            correct = before_tta_stats_dict["before_tta_correct"].item()
            csvwriter.writerow([pred_conf, pred_label, gt_label, correct])
            csv_handle.close()

        # Step 2: TTA by gradient descent
        losses, after_tta_outputs = engine.tta_one_image_by_gradient_descent(
            batch, tta_model, optimizer, scaler,
            autoencoder, image_renormalizer, config,
            before_tta_stats_dict['pred_topk_idx'],
            image_num=img_ind,
            last_index=last_index
        )

        # Step 3: Predict post-TTA classification. The results are saved in
        # `after_tta_stats_dict` and `tta_model.after_tta_acc`
        after_tta_stats_dict = tta_model.evaluate(batch, after_tta=True)

        # Reload the original model state dict
        if not config.tta.online:
            tta_model.load_state_dict(tta_class_state_dict)
            optimizer = build.load_optimizer(config, tta_model)

        if visualize:
            # wandb_dict is updated in-place
            wandb_dict = visualize_classification_with_image(
                batch, config, dataloader.dataset,
                before_tta_stats_dict["before_tta_logits"],
                before_tta_stats_dict["before_tta_topk_idx"],
                before_tta_stats_dict["before_tta_pred_class_idx"],
                before_tta_stats_dict["before_tta_topk_class_idx"],
                wandb_dict
            )

            wandb_dict = visualize_diffusion_loss(losses, config, wandb_dict)

        # # Plot accuracy curve every image
        # wandb_dict = visualize_classification_improvements(
        #     tta_model.before_tta_acc, tta_model.after_tta_acc,
        #     before_tta_stats_dict["before_tta_correct"].float(),
        #     after_tta_stats_dict["after_tta_correct"].float(),
        #     wandb_dict
        # )
        cached_avg_acc = tta_model.get_avg_acc_with_cache(before=True, after=True)
        before_avg_acc, after_avg_acc = cached_avg_acc['before'], cached_avg_acc['after']
        # Plot accuracy curve every image
        wandb_dict = visualize_classification_improvements(
            before_tta_stats_dict["before_tta_correct"].float(),
            after_tta_stats_dict["after_tta_correct"].float(),
            wandb_dict,
            before_avg_acc=before_avg_acc, after_avg_acc=after_avg_acc
        )

        pbar.set_postfix({'before_avg_acc': f'{wandb_dict["before_avg_acc"].item():.4f}',
                          'after_avg_acc': f'{wandb_dict["after_avg_acc"].item():.4f}'})

        # Save the results to the disck
        wandb_run_name = wandb.run.name
        stats_folder_name = f'stats/{wandb_run_name}/'
        os.makedirs(stats_folder_name, exist_ok=True)
        
        if config.save_results:
            stats_dict = {}
            stats_dict['accum_iter'] = config.tta.gradient_descent.accum_iter
            stats_dict['filename'] = batch['filepath']
            stats_dict['losses'] = losses
            stats_dict['gt_idx'] = batch['class_idx'][0]
            stats_dict = merge(stats_dict, before_tta_stats_dict, after_tta_stats_dict)
            file_index = int(batch['index'].squeeze())
            store_filename = f"{stats_folder_name}/{file_index:06d}.p"
            pickle.dump(stats_dict, open(store_filename, 'wb'))

        wandb.log(wandb_dict, step=img_ind)
    cached_avg_acc = tta_model.get_avg_acc_with_cache(before=True, after=True)
    before_avg_acc, after_avg_acc = cached_avg_acc['before'], cached_avg_acc['after']
    log_one_epoch(config=config,
                  before_avg_acc=before_avg_acc,
                  after_avg_acc=after_avg_acc,)

def get_dataset(config):
    """Instantiate the dataset object."""
    Catalog = DatasetCatalog(config)

    dataset_dict = getattr(Catalog, config.input.dataset_name)

    target = dataset_dict['target']
    params = dataset_dict['train_params']
    if config.input.dataset_name == "ObjectNetSubsetNew":
        params.update({'use_dit': config.model.use_dit})
    dataset = utils.instantiate_from_config(dict(target=target, params=params))

    return dataset


@hydra.main(config_path="diff_tta/config", config_name="config")
def run(config):
    with open_dict(config):
        config.log_dir = os.getcwd()
        print(f"Logging files in {config.log_dir}")
        config.cwd = get_original_cwd()
        config.gpu = None if config.gpu < 0 else config.gpu

    # Hydra automatically changes the working directory, but we stay at the
    # project directory.
    os.chdir(config.cwd)

    print(OmegaConf.to_yaml(config))
    
    if config.input.dataset_name == "ObjectNetDataset":
        config.input.use_objectnet = True

    if config.seed is not None:
        np.random.seed(config.seed)
        random.seed(config.seed)
        torch.manual_seed(config.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    run_worker(config)


def run_worker(config):
    torch.set_flush_denormal(True)

    if config.gpu is not None:
        print("Use GPU: {} for training".format(config.gpu))

    wandb.init(project=config.wandb.project,name=config.input.split, config= OmegaConf.to_container(config, resolve=True), mode=config.wandb.mode)

    # model, autoencoder, image_renormalizer,optimizer,scaler= None, None, None, None, None
    if config.continual.is_ckpt_continual:
        print("=> Loading dataset")
        dataset = get_dataset(config)
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=1,
            shuffle=config.input.shuffle,
            num_workers=config.workers,
            pin_memory=True,
            sampler=None,
            drop_last=False
        )

        # create model
        print("=> Creating model ")
        model, autoencoder, image_renormalizer = (
            build.create_models(config, dataset.classes, dataset.class_names)
        )
        optimizer = build.load_optimizer(config, model)
        scaler = torch.cuda.amp.GradScaler()

        if not config.continual.is_first_task:
            print("=> Continual learning, Loading model and optimizer!")
            model.load_state_dict(torch.load(config.continual.tta_model_ckpt_path))
            optimizer.load_state_dict(torch.load(config.continual.optimizer_ckpt_path))
        tta_one_epoch(config, dataloader, model, optimizer, scaler,
                      autoencoder, image_renormalizer)
        print("=> Continual learning, Saving model and optimizer!")
        torch.save(model.state_dict(), config.continual.tta_model_ckpt_path)
        torch.save(optimizer.state_dict(), config.continual.optimizer_ckpt_path)

    elif config.input.split == "continual_all":
        task_list = [
            "gaussian_noise/5",
            "shot_noise/5",
            "impulse_noise/5",
            "defocus_blur/5",
            "glass_blur/5",
            "motion_blur/5",
            "zoom_blur/5",
            "snow/5",
            "frost/5",
            "fog/5",
            "brightness/5",
            "contrast/5",
            "elastic_transform/5",
            "pixelate/5",
            "jpeg_compression/5"
        ]
        # prepare models for continual usage
        # create model
        print("=> Creating model ")
        config.input.split = task_list[0]
        dataset = get_dataset(config)
        model, autoencoder, image_renormalizer = (
            build.create_models(config, dataset.classes, dataset.class_names)
        )
        optimizer = build.load_optimizer(config, model)
        scaler = torch.cuda.amp.GradScaler()
        for i, task in enumerate(task_list):
            config.input.split = task
            print("=> Loading dataset")
            dataset = get_dataset(config)
            dataloader = torch.utils.data.DataLoader(
                dataset,
                batch_size=1,
                shuffle=config.input.shuffle,
                num_workers=config.workers,
                pin_memory=True,
                sampler=None,
                drop_last=False
            )
            tta_one_epoch(config, dataloader, model, optimizer, scaler,
                          autoencoder, image_renormalizer)

    else:
        print("=> Loading dataset")
        dataset = get_dataset(config)
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=1,
            shuffle=config.input.shuffle,
            num_workers=config.workers,
            pin_memory=True,
            sampler=None,
            drop_last=False
        )

        # create model
        print("=> Creating model ")
        model, autoencoder, image_renormalizer = (
            build.create_models(config, dataset.classes, dataset.class_names)
        )
        optimizer = build.load_optimizer(config, model)
        scaler = torch.cuda.amp.GradScaler()
        tta_one_epoch(config, dataloader, model, optimizer, scaler,
                      autoencoder, image_renormalizer)


if __name__ == '__main__':
    run()
