import os
import torch
from data.influence_util import get_dataset, IdxDataset
from torch.utils.data import DataLoader
import jax

from typing import Any
from absl import app, flags
from flax.training import train_state, checkpoints
import numpy as np
import optax

from model import resnet

import os
import random

from influence.estimate import compute_influence
import pandas as pd


# additional hyper-parameters
flags.DEFINE_enum('dataset', 'cifar10c', ['cmnist', 'cifar10c', 'bffhq', 'cifar10_lff', 'waterbird'],
help='training dataset')
flags.DEFINE_enum('model', 'resnet', ['resnet'],
help='network architecture')
flags.DEFINE_integer('seed', 0, 
help='random number seed')
flags.DEFINE_integer('test_batch_size_total', 500, 
help='total batch size (not device-wise) for evaluation')
flags.DEFINE_integer('target_epoch', 5, 
help='target_epoch')

# Dataset Spec
flags.DEFINE_string("percent", "0.5pct",
help="percentage of conflict")
flags.DEFINE_integer('num_workers', 4, 
help='workers number')
flags.DEFINE_bool("use_type0", False,
help="whether to use type 0 CIFAR10C")
flags.DEFINE_bool("use_type1", False, 
help="whether to use type 1 CIFAR10C")
flags.DEFINE_integer("target_attr_idx", 0,
help="target_attr_idx")
flags.DEFINE_integer('batch_size', 30, 
help='train_batch_size')
flags.DEFINE_integer('hess_batch_size', 2000, 
help='train_batch_size')
flags.DEFINE_string("data_dir", "../dataset",
help="percentage of conflict")

# Optimization Spec
flags.DEFINE_float('lr', 0.001, 
help='learning rate')

# tunable hparams for generalization
flags.DEFINE_float('weight_decay', 0, 
help='l2 regularization coeffcient')
flags.DEFINE_integer('train_batch_size_total', 1000, 
help='total batch size (not device-wise) for training')

FLAGS = flags.FLAGS

class TrainState(train_state.TrainState):
    batch_stats: Any

def init_state(rng, batch, num_classes):
    net = resnet.ResNet18(num_classes=num_classes)
        
    variables = net.init(rng, batch)
    params, batch_stats = variables['params'], variables['batch_stats']
    tx = optax.chain(
        optax.adam(
            learning_rate=FLAGS.lr,
    )   
    )
    state = TrainState.create(
    apply_fn=net.apply, 
    params=params, 
    tx=tx, 
    batch_stats = batch_stats,
    ) 
    return state

def main(_):
    os.environ['PYTHONHASHSEED'] = str(FLAGS.seed)
    random.seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    torch.manual_seed(FLAGS.seed)
    torch.cuda.manual_seed(FLAGS.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # torch data loader
    data2model = {'cmnist': "ResNet18",
                    'cifar10c': "ResNet18",
                    'cifar10_lff': "ResNet18",
                    'waterbird': "ResNet18",
                    'bffhq': "ResNet18"}

    data2batch_size = {'cmnist': 256,
                        'cifar10c': 512,
                        'cifar10_lff': 512,
                        'waterbird': 32,
                        'bffhq': 512}
        
    data2preprocess = {'cmnist': True,
                        'cifar10c': True,
                        'cifar10_lff': True,
                        'bffhq': True,
                        'waterbird': True}
    
    data2img_shape = {'cmnist': (32, 32, 3),
                        'cifar10c': (32, 32, 3),
                        'cifar10_lff': (32, 32, 3),
                        'bffhq': (224, 224, 3),
                        'waterbird': (224, 224, 3)}
        
    train_dataset = get_dataset(
        FLAGS.dataset,
        data_dir=FLAGS.data_dir,
        dataset_split="train",
        transform_split="train",
        percent=FLAGS.percent,
        use_preprocess=data2preprocess[FLAGS.dataset],
        use_type0=FLAGS.use_type0,
        use_type1=FLAGS.use_type1,
    )

    if FLAGS.dataset == 'cifar10_lff':
        train_target_attr = torch.LongTensor(train_dataset.query_attr)
    elif FLAGS.dataset == 'waterbird':
        train_target_attr = torch.LongTensor(train_dataset.y_array)
    else:
        train_target_attr = []
        for data in train_dataset.data:
            train_target_attr.append(int(data.split('_')[-2]))
        train_target_attr = torch.LongTensor(train_target_attr)
    
    attr_dims = []
    attr_dims.append(torch.max(train_target_attr).item() + 1)
    num_classes = attr_dims[0]
    train_dataset = IdxDataset(train_dataset)

    # make loader
    train_loader = DataLoader(
        train_dataset,
        batch_size=FLAGS.batch_size,
        shuffle=False,
        num_workers=FLAGS.num_workers,
        pin_memory=True,
        drop_last=False
    )

    train_loader_hess = DataLoader(
        train_dataset,
        batch_size=FLAGS.hess_batch_size,
        shuffle=True,
        num_workers=FLAGS.num_workers,
        pin_memory=True,
        drop_last=False
    )
    
    hparams = [
        FLAGS.model,
        FLAGS.lr,
        FLAGS.train_batch_size_total,
        FLAGS.seed,
        ]
    hparams = '_'.join(map(str, hparams))
    res_dir = f'./res/{FLAGS.dataset}_{FLAGS.percent}/'+hparams

    # define pseudo-random number generator
    rng = jax.random.PRNGKey(FLAGS.seed)
    rng, rng_ = jax.random.split(rng)

    epoch = str(FLAGS.target_epoch)
    cur_directory = os.path.join(res_dir,'networks',epoch)

    # initialize network and optimizer
    state = init_state(
        rng_, 
        jax.random.normal(rng_, (1, *data2img_shape[FLAGS.dataset])), 
        num_classes,
    )
    state = checkpoints.restore_checkpoint(
        cur_directory,
        state,
    )
       
    output = compute_influence(state, train_loader, train_loader_hess, num_classes, len(train_dataset))
    
    index = output['index']
    true_label = output['true_label']
    bias_label = output['bias_label']
    self_influence = output['influence']

    df = pd.DataFrame(np.stack([index, true_label, bias_label, self_influence],axis=1), \
        columns=['index', 'true_label', 'bias_label', 'self_influence'])
    df.to_csv(cur_directory+f'/influence_train_df.tsv', sep='\t', index=False)


if __name__ == "__main__":
    app.run(main)