import argparse
import os
import click
import csv
import shutil
from pathlib import Path
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

from wilds.datasets.waterbirds_dataset import WaterbirdsDataset
from wilds.datasets.celebA_dataset import CelebADataset
from wilds.datasets.coco_places_dataset import COCOonPlacesDataset
from models.initializer import initialize_model
import configs.supported as supported
from utils.misc import ParseKwargs, detach_and_clone, collate_list
from examples.utils.print_logger import get_logger
LOGGER = get_logger(__name__, level="DEBUG")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--root_dir', required=True,
                        help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')
    parser.add_argument("--dataset", type=str, default='waterbirds', help='')
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--model', choices=supported.models)
    parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},
        help='keyword arguments for model initialization passed as key1=value1 key2=value2')
    parser.add_argument('--process_outputs_function', choices = supported.process_outputs_functions,
                        default = 'multiclass_logits_to_pred')
    parser.add_argument('--log_dir', default='./logs')
    parser.add_argument('--seed', type=int, default=0)
    config = parser.parse_args()

    # Directory management
    output_dir = os.path.join(config.log_dir, 'env_jtt')
    if Path(output_dir).exists():
        if click.confirm("Removing directory ? ({}).".format(output_dir), abort=True):
            shutil.rmtree(output_dir)
    os.mkdir(output_dir)

    if config.dataset == 'waterbirds':
        LOGGER.info('Loading Waterbirds dataset')
        dataset = WaterbirdsDataset(root_dir=config.root_dir, get_img_idx=True)
    elif config.dataset == 'celebA':
        LOGGER.info('Loading CelebA dataset')
        dataset = CelebADataset(root_dir=config.root_dir, get_img_idx=True)
    elif config.dataset == 'coco_on_places':
        LOGGER.info('Loading COCO-on-Places dataset')
        dataset = COCOonPlacesDataset(root_dir=config.root_dir, get_img_idx=True)
    else:
        raise ValueError(f"Dataset {config.dataset} not recognized")

    test_transform = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    train_data = dataset.get_subset('train', transform=test_transform)
    train_loader = DataLoader(
        train_data,
        shuffle=False,
        sampler=None,
        pin_memory=True,
        num_workers=10,
        collate_fn=dataset.collate,
        batch_size=config.batch_size)

    LOGGER.info(f'Loading {config.model} model' )
    model = initialize_model(config,dataset._n_classes)
    weight_file = os.path.join(config.log_dir, f'{config.dataset}_seed:{config.seed}_epoch:best_model.pth')
    checkpoint = torch.load(weight_file, map_location=lambda storage, loc: storage)
    state_dict = {str.replace(k, 'model.', ''): v for k, v in checkpoint['algorithm'].items()}
    model.load_state_dict(state_dict)
    model.cuda().eval()

    LOGGER.info('Looping through training dataset')
    epoch_y_true, epoch_y_pred, ind_list = [], [], []
    for i, (data, y_true, indices) in enumerate(tqdm(train_loader)):
        with torch.no_grad():
            outputs = model(data.cuda())
        epoch_y_true.append(detach_and_clone(y_true))
        y_pred = detach_and_clone(outputs)
        if config.process_outputs_function is not None:
            y_pred = supported.process_outputs_functions[config.process_outputs_function](y_pred)
        epoch_y_pred.append(y_pred)
        ind_list.extend(indices.tolist())

    epoch_y_pred = collate_list(epoch_y_pred).cpu()
    epoch_y_true = collate_list(epoch_y_true)

    LOGGER.info(f'Creating environments correct vs errors in training set')
    corrects = (epoch_y_pred==epoch_y_true).int()
    print(f'Number of errors in training dataset: {len(corrects) - corrects.sum().item()}')

    LOGGER.info('Saving predicted environments for each training image')
    output_inv = []
    cluster_index = 0
    for idx in tqdm(range(len(dataset._input_array))):
        if idx in ind_list:
            output_inv.append([idx,dataset._input_array[idx],corrects[cluster_index].item()])
            cluster_index += 1
        else:
            output_inv.append([idx, dataset._input_array[idx], dataset.metadata_array[idx,0].item()])
    with open(os.path.join(output_dir,'env_labels.csv'), mode='w') as csv_file:
        csv_writer = csv.writer(csv_file, delimiter=',')
        csv_writer.writerow(['img_id','img_filename','env'])
        csv_writer.writerows(output_inv)


if __name__ == "__main__":
    main()