import os
import math
import numpy as np

import torch
from torch import nn, optim
import torch.nn.functional as F

from utils import *
from filters import filters
from funcs import *
import timm

def scaled_dot_product(q, k, v, attn_mask=None, attn_drop=None):
    
    def _expand_mask(mask):
        assert mask.ndim >= 2, "Mask must be at least 2-dimensional with seq_length x seq_length"
        if mask.ndim == 3:
            mask = mask.unsqueeze(1)
        while mask.ndim < 4:
            mask = mask.unsqueeze(0)
        return mask
    
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    
    if attn_mask is not None:
        attn_mask = _expand_mask(attn_mask)
        attn_logits = attn_logits.masked_fill(attn_mask == 0, -9e15)
        
    attention = F.softmax(attn_logits, dim=-1)
    if attention is not None:
        attention = attn_drop(attention)
        
    values = torch.matmul(attention, v)
    return values, attention

class MultiheadAttention(nn.Module):
    def __init__(
            self,
            dim,
            num_heads=8,
            bias=False,
            qk_norm=False,
            qk_proj = True,
            attn_drop=0.,
            proj_drop=0.,
            norm_layer=nn.LayerNorm,
    ):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.q_proj = nn.Linear(dim, dim, bias=bias) if qk_proj else nn.Identity()
        self.k_proj = nn.Linear(dim, dim, bias=bias) if qk_proj else nn.Identity()
        self.v_proj = nn.Linear(dim, dim, bias=bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, query, key, value, attn_mask=None, return_attention=True):
        B, N, C = query.shape
        v = self.v_proj(value).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        q = self.q_proj(query).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = self.k_proj(key).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
                
        x, attn = scaled_dot_product(q, k, v, attn_mask=attn_mask, attn_drop=self.attn_drop)
        
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        
        if return_attention:
            return x, attn
        else:
            return x
        
class CosineWarmupScheduler(optim.lr_scheduler._LRScheduler):

    def __init__(self, optimizer, warmup, max_iters, min_lr=0, base_lr=1e-3):
        self.warmup = warmup
        self.max_num_iters = max_iters
        self.min_lr = min_lr
        self.base_lr = base_lr
        super().__init__(optimizer)
        assert all([base_lr==self.base_lr for base_lr in self.base_lrs])

    def get_lr(self):
        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        return [base_lr * lr_factor for base_lr in self.base_lrs]

    def get_lr_factor(self, epoch):
        if epoch <= self.warmup:
            lr_factor = epoch * 1.0 / self.warmup
        else:
            scale = 0.5 * (1 + np.cos(np.pi * (epoch - self.warmup) / (self.max_num_iters - self.warmup)))
            lr_factor = scale + (1 - scale) * self.min_lr / self.base_lr
        return lr_factor