import os
import torch
from data.influence_util import get_dataset, IdxDataset
from torch.utils.data import DataLoader
import jax

from typing import Any
from functools import partial
from absl import app, flags

import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
from flax.training import train_state, checkpoints
from flax.jax_utils import replicate
import numpy as np
import optax
from tqdm import tqdm

from utils import ckpt, metrics
from model import resnet
import einops
import random

# additional hyper-parameters
flags.DEFINE_integer('epoch_num', 5, 
help='epoch number of pre-training')
flags.DEFINE_enum('dataset', 'cifar10c', ['cmnist', 'cifar10c', 'bffhq', 'cifar10_lff', 'waterbird'],
help='training dataset')
flags.DEFINE_enum('model', 'resnet', ['resnet'],
help='network architecture')
flags.DEFINE_integer('seed', 0, 
help='random number seed')
flags.DEFINE_bool('eval', False, 
help='do not training')
flags.DEFINE_integer('test_batch_size_total', 1000, 
help='total batch size (not device-wise) for evaluation')

# Dataset Spec
flags.DEFINE_string("percent", "0.5pct",
help="percentage of conflict")
flags.DEFINE_integer('num_workers', 4, 
help='workers number')
flags.DEFINE_bool("use_type0", False,
help="whether to use type 0 CIFAR10C")
flags.DEFINE_bool("use_type1", False, 
help="whether to use type 1 CIFAR10C")
flags.DEFINE_integer("target_attr_idx", 0,
help="target_attr_idx")
flags.DEFINE_string("data_dir", "../dataset",
help="path for loading data")

# Optimization Spec
flags.DEFINE_float('lr', 0.001, 
help='learning rate')

# tunable hparams for generalization
flags.DEFINE_float('weight_decay', 0, 
help='l2 regularization coeffcient')
flags.DEFINE_integer('train_batch_size_total', 1000, 
help='total batch size (not device-wise) for training')

flags.DEFINE_float('q', 0.7,
help='gce q')

FLAGS = flags.FLAGS

class TrainState(train_state.TrainState):
    batch_stats: Any

def create_lr_sched(num_train):
    total_step = FLAGS.epoch_num * (num_train // FLAGS.train_batch_size_total)
    warmup_step = int(0.1 * total_step)
    return optax.warmup_cosine_decay_schedule(0.0, FLAGS.peak_lr, warmup_step, total_step)

def init_state(rng, batch, num_classes):
    # parsing model
    net = resnet.ResNet18(num_classes=num_classes)
        
    variables = net.init(rng, batch)
    params, batch_stats = variables['params'], variables['batch_stats']
    
    tx = optax.chain(
        optax.adam(
            learning_rate=FLAGS.lr,
        )   
    )

    state = TrainState.create(
        apply_fn=net.apply, 
        params=params, 
        tx=tx, 
        batch_stats = batch_stats,
        ) 
    
    return state

def loss_fn(params, state, batch, train):

    if train:
        logits, new_net_state = state.apply_fn(
            {'params':params, 'batch_stats': state.batch_stats},
            batch['x'], train=train, mutable=['batch_stats'],
        )
    else:
        logits = state.apply_fn(
            {'params':params, 'batch_stats': state.batch_stats},
            batch['x'], train=train,
        )
        new_net_state = None

    conf = (jax.nn.softmax(logits, axis=-1) * batch['y']).sum(axis=-1)
    loss_weight = jax.lax.stop_gradient((conf**FLAGS.q) * FLAGS.q)
    loss = (loss_weight * optax.softmax_cross_entropy(logits, batch['y'])).mean()

    wd = 0.5 * jnp.sum(jnp.square(ravel_pytree(params)[0]))
    loss_ = loss + FLAGS.weight_decay * wd
    acc = jnp.mean(
        jnp.argmax(logits, axis=-1) == jnp.argmax(batch['y'],axis=-1)
        )
    return loss_, (loss, wd, acc, new_net_state)

@partial(jax.pmap, axis_name='batch')
def opt_step(rng, state, batch):
    # batch = mixup.mixup(rng, batch)
    grad_fn = jax.grad(loss_fn, has_aux=True)
    grads, (loss, wd, acc, new_net_state) = grad_fn(
        state.params, 
        state, 
        batch, 
        True,
        )
    # sync and update
    grads = jax.lax.pmean(grads, axis_name='batch')

    batch_stats = jax.lax.pmean(new_net_state['batch_stats'], axis_name='batch')
    new_state = state.apply_gradients(
        grads=grads, batch_stats=batch_stats
    )
    # log norm of gradient
    grad_norm = jnp.sum(jnp.square(ravel_pytree(grads)[0]))
    return loss, wd, grad_norm, acc, new_state

def main(_):
    os.environ['PYTHONHASHSEED'] = str(FLAGS.seed)
    random.seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    torch.manual_seed(FLAGS.seed)
    torch.cuda.manual_seed(FLAGS.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # torch data loader
    data2batch_size = {'cmnist': 256,
                        'cifar10c': 256,
                        'cifar10_lff': 256,
                        'waterbird': 32,
                        'bffhq': 32}
        
    data2preprocess = {'cmnist': True,
                        'cifar10c': True,
                        'cifar10_lff': True,
                        'waterbird': True,
                        'bffhq': True}
    
    data2img_shape = {'cmnist': (32, 32, 3),
                        'cifar10c': (32, 32, 3),
                        'cifar10_lff': (32, 32, 3),
                        'waterbird': (224, 224, 3),
                        'bffhq': (224, 224, 3)}
    
    num_devices = jax.device_count()
    
    train_dataset = get_dataset(
        FLAGS.dataset,
        data_dir=FLAGS.data_dir,
        dataset_split="train",
        transform_split="train",
        percent=FLAGS.percent,
        use_preprocess=data2preprocess[FLAGS.dataset],
        use_type0=FLAGS.use_type0,
        use_type1=FLAGS.use_type1,
    )
    
    valid_dataset = get_dataset(
        FLAGS.dataset,
        data_dir=FLAGS.data_dir,
        dataset_split="valid",
        transform_split="valid",
        percent=FLAGS.percent,
        use_preprocess=data2preprocess[FLAGS.dataset],
        use_type0=FLAGS.use_type0,
        use_type1=FLAGS.use_type1
    )

    test_dataset = get_dataset(
        FLAGS.dataset,
        data_dir=FLAGS.data_dir,
        dataset_split="test",
        transform_split="valid",
        percent=FLAGS.percent,
        use_preprocess=data2preprocess[FLAGS.dataset],
        use_type0=FLAGS.use_type0,
        use_type1=FLAGS.use_type1
    )

    if FLAGS.dataset == 'cifar10_lff':
        train_target_attr = torch.LongTensor(train_dataset.query_attr)
    elif FLAGS.dataset == 'waterbird':
        train_target_attr = torch.LongTensor(train_dataset.y_array)
    else:
        train_target_attr = []
        for data in train_dataset.data:
            train_target_attr.append(int(data.split('_')[-2]))
        train_target_attr = torch.LongTensor(train_target_attr)
    
    attr_dims = []
    attr_dims.append(torch.max(train_target_attr).item() + 1)
    num_classes = attr_dims[0]
    train_dataset = IdxDataset(train_dataset)

    # make loader
    train_loader = DataLoader(
        train_dataset,
        batch_size=data2batch_size[FLAGS.dataset],
        shuffle=True,
        num_workers=FLAGS.num_workers,
        pin_memory=True,
        drop_last=False
    )

    valid_loader = DataLoader(
        valid_dataset,
        batch_size=FLAGS.test_batch_size_total,
        shuffle=True,
        num_workers=FLAGS.num_workers,
        pin_memory=True,
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=FLAGS.test_batch_size_total,
        shuffle=True,
        num_workers=FLAGS.num_workers,
        pin_memory=True,
    )

    hparams = [
        FLAGS.model,
        FLAGS.lr,
        FLAGS.train_batch_size_total,
        FLAGS.seed,
        ]
    hparams = '_'.join(map(str, hparams))
    res_dir = f'./res/{FLAGS.dataset}_{FLAGS.percent}/'+hparams

    print(res_dir)
    ckpt.check_dir(res_dir)
    ckpt.check_dir(os.path.join(res_dir, 'networks'))

    # define pseudo-random number generator
    rng = jax.random.PRNGKey(FLAGS.seed)
    rng, rng_ = jax.random.split(rng)

    # initialize network and optimizer
    state = init_state(
        rng_, 
        jax.random.normal(rng_, (1, *data2img_shape[FLAGS.dataset])), 
        num_classes,
        )
    if FLAGS.eval:
        state = checkpoints.restore_checkpoint(
            res_dir,
            state,
            )
    state = replicate(state)
    
    if not(FLAGS.eval):
        train_iter = iter(train_loader)
        train_num = len(train_dataset.dataset)

        pbar = tqdm(range(1,FLAGS.epoch_num+1))
        for epoch in pbar:
            for step in range(train_num // data2batch_size[FLAGS.dataset]):
                try:
                    index, data, attr, _ = next(train_iter)
                except:
                    train_iter = iter(train_loader)
                    index, data, attr, _ = next(train_iter)
                
                data = einops.rearrange(data, '(n b) c h w -> n b h w c', n=num_devices)
                data = jnp.asarray(data.numpy())

                label = attr[:, FLAGS.target_attr_idx]
                label = torch.nn.functional.one_hot(label, num_classes).reshape(num_devices, -1, num_classes)
                label = jnp.asarray(label.numpy())
                # attr = jnp.asarray(attr.numpy())
                
                batch_tr = {'x':data, 'y':label}
                
                rng, rng_ = jax.random.split(rng)
                loss, wd, grad_norm, acc, state = opt_step(
                    replicate(rng_), 
                    state, 
                    batch_tr,
                    )
                res = {
                    'epoch': epoch,
                    'step' : step,
                    'acc' : f'{np.mean(jax.device_get(acc)):.4f}',
                    'loss': f'{np.mean(jax.device_get(loss)):.4f}',
                    'wd' : f'{np.mean(jax.device_get(wd)):.4f}',
                    'grad_norm' : f'{np.mean(jax.device_get(grad_norm)):.4f}',
                    }
                pbar.set_postfix(res)
                
            if epoch in [5]:
                ckpt.check_dir(os.path.join(res_dir, 'networks', str(epoch)))
                ckpt.save_ckpt(state, os.path.join(res_dir, 'networks', str(epoch)))
                acc_tr = metrics.acc_torch_dataset(state, valid_loader, num_classes)
                res['acc_tr'] = f'{acc_tr:.4f}'
                acc_te = metrics.acc_torch_dataset(state, test_loader, num_classes)
                res['acc_te'] = f'{acc_te:.4f}'
                ckpt.dict2tsv(res, res_dir+'/log.tsv')
    
    # evaluate
    res = {}
    acc_tr = metrics.acc_torch_dataset(state, valid_loader, num_classes)
    res['acc_tr'] = f'{acc_tr:.4f}'
    acc_te = metrics.acc_torch_dataset(state, test_loader, num_classes)
    res['acc_te'] = f'{acc_te:.4f}'
    ckpt.dict2tsv(res, res_dir+'/last.tsv')

if __name__ == "__main__":
    app.run(main)
