import copy
import logging
from typing import Iterable

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F

import models.vit as vit

logger = logging.getLogger()

def ortho_penalty(t):
    return ((t @t.T - torch.eye(t.shape[0]).cuda())**2).mean()

def tensor_prompt(a, b, c=None, ortho=False):
    if c is None:
        p = torch.nn.Parameter(torch.FloatTensor(a,b), requires_grad=True)
    else:
        p = torch.nn.Parameter(torch.FloatTensor(a,b,c), requires_grad=True)
    if ortho:
        nn.init.orthogonal_(p)
    else:
        nn.init.uniform_(p)
    return p


class CodaPrompt(nn.Module):
    def __init__(self,
                 pos_e_prompt   : Iterable[int] = (0,1,2,3,4),
                 len_e_prompt   : int   = 20,
                 e_pool         : int   = 10,
                 task_num       : int   = 10,
                 num_classes    : int   = 100,
                 backbone_name  : str   = None,
                 key_dim        : int   = 768,
                 ortho_mu       : float = 0,
                 **kwargs):
        super().__init__()

        self.kwargs = kwargs
        self.ortho_mu = ortho_mu
        self.task_num = task_num
        self.num_classes = num_classes

        self.task_count = 0

        # Backbone
        assert backbone_name is not None, 'backbone_name must be specified'
        self.add_module('backbone', timm.create_model(backbone_name, pretrained=True, num_classes=num_classes))
        for name, param in self.backbone.named_parameters():
            param.requires_grad = False
        self.backbone.fc.weight.requires_grad = True
        self.backbone.fc.bias.requires_grad   = True

        # Slice the eprompt
        self.key_d = key_dim
        self.num_pt_per_task = int(e_pool / task_num)

        self.e_pool = e_pool
        self.len_e_prompt = len_e_prompt
        self.e_length = len(pos_e_prompt) if pos_e_prompt else 0
        self.register_buffer('pos_e_prompt', torch.tensor(pos_e_prompt, dtype=torch.int64))
        for e in self.pos_e_prompt:
            p = tensor_prompt(e_pool, self.len_e_prompt, self.backbone.num_features)
            k = tensor_prompt(e_pool, self.key_d)
            a = tensor_prompt(e_pool, self.key_d)
            p = self.gram_schmidt(p)
            k = self.gram_schmidt(k)
            a = self.gram_schmidt(a)
            setattr(self, f'e_p_{e}',p)
            setattr(self, f'e_k_{e}',k)
            setattr(self, f'e_a_{e}',a)

    def prompt_tuning(self,
                      x        : torch.Tensor,
                      g_prompt : torch.Tensor,
                      e_prompt : torch.Tensor,
                      **kwargs):

        B, N, C = x.size()

        e_prompt = e_prompt.contiguous().view(B, self.e_length, self.len_e_prompt, C)
        e_prompt = e_prompt + self.backbone.pos_embed[:,:1,:].unsqueeze(1).expand(B, self.e_length, self.len_e_prompt, C)

        for n, block in enumerate(self.backbone.blocks):
            pos_e = ((self.pos_e_prompt.eq(n)).nonzero()).squeeze()
            if pos_e.numel() != 0:
                x = torch.cat((x, e_prompt[:, pos_e]), dim = 1)

            x = block(x)
            x = x[:, :N, :]
        return x

    def forward(self, inputs : torch.Tensor) :
        with torch.no_grad():
            x = self.backbone.patch_embed(inputs)
            B, N, D = x.size()

            cls_token = self.backbone.cls_token.expand(B, -1, -1)
            token_appended = torch.cat((cls_token, x), dim=1)
            x = self.backbone.pos_drop(token_appended + self.backbone.pos_embed)
            query = self.backbone.blocks(x)
            query = self.backbone.norm(query)[:, 0]

        g_p = None
        e_p = None
        s = self.task_count * self.num_pt_per_task
        f = (self.task_count+1) * self.num_pt_per_task
        loss = 0
        for e in self.pos_e_prompt:
            K = getattr(self,f'e_k_{e}')
            A = getattr(self,f'e_a_{e}')
            p = getattr(self,f'e_p_{e}')
            if self.training:
                if self.task_count > 0:
                    K = torch.cat((K[:s].detach().clone(),K[s:f]), dim=0)
                    A = torch.cat((A[:s].detach().clone(),A[s:f]), dim=0)
                    p = torch.cat((p[:s].detach().clone(),p[s:f]), dim=0)
                else:
                    K = K[s:f]
                    A = A[s:f]
                    p = p[s:f]
            else:
                K = K[0:f]
                A = A[0:f]
                p = p[0:f]

            # with attention and cosine sim
            # (b x 1 x d) * soft([1 x k x d]) = (b x k x d) -> attention = k x d
            a_querry = torch.einsum('bd,kd->bkd', query, A)
            # # (b x k x d) - [1 x k x d] = (b x k) -> key = k x d
            n_K = nn.functional.normalize(K, dim=1)
            q = nn.functional.normalize(a_querry, dim=2)
            aq_k = torch.einsum('bkd,kd->bk', q, n_K)
            # (b x 1 x k x 1) * [1 x plen x k x d] = (b x plen x d) -> prompt = plen x k x d
            P_ = torch.einsum('bk,kld->bld', aq_k, p) # B, len_e_prompt, d

            if e_p is None:
                e_p = P_
            else:
                e_p = torch.cat((e_p, P_), dim=1)

            if self.training and self.ortho_mu > 0:
                loss += ortho_penalty(K) * self.ortho_mu
                loss += ortho_penalty(A) * self.ortho_mu
                loss += ortho_penalty(p.view(p.shape[0], -1)) * self.ortho_mu

        e_p = e_p.unsqueeze(1)
        x = self.prompt_tuning(self.backbone.pos_drop(token_appended + self.backbone.pos_embed), g_p, e_p)
        x = self.backbone.norm(x)
        cls_token = x[:, 0]
        x = self.backbone.fc(cls_token)

        if self.training:
            return x, loss
        else:
            return x

    def process_task_count(self):
        self.task_count += 1

        if self.task_count != self.task_num:

            # in the spirit of continual learning, we will reinit the new components
            # for the new task with Gram Schmidt
            #
            # in the original paper, we used ortho init at the start - this modification is more 
            # fair in the spirit of continual learning and has little affect on performance
            # 
            # code for this function is modified from:
            # https://github.com/legendongary/pytorch-gram-schmidt/blob/master/gram_schmidt.py
            for e in self.pos_e_prompt:
                K = getattr(self,f'e_k_{e}')
                A = getattr(self,f'e_a_{e}')
                P = getattr(self,f'e_p_{e}')
                k = self.gram_schmidt(K)
                a = self.gram_schmidt(A)
                p = self.gram_schmidt(P)
                setattr(self, f'e_p_{e}',p)
                setattr(self, f'e_k_{e}',k)
                setattr(self, f'e_a_{e}',a)

    def loss_fn(self, output, target):
        return F.cross_entropy(output, target)

    def gram_schmidt(self, vv):

        def projection(u, v):
            denominator = (u * u).sum()

            if denominator < 1e-8:
                return None
            else:
                return (v * u).sum() / denominator * u

        # check if the tensor is 3D and flatten the last two dimensions if necessary
        is_3d = len(vv.shape) == 3
        if is_3d:
            shape_2d = copy.deepcopy(vv.shape)
            vv = vv.view(vv.shape[0],-1)

        # swap rows and columns
        vv = vv.T

        # process matrix size
        nk = vv.size(1)
        uu = torch.zeros_like(vv, device=vv.device)

        # get starting point
        pt = self.num_pt_per_task
        s = int(self.task_count * pt)
        f = int((self.task_count + 1) * pt)
        if s > 0:
            uu[:, 0:s] = vv[:, 0:s].clone()
        for k in range(s, f):
            redo = True
            while redo:
                redo = False
                vk = torch.randn_like(vv[:,k]).to(vv.device)
                uk = 0
                for j in range(0, k):
                    if not redo:
                        uj = uu[:, j].clone()
                        proj = projection(uj, vk)
                        if proj is None:
                            redo = True
                            logger.info('restarting!!!')
                        else:
                            uk = uk + proj
                if not redo: uu[:, k] = vk - uk
        for k in range(s, f):
            uk = uu[:, k].clone()
            uu[:, k] = uk / (uk.norm())

        # undo swapping of rows and columns
        uu = uu.T 

        # return from 2D
        if is_3d:
            uu = uu.view(shape_2d)
        
        return torch.nn.Parameter(uu) 