""" Utils for the CovidXRAY dataset. 
    This will be used to generate the CXRay dataset
"""


import torch
from torchvision import transforms
from torch.utils.data import Dataset
import torch.nn as nn
import numpy as np
import glob
import pandas as pd
import csv
from PIL import Image
import shutil
from tqdm.auto import tqdm
from multiprocessing import Pool
import argparse
import os
from collections import defaultdict
import random
import skimage.io


# ==============================================
# =           Dataset definition               =
# ==============================================

class CXRayDataset(Dataset):
    def __init__(self, root_dir, split, transform=None):
        subdir = {'train': 'train_split',
                  'val': 'val_split',
                  'val_split_c1': 'val_split_c1',
                  'val_split_c0': 'val_split_c0',
                  'test': 'test'}[split]

        self.root_dir = os.path.join(root_dir, subdir)
        self.csvfile = os.path.join(self.root_dir, 'index.csv')
        self.csv_data = pd.read_csv(self.csvfile)
        
        self.transform = transform
        
    def __len__(self):
        return len(self.csv_data)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img_name = os.path.join(self.root_dir,
                                self.csv_data.iloc[idx, 0])
        X = skimage.io.imread(img_name)
        if self.transform:
            X = self.transform(X)
        C = self.csv_data.iloc[idx, 1]
        N = self.csv_data.iloc[idx, 2]
        sample = {'X': X, 'C': C, 'N': N}
        return sample


class DiffusionScale:
    def __init__(self):
        pass
    def __call__(self, x):
        return 2 * x - 1.0
    

def transback(x):
    return x *0.5 + 0.5


# ===============================================
# =           Data Preparation Stuff            =
# ===============================================

# Resizing images
def resize_file(args):
    src_file, output_dir, target_size = args
    resize = transforms.Resize(target_size)
    img_exts = {'.JPG', '.PNG', '.jpeg', '.jpg', '.png'}

    dst_file = os.path.join(output_dir, os.path.basename(src_file))
    if os.path.splitext(src_file)[-1].lower() not in img_exts:
        if src_file != dst_file:
            shutil.copyfile(src_file, dst_file)
    else:
        dst_img = resize(Image.open(src_file).convert('L'))
        dst_img.save(dst_file)
    return


def resize_img_dir(base_dir, target_size, output_dir=None, num_workers=8):
    if output_dir == None: output_dir = base_dir
    base_dir = os.path.expanduser(base_dir)
    output_dir = os.path.expanduser(output_dir)
    os.makedirs(output_dir, exist_ok=True)
    
    all_files = glob.glob(os.path.join(base_dir, '*'))
        
    with Pool(processes=num_workers) as pool:
        print("Resizing all images in %s..." % base_dir)  
        pool_input = [(_, output_dir, target_size) for _ in all_files]
        results = tqdm(pool.imap(resize_file, pool_input), total=len(pool_input))
        tuple(results)
        print("...done resizing, saved in %s" % output_dir)
    





# Generating index CSV file
def get_triples(root_dir, random_seed=1234, noise_fraction=0.05):
    """ Basic data recipe:
    Root_dir: file that contains all images AND {train_COVID9xa.txt test_COVID9xa.txt} file


    Procedure:
    1. Collect images that we want to attach labels to
    2. For each of the img files, gather label from x9a dataset
    3. Remove anything that doesn't have _both_ labels
    4. Use x9a_label -> C, N
    5. Define split function to make some N's noisy
    Return [(img_filename, C, N),...]
    """
    root_dir = os.path.expanduser(root_dir)
    x9a_txt = glob.glob(os.path.join(root_dir, '*COVIDx9A.txt'))[0]
    
    # 1 collect image files
    output_d = {}
    for filename in glob.glob(os.path.join(root_dir, '*')):
        output_d[os.path.basename(filename)] = None # x9a labels

    print("Num image files", len(output_d))

    # 2 gather x9a labels
    skipped_files = 0
    for x9a_line in open(x9a_txt).readlines():
        if x9a_line.strip() == '': 
            continue
        split = x9a_line.split(' ')
        assert len(split) == 4
        _, img_name, diagnosis, _ = split
        img_name = '/'.join(img_name.split('/')[1:])
        
        if img_name not in output_d: 
            skipped_files += 1
            continue
        output_d[img_name] = diagnosis
    
    # 3 Remove anything that doesn't have both labels
    true_output_d = {k: v for k,v in output_d.items() if v != None}
    print("Num items: %s | Num skipped: %s" % (len(true_output_d), skipped_files))
    
    # 4 use x9a labels -> C, N
    proper_triples = []
    for k, x9alab in true_output_d.items():
        # Rule is: x9alab == 'Covid-19' -> C=1, N=1
        #                 == 'Pneumonia'-> C=0, N=1
        #                 == 'Normal'   -> C=0, N=0
        # print(kaglab, x9alab)
        if x9alab == 'COVID-19':
            x9a_c, x9a_n = 1, 1
        elif x9alab == 'pneumonia':
            x9a_c, x9a_n = 0, 1
        elif x9alab == 'normal':
            x9a_c, x9a_n = 0, 0
        else:
            raise Exception("Unknown label: %s" % x9alab)
        proper_triples.append((k, x9a_c, x9a_n))
        
    # Randomly convert some (C=1,N=1)->(C=1, N=0) labels
    groups = defaultdict(list) # group by labels
    for img, c, n in proper_triples:
        groups[(c,n)].append(img)
    random.seed(random_seed)
    random.shuffle(groups[(1,1)])
    K = int(len(groups[(1,1)]) * noise_fraction)
    groups[(1,0)] = groups[(1,1)][:K]
    groups[(1,1)] = groups[(1,1)][K:]
        
    output = []
    for (c,n), img_list in groups.items():
        for img in img_list:
            output.append((img, c, n))
    return sorted(output)

def write_csv_from_triples(root_dir, triples):
    full_output = os.path.join(root_dir, 'index.csv')
    with open(full_output, 'w', newline='') as csvfile:
        csvwriter = csv.writer(csvfile, delimiter=',')
        for (img, c, n) in triples:
            csvwriter.writerow([img, str(c), str(n)])


def full_csv_idx(root_dir, random_seed, noise_fraction):
    triples = get_triples(root_dir, random_seed=random_seed, noise_fraction=noise_fraction)
    write_csv_from_triples(root_dir, triples)



def valsplit(root_dir, split_seed, split_frac):
    # Make directories for train split, val split
    train_dir = os.path.join(root_dir, 'train')
    trainsplit_dir = os.path.join(root_dir, 'train_split')
    valsplit_dir = os.path.join(root_dir, 'val_split')

    os.makedirs(trainsplit_dir, exist_ok=True)
    os.makedirs(valsplit_dir, exist_ok=True)


    # Now shuffle the dataset
    lines = open(os.path.join(root_dir, 'train', 'index.csv'), 'r').readlines()
    get_filename = lambda line: line.split(',')[0]

    random.seed(split_seed)
    random.shuffle(lines)
    train_size = round(split_frac * len(lines))
    train_split_lines, val_split_lines = lines[:train_size], lines[train_size:]

    # Make train split 
    with open(os.path.join(trainsplit_dir, 'index.csv'), 'w') as index:
        for line in train_split_lines:
            index.write(line)
            src_file = os.path.join(train_dir, get_filename(line))
            tgt_file = os.path.join(trainsplit_dir, get_filename(line))
            os.symlink(src_file, tgt_file)

    # Make val split
    with open(os.path.join(valsplit_dir, 'index.csv'), 'w') as index:
        for line in val_split_lines:
            index.write(line)
            src_file = os.path.join(train_dir, get_filename(line))
            tgt_file = os.path.join(valsplit_dir, get_filename(line))
            os.symlink(src_file, tgt_file)


def split_by_c(root_dir, split):
    assert split in ['train', 'val', 'test']
    if split in ['train', 'val']:
        split = split + '_split'
    assert split in ['train_split', 'val_split', 'test']
    base_dir = os.path.join(root_dir, split)
    c0_dir = os.path.join(root_dir, split + '_c0')
    c1_dir = os.path.join(root_dir, split + '_c1')
    
    os.makedirs(c0_dir, exist_ok=True)
    os.makedirs(c1_dir, exist_ok=True)

    lines = open(os.path.join(base_dir, 'index.csv'), 'r').readlines()

    with open(os.path.join(c0_dir, 'index.csv'), 'w') as c0_idx:
        with open(os.path.join(c1_dir, 'index.csv'), 'w') as c1_idx:
            for line in lines:
                filename = line.split(',')[0]
                c_val = line.split(',')[1]
                if int(c_val) == 0:
                    c0_idx.write(line)
                    target_dir = c0_dir
                elif int(c_val) == 1:
                    c1_idx.write(line)
                    target_dir = c1_dir
                else:
                    raise Exception("Invalid C value: %s" % c_val)
                src_file = os.path.join(base_dir, filename)
                tgt_file = os.path.join(target_dir, filename)
                os.symlink(src_file, tgt_file)






# ======================================
# =           MAIN                     =
# ======================================

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Args to do CLI stuff here')
    parser.add_argument('--fxn', type=str, choices=['write_csv', 'resize', 'valsplit'])

    # Handle shared arguments
    parser.add_argument('--root_dir', type=str, required=True, 
                        help='pointer to directory that has images and *COVID9xa.txt file')

    # Handle write_csv arguments
    parser.add_argument('--random_seed', type=int, default=1234)
    parser.add_argument('--noise_fraction', type=float, default=0.05)


    # Handle resize_image argument
    parser.add_argument('--hw', type=int, default=128, help='resizes image to size (hw, hw)')
    parser.add_argument('--output_dir', type=str, help='where to put the resized files (same place if ==None)')
    parser.add_argument('--num_workers', type=int, default=8, help='num processes to spawn for reshaping')


    # Handle randsplit args
    parser.add_argument('--split_seed', type=int, default=1234)
    parser.add_argument('--split_frac', type=float, default=0.6749) #20k in trainset

    args = parser.parse_args()
    print("ROOT DIR", args.root_dir)

    if args.fxn == 'write_csv':
        full_csv_idx(args.root_dir, args.random_seed, args.noise_fraction)
    elif args.fxn == 'resize':
        resize_img_dir(args.root_dir, (args.hw, args.hw), output_dir=args.output_dir, num_workers=args.num_workers)
    elif args.fxn == 'valsplit':
        valsplit(args.root_dir, args.split_seed, args.split_frac)
    




