"""
From https://github.com/ecreager/eiil
"""

from wilds import get_dataset
from wilds.datasets.waterbirds_dataset import WaterbirdsDataset
from wilds.datasets.celebA_dataset import CelebADataset

from collections import defaultdict

from tqdm import tqdm

import torchvision.transforms as transforms

from torch.utils.data import DataLoader
import pathlib

import numpy as np

import torch

from torch import autograd
from torch import nn
from torch import optim

import pandas as pd
import pickle

DATASET_DICT = {
'waterbirds': WaterbirdsDataset,
'celebA': CelebADataset
}


def get_transformed_join_dataset_to_eiil(dataset_name, env_ix_metata=0, dir_path=None):
    """
    Transform dataset from wilds repo to make it work with EIIL
    code base

    Input
        env_ix_metada: index in metada file for dataset

    """
    dataset_class = DATASET_DICT[dataset_name]
    dataset = dataset_class(download=True, get_img_idx=True)
    test_transform = transforms.Compose([
            transforms.Resize((224,224)),
            #transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    train_data = dataset.get_subset('train', transform=test_transform)
    train_loader = DataLoader(
            train_data,
            shuffle=False,
            sampler=None,
            collate_fn=dataset.collate,
            batch_size=128)
    envs = list(set(np.array(dataset.metadata_array[:, 0])))
    train_envs = {env: defaultdict(list) for env in envs}
    for i, (img, label, img_ix) in tqdm(enumerate(train_loader)):
        for env in envs:
            filter_ix = dataset.metadata_array[img_ix, env_ix_metata] == env
            filtered_img = img[filter_ix]
            filtered_label = label[filter_ix].float()
            filtered_img_ix = img_ix[filter_ix]
            if filtered_img.shape[0] == 1:
                filtered_img_name = [np.array(dataset._input_array[filtered_img_ix])]
            else:
                filtered_img_name = dataset._input_array[filtered_img_ix]
            train_envs[env]['images'].append(filtered_img)
            train_envs[env]['images_id'].append(filtered_img_ix)
            train_envs[env]['labels'].append(filtered_label[:, None])
            train_envs[env]['images_filename'].append(filtered_img_name)

    for env, sub_d in train_envs.items():
        for k, v in sub_d.items():
            if k == 'images_filename':
                newv = []
                for l in list(v):
                    for e in l:
                        newv.append(e)
            else:
                newv = torch.cat(v, axis=0)
            train_envs[env][k] = newv


    if dir_path is not None:
        file_path = pathlib.Path(dir_path) / f'{dataset_name}_eiil.pickle'
        with open(file_path, 'wb') as handle:
            pickle.dump(train_envs, handle, protocol=pickle.HIGHEST_PROTOCOL)

    return [v for (k,v) in train_envs.items()]
    #return


def get_transformed_waterbirds_to_eiil(dataset_name, env_ix_metata=0, dir_path=None):
    """
    Transform dataset from wilds repo to make it work with EIIL
    code base

    Input
        env_ix_metada: index in metada file for dataset

    """
    dataset_class = DATASET_DICT[dataset_name]
    dataset = WaterbirdsDataset(download=True, get_img_idx=True)
    test_transform = transforms.Compose([
            transforms.Resize((224,224)),
            #transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    train_data = dataset.get_subset('train', transform=test_transform)
    train_loader = DataLoader(
            train_data,
            shuffle=False,
            sampler=None,
            collate_fn=dataset.collate,
            batch_size=128)
    envs = list(set(np.array(dataset.metadata_array[:, 0])))
    train_envs = defaultdict(list)
    for i, (img, label, img_ix) in tqdm(enumerate(train_loader)):
        filtered_img = img
        filtered_label = label.float()
        if filtered_img.shape[0] == 1:
            filtered_img_name = [np.array(dataset._input_array[img_ix])]
        else:
            filtered_img_name = dataset._input_array[img_ix]
        train_envs['images'].append(filtered_img)
        train_envs['images_id'].append(img_ix)
        train_envs['labels'].append(filtered_label[:, None])
        train_envs['images_filename'].append(filtered_img_name)

    for k, v in train_envs.items():
        if k == 'images_filename':
            newv = []
            for l in list(v):
                for e in l:
                    newv.append(e)
        else:
            newv = torch.cat(v, axis=0)
        train_envs[k] = newv


    if dir_path is not None:
        file_path = pathlib.Path(dir_path) / f'notjoin_{dataset_name}_eiil.pickle'
        with open(file_path, 'wb') as handle:
            pickle.dump(train_envs, handle, protocol=pickle.HIGHEST_PROTOCOL)

    return train_envs

def transform_envs_to_metadatafile(envs, dir_name, file_name='eiil_env_labels.csv'):
    """
    Transforms envs object to metadata file

    Input
        envs: list of dict
        file_name: str
    """
    dataset = WaterbirdsDataset(download=True, get_img_idx=True)
    df = pd.DataFrame(columns=['img_id', 'img_filename', 'env'])
    img_id_list = []
    img_filename_list = []
    env_list = []
    for ev in range(len(envs)):
        img_id = [int(x) for x in list(envs[ev]['images_id'])]
        img_id_list += img_id
        img_filename_list += envs[ev]['images_filename']
        env = [ev] * len(img_id)
        env_list += env
    for idx in range(len(dataset._input_array)):
        if idx not in img_id_list:
            img_id_list.append(idx)
            img_filename_list.append(dataset._input_array[idx])
            env_list.append(dataset.metadata_array[idx,0].item())
    df.loc[:, 'img_id'] = img_id_list
    df.loc[:, 'img_filename'] = img_filename_list
    df.loc[:, 'env'] = env_list

    dir_name = pathlib.Path(dir_name)
    file_path = dir_name.joinpath(file_name)

    df = df.sort_values('img_id')
    df['img_id'] += 1
    df  = df.reset_index(drop=True)
    df.to_csv(file_path, index=False)
    print(f'##### EEIL env saved at {file_path} #####')

    return df


def nll(logits, y, reduction='mean'):
  return nn.functional.binary_cross_entropy_with_logits(logits, y, reduction=reduction)

def mean_accuracy(logits, y):
  preds = (logits > 0.).float()
  return ((preds - y).abs() < 1e-2).float().mean()

def penalty(logits, y):
  scale = torch.tensor(1.).cuda().requires_grad_()
  loss = nll(logits * scale, y)
  grad = autograd.grad(loss, [scale], create_graph=True)[0]
  return torch.sum(grad**2)

def split_data_opt(envs, model, n_steps=20000, n_samples=-1, lr=0.01,
                    batch_size=None, join=True, no_tqdm=False):
  """Learn soft environment assignment."""

  if join:  # assumes first two entries in envs list are the train sets to joined
    print('pooling envs')
    # pool all training envs (defined as each env in envs[:-1])
    joined_train_envs = dict()
    for k in envs[0].keys():
        if k!= 'images_filename':
            if envs[0][k].numel() > 1:  # omit scalars previously stored during training
                joined_values = torch.cat((envs[0][k][:n_samples],
                                           envs[1][k][:n_samples]),
                                          0)
                joined_train_envs[k] = joined_values
            else:
                joined_values = evns[0][k] + envs[1][k]
    print('size of pooled envs: %d' % len(joined_train_envs['images']))
  else:
    if not isinstance(envs, dict):
      raise ValueError(('When join=False, first argument should be a dict'
                        ' corresponding to the only environment.'
                       ))
    print('splitting data from single env of size %d' % len(envs['images']))
    joined_train_envs = envs

  scale = torch.tensor(1.).cuda().requires_grad_()
  if batch_size:
    logits = []
    i = 0
    num_examples = len(joined_train_envs['images'])
    while i < num_examples:
      images = joined_train_envs['images'][i:i+64]
      images = images.cuda()
      logits.append(model(images).detach())
      i += 64
    logits = torch.cat(logits)
  else:
    logits = model(joined_train_envs['images'].cuda())
    logits = logits.detach()

  loss = nll(logits * scale, joined_train_envs['labels'].cuda(), reduction='none')

  env_w = torch.randn(len(logits)).cuda().requires_grad_()
  optimizer = optim.Adam([env_w], lr=lr)

  with tqdm(total=n_steps, position=1, bar_format='{desc}', desc='AED Loss: ', disable=no_tqdm) as desc:
    for i in tqdm(range(n_steps), disable=no_tqdm):
      # penalty for env a
      lossa = (loss.squeeze() * env_w.sigmoid()).mean()
      grada = autograd.grad(lossa, [scale], create_graph=True)[0]
      penaltya = torch.sum(grada**2)
      # penalty for env b
      lossb = (loss.squeeze() * (1-env_w.sigmoid())).mean()
      gradb = autograd.grad(lossb, [scale], create_graph=True)[0]
      penaltyb = torch.sum(gradb**2)
      # negate
      npenalty = - torch.stack([penaltya, penaltyb]).mean()
      # step
      optimizer.zero_grad()
      npenalty.backward(retain_graph=True)
      optimizer.step()
      desc.set_description('AED Loss: %.8f' % npenalty.cpu().item())

  print('Final AED Loss: %.8f' % npenalty.cpu().item())

  # split envs based on env_w threshold
  new_envs = []
  idx0 = (env_w.sigmoid()>.5)
  idx1 = (env_w.sigmoid()<=.5)
  # train envs
  # NOTE: envs include original data indices for qualitative investigation
  for _idx in (idx0, idx1):
    new_env = dict()
    for k, v in envs.items():
        if k != 'images_filename':
            new_env[k] = v[_idx]
        else:
            new_env[k] = list(np.array(v)[_idx.cpu().numpy()])
    new_envs.append(new_env)
  print('size of env0: %d' % len(new_envs[0]['images']))
  print('size of env1: %d' % len(new_envs[1]['images']))

  if join:  #NOTE: assume the user includes test set as part of arguments only if join=True
    new_envs.append(envs[-1])
    print('size of env2: %d' % len(new_envs[2]['images']))

  return new_envs

