import pathlib
import time
import logging
import csv
import re
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
import tree
from skimage.util import view_as_blocks

try:
    import wandb
except ImportError:
    pass

class Logger(object):
    def __init__(self, args, verbose=0, log_dir=None, log_csv=True, use_wandb=True):
        if log_dir is not None:
            self.log_dir = pathlib.Path(log_dir)
            self.log_dir.mkdir(parents=True, exist_ok=True)
        else:
            self.log_dir = None

        # TODO add handler to log to file
        logging.basicConfig(
            level=logging.DEBUG if verbose else logging.INFO,
            format='[%(asctime)s %(levelname)s]: %(message)s'
        )

        self.log_csv = log_csv
        if self.log_csv:
            self.columns = set()
            self.data = []
            self.idx = []

        self.use_wandb = use_wandb
        if self.use_wandb:
            wandb.init(project='prompt-subtasks', id=args.exp_id, config=args)

    def _wandb_print(self, msg, level='INFO'):
        # wandb doesn't track logging.info/debug/msg correctly...
        print(f'[{time.asctime()} {level}]: {msg}')

    def log(self, values: dict, step=None):
        if self.use_wandb:
            wandb.log(values, step=step)
        if self.log_csv:
            self.idx.append(step)
            self.data.append(dict(values))
            self.columns.update(values.keys())

    def info(self, msg):
        if self.use_wandb:
            self._wandb_print(msg, 'INFO')
        else:
            logging.info(msg)

    def debug(self, msg):
        if self.use_wandb:
            self._wandb_print(msg, 'DEBUG')
        else:
            logging.debug(msg)

    def error(self, msg):
        if self.use_wandb:
            self._wandb_print(msg, 'ERROR')
        else:
            logging.error(msg)

    def exit(self):
        if self.log_csv:
            cols = list(sorted(self.columns))
            formatted_columns = ['step'] + cols
            formatted_data = []
            for i, row in zip(self.idx, self.data):
                formatted_row = [i]
                for col in cols:
                    formatted_row.append(row.get(col, None))
                formatted_data.append(formatted_row)
            '''
            if self.use_wandb:
                table = wandb.Table(columns=formatted_columns, data=formatted_data)
                wandb.log({'table': table})
            '''
            if self.log_dir is not None:
                with (self.log_dir / 'log.csv').open('w') as tmp:
                    writer = csv.writer(tmp)
                    writer.writerow(formatted_columns)
                    for row in formatted_data:
                        writer.writerow(row)

class RunningAverages(object):
    def __init__(self):
        self.metrics = {}
    def push(self, key, value, batch_count):
        count, average = self.metrics.get(key, (0, 0.0))
        count += batch_count
        average = average * (count - batch_count) / count \
            + value * batch_count / count
        self.metrics[key] = (count, average)
    def pushes(self, keyvalues, batch_count):
        for key, value in keyvalues.items():
            self.push(key, value, batch_count)
    def merge(self, ra):
        for key, (count, val) in ra.metrics.items():
            self.push(key, val, count)
    def get_values(self):
        return {key: val[1] for key, val in self.metrics.items()}

def token_stats(inputs, targets, loss_mask, tasks=None, shift=False):
    if shift:
        targets = targets[..., 1:].contiguous()
        inputs = inputs[..., :-1].contiguous()
        loss_mask = loss_mask[..., 1:].contiguous()

    running_metrics = RunningAverages()
    batch_size = inputs.shape[0]

    mask = loss_mask > 0.5
    matches = (inputs == targets) & mask
    token_accuracy = (matches.sum(dim=1) / loss_mask.sum(dim=1))
    full_accuracy = (matches.sum(dim=1) >= loss_mask.sum(dim=1)).float()
    running_metrics.push('token_accuracy', token_accuracy.mean().item(), batch_size)
    running_metrics.push('full_accuracy', full_accuracy.mean().item(), batch_size)

    diverges = (inputs != targets) & mask
    diverges[:, -1] = True# Edge case (100% match)
    diverge_idx = diverges.int().argmax(dim=1)
    norm_diverge = torch.cumsum(mask, dim=1)
    norm_diverge_idx = norm_diverge[range(batch_size), diverge_idx]
    diverge_point = (norm_diverge_idx / mask.sum(dim=1))
    running_metrics.push('diverge_point', diverge_point.mean().item(), batch_size)

    # subtask accuracy
    mask_pad = F.pad(loss_mask.int(), (1, 0), value=0)
    delta = mask_pad[:,1:] - mask_pad[:,:-1]
    subtasks = (delta == 1).cumsum(1) * mask
    acc_numer = torch.zeros(inputs.shape[0], device=inputs.device)
    acc_denom = torch.zeros(inputs.shape[0], device=inputs.device)
    for i in range(1, subtasks.max().item() + 1):
        subtask_mask = subtasks == i
        subtask_matches = (inputs == targets) & subtask_mask
        numer = subtask_matches.sum(1)
        denom = subtask_mask.sum(1) + 1e-5
        acc_numer += (numer / denom) > (1 - 1e-3)
        acc_denom += denom > 0.5
    subtask_accuracy = acc_numer / (acc_denom + 1e-3)
    running_metrics.push('subtask_accuracy', subtask_accuracy.mean().item(), batch_size)

    if tasks is not None:
        for task in set(tasks):
            mask = [task == t for t in tasks]
            task_size = sum(mask)
            running_metrics.push(f'task{task}_token_accuracy', token_accuracy[mask].mean().item(), task_size)
            running_metrics.push(f'task{task}_full_accuracy', full_accuracy[mask].mean().item(), task_size)
            running_metrics.push(f'task{task}_diverge_point', diverge_point[mask].mean().item(), task_size)
            running_metrics.push(f'task{task}_subtask_accuracy', subtask_accuracy[mask].mean().item(), task_size)


    return running_metrics

def _seq_to_subtasks(seq, sep=r'\d+\.'):
    st = re.split(sep, seq)
    st = st[1:]# skip split before first number (empty)
    return [s.strip() for s in st]

def subtask_stats(samples, targets, task_type='cooking'):
    metrics = RunningAverages()
    if task_type == 'cooking':
        sep = r'\d+\.'
    elif task_type == 'alf':
        sep = r' \[SEP\] '
    else:
        raise NotImplementedError

    for sample, target in zip(samples, targets):
        st_sample = _seq_to_subtasks(sample, sep)
        st_target = _seq_to_subtasks(target, sep)
        exact_match = []
        recall = []
        for i in range(len(st_target)):
            t = st_target[i]
            t_type = t.split(maxsplit=1)[0].lower()
            s = st_sample[i] if i < len(st_sample) else ''
            exact_match.append(t == s)
            recall.append(t in st_sample)

            metrics.push(f'{t_type}_accuracy', exact_match[-1], 1)
            metrics.push(f'{t_type}_recall', recall[-1], 1)

        metrics.push('subtask_accuracy', np.mean(exact_match), len(st_target))
        metrics.push('subtask_recall', np.mean(recall), len(st_target))
        metrics.push('nsubtask_accuracy', np.mean(exact_match), 1)
        metrics.push('nsubtask_recall', np.mean(recall), 1)

    return metrics

def guided_gen_stats(sample_actions, sample_obs, target_actions, target_obs):
    metrics = RunningAverages()
    traj_length = len(sample_actions)
    a_exact_match = []
    o_exact_match = []
    a_recall = []
    for s_action, s_obs, t_action, t_obs in zip(sample_actions, sample_obs, target_actions, target_obs):
        action_type = t_action.split(maxsplit=1)[0].lower()
        a_exact_match.append(s_action == t_action)
        a_recall.append(t_action in sample_actions)

        metrics.push(f'{action_type}_accuracy', a_exact_match[-1], 1)
        metrics.push(f'{action_type}_recall', a_recall[-1], 1)

        o_exact_match.append(s_obs == t_obs)
    metrics.push('full_accuracy', np.all(a_exact_match), traj_length)
    metrics.push('subtask_accuracy', np.mean(a_exact_match), traj_length)
    metrics.push('subtask_recall', np.mean(a_recall), traj_length)
    metrics.push('obs_accuracy', np.mean(o_exact_match), traj_length)
    metrics.push('full_obs_accuracy', np.all(o_exact_match), traj_length)

    return metrics

def shifted_cross_ent(inputs, logits, loss_mask=None):
    # Shift so that tokens < n predict n
    shift_labels = inputs[..., 1:].contiguous()
    shift_logits = logits[..., :-1, :].contiguous()
    # Calculate per-token loss
    loss_fct = CrossEntropyLoss(reduce=False)
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    if loss_mask is not None:
        loss_mask = loss_mask[..., 1:].contiguous().view(-1)
        loss *= loss_mask
    return loss.view(inputs.shape[0], inputs.shape[1] - 1)

def keytoken_weighted_loss(inputs, logits, keytoken_ids, alpha=1.0):
    # Shift so that tokens < n predict n
    shift_labels = inputs[..., 1:].contiguous()
    shift_logits = logits[..., :-1, :].contiguous()
    # Calculate per-token loss
    loss_fct = CrossEntropyLoss(reduce=False)
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    # Resize and average loss per sample
    loss_per_sample = loss.view(shift_logits.size(0), shift_logits.size(1)).mean(axis=1)
    # Calculate and scale weighting
    weights = torch.stack([(inputs == kt).float() for kt in keytoken_ids]).sum(
        axis=[0, 2]
    )
    weights = alpha * (1.0 + weights)
    # Calculate weighted average
    weighted_loss = (loss_per_sample * weights).mean()
    return weighted_loss

def lerp(t, warmup, peak):
    t = np.clip(t, warmup, peak) - warmup
    return t / (peak - warmup)

def qerp(t, p0, p1, p2):
    return (1 - t) * ((1 - t) * p0 + t * p1) + t * ((1 - t) * p1 + t * p2)

def cerp(t, p0, p1, p2, p3):
    return (1 - t) * qerp(t, p0, p1, p2) + t * qerp(t, p1, p2, p3)

def get_lw_scheduler(lw_id):
    sched_id, *args = lw_id.split('.')
    args = [int(arg) / 100. for arg in args]
    if sched_id == 'unif':
        return lambda t: args[0]
    elif sched_id == 'lerp':
        return lambda t: args[0] + (args[1] - args[0]) * lerp(t, args[2], args[3])
    elif sched_id == 'qerp':
        return lambda t: qerp(t, args[0], args[2], args[1])
    elif sched_id == 'cerp':
        return lambda t: cerp(t, args[0], args[2], args[3], args[1])
    else:
        raise NotImplementedError(f'lw scheduler {sched_id} not implemented')

def map_structure_batched(func, struct, pre_func=lambda _: _, post_func=lambda _: _, batch_size=4):
    '''
    struct : tree of tensors
    '''
    flat = torch.stack(tree.flatten(struct))# N x tensor dims
    flat = pre_func(flat)
    out_chunks = []
    for chunk in torch.split(flat, batch_size):# batch_size x tensor dims
        out_chunks.append(func(chunk))
    out = torch.cat(out_chunks)# N x new tensor dims
    out = post_func(out)
    out = tree.unflatten_as(struct, list(out))# tree of new_tensor dims
    return out

def map_structure_stacked(func, struct):
    flat = torch.stack(tree.flatten(struct))# N x tensor dims
    flat = func(flat)
    out = tree.unflatten_as(struct, list(flat))
    return out

def unravel_as(struct, outer):
    if not tree.is_nested(struct):
        struct = list(struct)
    return tree.map_structure_up_to(struct, (lambda inner, val: tree.map_structure(lambda _: val, inner)), outer, struct)

def patchify(im, grid_dim):
    B, C, H, W = im.shape
    assert C == 3
    H_new = H // grid_dim[0]
    W_new = W // grid_dim[1]
    x = im.unfold(2, H_new, H_new)\
            .unfold(3, W_new, W_new)\
            .permute(0, 2, 3, 1, 4, 5)\
            .contiguous()\
            .view(B * grid_dim[0] * grid_dim[1], C, H_new, W_new)
    return x

def patchify_np(im, grid_dim):
    B, H, W, C = im.shape
    assert C == 3
    H_new = H // grid_dim[0]
    W_new = W // grid_dim[1]
    x = view_as_blocks(im, (1, grid_dim[0], grid_dim[1], 1))\
            .transpose(4, 5, 6, 7, 0, 1, 2, 3)
    x = np.ascontiguousarray(x).reshape(B * grid_dim[0] * grid_dim[1], H_new, W_new, C)
    return x
