from collections import OrderedDict

import math
import random
import torch
from torch import optim
from einops import rearrange, repeat
from copy import deepcopy
from utils import rsvrBase


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def inner_adapt(P, wrapper, task_data, inner_lr=1e-2, num_steps=4,
                first_order=False, params=None, order=None):

    loss = 0.
    batch_size, h, w = task_data.size(0), task_data.size(-2), task_data.size(-1)
    params = wrapper.get_batch_params(params, batch_size)

    """ inner gradient step """
    losses_in, grads_in = [], [] # (step * iter, b, c, h, w)
    loss_in_log, res_in_log = [], [] # (step, t, b, c, h//n, w//n), (step, b, c, h, w)

    num_iters = P.inner_iter if wrapper.training else P.tto
    for step_inner in range(num_steps):
        for step_iter in range(num_iters):
            if P.data_type == 'img':
                # loss_in: (b, c, h, w), loss_in_div: (t, b, c, h//n, w//n), res_in: (b, c, h, w)
                params, loss_in, loss_in_div, res_in, grad_in = inner_loop_step_img(P, wrapper, params, task_data,
                                                                                    inner_lr, first_order, step_inner, step_iter)
            elif P.data_type == 'video':
                params, loss_in, loss_in_div, res_in, grad_in = inner_loop_step_video(P, wrapper, params, task_data,
                                                                                    inner_lr, first_order, step_inner, step_iter)
            losses_in.append(loss_in)
            grads_in.append(grad_in)

            if step_iter == num_iters - 1:
                loss_in_log.append(loss_in_div)
                res_in_log.append(res_in)

    return params, torch.stack(losses_in), torch.stack(loss_in_log), res_in_log, grads_in


def inner_loop_step_img(P, wrapper, params, task_data, inner_lr=1e-2, first_order=False, step_inner=-1, step_iter=-1):
    num_steps = P.inner_step
    is_incremental = P.incremental

    b, c, h, w = task_data.size()
    wrapper.decoder.zero_grad()

    with torch.enable_grad():
        loss_in, res_in = wrapper(task_data, params=params, step_inner=step_inner, step_iter=step_iter) # (b, c, h, w)
        loss_in_div = divide_loss(loss_in, num_steps, P.order) # (t, b, c, h//n, w//n)

        if is_incremental:
            loss_grad = loss_in_div[step_inner].view(b, -1).mean(dim=-1)
        else:
            loss_grad = loss_in.view(b, -1).mean(dim=-1)

        grads = torch.autograd.grad(
            loss_grad.mean() * b,
            params.values(),
            create_graph=not first_order,
            allow_unused=True,
        )
        updated_params = OrderedDict()
        for (name, param), grad in zip(params.items(), grads):
            if grad is None:
                grad = 0.
            if P.oml:
                if f'layers.{P.oml_layer}.linear' in name:
                    updated_params[name] = param - inner_lr * grad
                else:
                    updated_params[name] = param
            else:
                updated_params[name] = param - inner_lr * grad

    return updated_params, loss_in, loss_in_div, res_in, grads


def inner_loop_step_video(P, wrapper, params, task_data, inner_lr=1e-2, first_order=False, step_inner=-1, step_iter=-1):
    num_steps = P.inner_step
    is_incremental = P.incremental
    b, t, c, h, w = task_data.size()
    wrapper.decoder.zero_grad()

    if P.prog:
        truncated_params = OrderedDict()
        for (name, param) in params.items():
            if 'bias' in name:
                feats_new = P.dim_hidden // P.inner_step * (step_inner+1) + P.dim_hidden % P.inner_step
                feats_prev = P.dim_hidden // P.inner_step * step_inner + P.dim_hidden % P.inner_step
                b_dec = wrapper.get_batch_params(None, P.batch_size)[name]
                if f'layers.{P.oml_layer}.linear' in name:
                    b_new = b_dec[..., :feats_new]
                    if step_iter == 0:
                        b_new[..., :feats_prev] = param[..., :feats_prev]
                        truncated_params[name] = b_new
                    else:
                        truncated_params[name] = param
                else:
                    truncated_params[name] = param
            elif 'weight' in name:
                feats_new = P.dim_hidden // P.inner_step * (step_inner+1) + P.dim_hidden % P.inner_step
                feats_prev = P.dim_hidden // P.inner_step * step_inner + P.dim_hidden % P.inner_step
                W_dec = wrapper.get_batch_params(None, P.batch_size)[name]
                if f'layers.{P.oml_layer}.linear' in name:
                    W_new = W_dec[..., :feats_new, :P.dim_hidden]
                    if step_iter == 0:
                        W_new[..., :feats_prev, :P.dim_hidden] = param[..., :feats_prev, :P.dim_hidden]
                        truncated_params[name] = W_new
                    else:
                        truncated_params[name] = param
                elif f'layers.{P.oml_layer+1}.linear' in name:
                    W_new = W_dec[..., :P.dim_hidden, :feats_new]
                    if step_iter == 0:
                        W_new[..., :P.dim_hidden, :feats_prev] = param[..., :P.dim_hidden, :feats_prev]
                        truncated_params[name] = W_new
                    else:
                        truncated_params[name] = param
                else:
                    truncated_params[name] = param

    with torch.enable_grad():
        if P.prog:
            loss_in, res_in = wrapper(task_data, params=truncated_params, step_inner=step_inner, step_iter=step_iter)
            loss_in_div = rearrange(loss_in, 'b t c h w -> t b c h w')

            if is_incremental:
                loss_grad = loss_in_div[step_inner].view(b, -1).mean(dim=-1)
            else:
                loss_grad = loss_in.view(b, -1).mean(dim=-1)

            grads = torch.autograd.grad(
                loss_grad.mean(),
                truncated_params.values(),
                create_graph=not first_order,
                allow_unused=True,
            )

            updated_params = OrderedDict()
            for (name, param), grad in zip(truncated_params.items(), grads):
                if grad is None:
                    grad = 0.
                if f'layers.{P.oml_layer}.linear' in name:
                    if P.frozen:
                        width = P.dim_hidden // P.inner_step * step_inner + P.dim_hidden % P.inner_step # Frozen
                    else:
                        width = 0 # Unfrozen
                    if 'weight' in name:
                        i, j = slice(0, width), slice(None)
                        grad[..., i, j] = 0
                    elif 'bias' in name:
                        i = slice(0, width)
                        grad[..., i] = 0
                    updated_params[name] = param - inner_lr * grad
                elif f'layers.{P.oml_layer+1}.linear' in name:
                    if P.frozen:
                        width = P.dim_hidden // P.inner_step * step_inner + P.dim_hidden % P.inner_step # Frozen
                    else:
                        width = 0 # Unfrozen
                    if 'weight' in name:
                        i, j = slice(None), slice(0, width)
                        grad[..., i, j] = 0
                        updated_params[name] = param - inner_lr * grad
                    elif 'bias' in name:
                        updated_params[name] = param
                else:
                    updated_params[name] = param
        else:
            loss_in, res_in = wrapper(task_data, params=params, step_inner=step_inner, step_iter=step_iter) # (b, c, h, w)
            loss_in_div = rearrange(loss_in, 'b t c h w -> t b c h w')

            if is_incremental:
                loss_grad = loss_in_div[step_inner].view(b, -1).mean(dim=-1)
            else:
                loss_grad = loss_in.view(b, -1).mean(dim=-1)

            grads = torch.autograd.grad(
                loss_grad.mean(),
                params.values(),
                create_graph=not first_order,
                allow_unused=True,
            )
            updated_params = OrderedDict()
            for (name, param), grad in zip(params.items(), grads):
                if grad is None:
                    grad = 0.
                if P.oml:
                    if f'layers.{P.oml_layer}.linear' in name:
                        updated_params[name] = param - inner_lr * grad
                    else:
                        updated_params[name] = param
                else:
                    updated_params[name] = param - inner_lr * grad

    return updated_params, loss_in, loss_in_div, res_in, grads


def divide_loss(loss, num_steps, order):
    def is_square(i: int) -> bool:
        return i == math.isqrt(i) ** 2

    b, c, h, w = loss.size()
    # assert is_square(num_steps)
    n = int(math.isqrt(num_steps))

    if order == 'raster':
        tiles = [loss[..., i:i+h//n, j:j+w//n] for i in range(0, h, h//n) for j in range(0, w, w//n)]
        return torch.stack(tiles) # (t, b, c, h//n, w//n)
    elif order == 'rowwise':
        tiles = torch.split(loss, loss.size(-2) // num_steps, dim=-2)
        return torch.stack(tiles)
    elif order == 'colwise':
        tiles = torch.split(loss, loss.size(-1) // num_steps, dim=-1)
        return torch.stack(tiles)
    else:
        raise NotImplementedError()
