from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed.nn

@dataclass
class BufferState:
    cap: int
    dim: int
    device: torch.device
    dtype: torch.dtype
    ptr: int = 0
    size: int = 0
    img: Optional[torch.Tensor] = None  # [cap, dim]
    txt: Optional[torch.Tensor] = None  # [cap, dim]

    def _maybe_init(self, dim: int, device: torch.device, dtype: torch.dtype):
        if self.img is None or self.txt is None:
            self.dim = dim
            self.device = device
            self.dtype = dtype
            self.img = torch.empty((self.cap, dim), device=device, dtype=dtype)
            self.txt = torch.empty((self.cap, dim), device=device, dtype=dtype)

    def add(self, img: torch.Tensor, txt: torch.Tensor):
        if self.cap == 0 or img.numel() == 0:
            return
        assert img.shape == txt.shape and img.dim() == 2, "img/txt must be [N, D] and same shape"
        N, D = img.shape
        self._maybe_init(D, img.device, img.dtype)
        if N >= self.cap:
            self.img.copy_(img[-self.cap:])
            self.txt.copy_(txt[-self.cap:])
            self.ptr = 0
            self.size = self.cap
            return
        end = self.ptr + N
        if end <= self.cap:
            self.img[self.ptr:end] = img
            self.txt[self.ptr:end] = txt
            self.ptr = end
        else:
            first = end - self.cap
            self.img[:-first] = self.img[first:].clone()
            self.txt[:-first] = self.txt[first:].clone()
            self.img[-N:] = img
            self.txt[-N:] = txt
            self.ptr = self.cap
        self.size = min(self.size + N, self.cap)

    def get(self) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.size == 0:

            dev = self.device if self.img is not None else torch.device("cpu")
            dt = self.dtype if self.img is not None else torch.float32
            return (torch.empty((0, self.dim if self.img is not None else 0), device=dev, dtype=dt),
                    torch.empty((0, self.dim if self.img is not None else 0), device=dev, dtype=dt))
        if self.size < self.cap or self.ptr == 0:
            return self.img[: self.size], self.txt[: self.size]

        first = self.img[self.ptr:]
        second = self.img[: self.ptr]
        img = torch.cat([first, second], dim=0)
        first = self.txt[self.ptr:]
        second = self.txt[: self.ptr]
        txt = torch.cat([first, second], dim=0)
        return img, txt


class CLIPBufferLoss(nn.Module):
    def __init__(
        self,
        accelerator, 
        local_loss: bool = False,
        local_loss2: bool = False,  
        gather_with_grad: bool = False,  
        buffer_size: int = 0,
        dtype: torch.dtype = torch.float16,
        cache_labels: bool = False, 
    ) -> None:
        super().__init__()
        self.accelerator = accelerator
        self.dtype = dtype
        self.local_loss = local_loss
        self.local_loss2 = local_loss2
        self.gather_with_grad = gather_with_grad
        self.buffer = BufferState(cap=buffer_size, dim=0, device=torch.device("cpu"), dtype=dtype)

        # cache state
        self.cache_labels = cache_labels
        self.prev_num_logits = 0
        self.labels = {}

    def _gather_features(self, preds: torch.Tensor, targs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        if hasattr(self.accelerator, "num_processes") and self.accelerator.num_processes > 1:
            if self.gather_with_grad:

                all_preds = torch.cat(torch.distributed.nn.all_gather(preds), dim=0)
                all_targs = torch.cat(torch.distributed.nn.all_gather(targs), dim=0)
                return all_preds, all_targs
            else:
                with torch.no_grad():
                    all_preds = self.accelerator.gather(preds)
                    all_targs = self.accelerator.gather(targs)

                all_preds = list(all_preds.chunk(self.accelerator.num_processes, dim=0))
                all_targs = list(all_targs.chunk(self.accelerator.num_processes, dim=0))
                if not self.local_loss:
                    local_rank = self.accelerator.local_process_index
                    bs = preds.shape[0]
                    all_preds[local_rank] = preds
                    all_targs[local_rank] = targs
                    all_preds = torch.cat(all_preds, dim=0)  # [B * P, D]
                    all_targs = torch.cat(all_targs, dim=0)  # [B
                return all_preds, all_targs
        return preds, targs

    def get_ground_truth(self, device, num_logits) -> torch.Tensor:
        if self.prev_num_logits != num_logits or device not in self.labels:
            labels = torch.arange(num_logits, device=device, dtype=torch.long)
            if self.accelerator.num_processes > 1 and self.local_loss:
                labels = labels + num_logits * self.accelerator.local_process_index
            if self.cache_labels:
                self.labels[device] = labels
                self.prev_num_logits = num_logits
        else:
            labels = self.labels[device]
        return labels

    def get_logits(self, preds, targs, temp=0.125, logit_bias=None):

        if self.buffer.cap>0 and self.buffer.size>0:
            device = preds.device
            buf_preds, buf_targs = self.buffer.get() 
            if buf_preds.device != device:
                buf_preds = buf_preds.to(device=device)
                buf_targs = buf_targs.to(device=device)

        all_preds, all_targs = self._gather_features(preds, targs)
        if self.buffer.cap>0 and self.buffer.size>0:
            all_preds = torch.cat([all_preds, buf_preds], dim=0) 
            all_targs = torch.cat([all_targs, buf_targs], dim=0)
        if self.local_loss:
            logits_per_brain = (preds @ all_targs.t()) / temp
            logits_per_clip = (targs @ all_preds.t()) / temp
        else:
            logits_per_brain = (all_preds @ all_targs.t()) / temp
            logits_per_clip = (all_targs @ all_preds.t()) / temp

        if logit_bias is not None:
            logits_per_brain += logit_bias
            logits_per_clip += logit_bias

        return logits_per_brain, logits_per_clip
    
    def forward(
        self,
        preds: torch.Tensor, 
        targs: torch.Tensor, 
        temp=0.125, 
        add_to_buffer: bool = True,
        logit_bias=None
    ) -> torch.Tensor:
        """
        Returns the symmetric CLIP loss (mean of i->t and t->i cross-entropy).
        Only the *current* batch items are used as anchors and contribute to the loss.
        Negatives are augmented with (detached) global buffer items.
        """
        assert preds.shape == targs.shape and preds.dim() == 2
        device = preds.device
        dtype = preds.dtype


        if self.buffer.cap>0 and self.buffer.size>0:
            device = preds.device
            buf_preds, buf_targs = self.buffer.get()
            if buf_preds.device != device:
                buf_preds = buf_preds.to(device=device)
                buf_targs = buf_targs.to(device=device)

        all_gathered_preds, all_gathered_targs = self._gather_features(preds, targs)
        if self.buffer.cap>0 and self.buffer.size>0:
            all_preds = torch.cat([all_gathered_preds, buf_preds], dim=0)  
            all_targs = torch.cat([all_gathered_targs, buf_targs], dim=0) 
        else:
            all_preds, all_targs = all_gathered_preds, all_gathered_targs
        if self.local_loss:
            logits_per_brain = (preds @ all_targs.t()) / temp
            logits_per_clip = (targs @ all_preds.t()) / temp
        else:
            logits_per_brain = (all_preds @ all_targs.t()) / temp
            logits_per_clip = (all_targs @ all_preds.t()) / temp

        if logit_bias is not None:
            logits_per_brain += logit_bias
            logits_per_clip += logit_bias

        if self.local_loss2 and not self.local_loss:
            labels = self.get_ground_truth(device, logits_per_brain.shape[0] - self.buffer_size())
        else:
            labels = self.get_ground_truth(device, logits_per_brain.shape[0])

        if self.local_loss2 and not self.local_loss and self.buffer_size() > 0:
            loss_b2c = F.cross_entropy(logits_per_brain[:-self.buffer_size()], labels)
            loss_c2b = F.cross_entropy(logits_per_clip[:-self.buffer_size()], labels)
        else:
            loss_b2c = F.cross_entropy(logits_per_brain, labels)
            loss_c2b = F.cross_entropy(logits_per_clip, labels)
        loss = (loss_b2c + loss_c2b) / 2

        if add_to_buffer and self.buffer.cap > 0:
            with torch.no_grad():
                self.buffer.add(all_gathered_preds.detach(), all_gathered_targs.detach())
        return loss

    @torch.no_grad()
    def buffer_size(self) -> int:
        return self.buffer.size

    @torch.no_grad()
    def clear_buffer(self):
        self.buffer.ptr = 0
        self.buffer.size = 0

    def extra_repr(self) -> str:
        return (
            f"normalize={self.normalize}, fixed_temp={self.fixed_temp}, "
            f"buffer_cap={self.buffer.cap} (size={self.buffer.size})"
        )

class BiMixCoBufferLoss(CLIPBufferLoss):
    def __init__(
        self,
        accelerator, 
        local_loss: bool = False,
        local_loss2: bool = False,  
        gather_with_grad: bool = False, 
        buffer_size: int = 0,
        dtype: torch.dtype = torch.float16,
        cache_labels: bool = False, 
    ) -> None:
        super().__init__(accelerator, local_loss, local_loss2, gather_with_grad, buffer_size, dtype, cache_labels)
        self.labels_buffer = torch.tensor([], device=torch.device("cpu"))
    def get_ground_truth(self, device, num_logits, perm=None, betas=None) -> torch.Tensor:
        if self.prev_num_logits != num_logits or device not in self.labels:
            if perm is not None and betas is not None:
                if self.local_loss:
                    labels = torch.diag(betas)
                    labels[torch.arange(num_logits).to(device), perm] = 1 - betas
                    ll = [torch.zeros(num_logits,num_logits).to(device) for _ in self.accelerator.num_processes]
                    ll[self.accelerator.local_process_index] = labels
                    labels = torch.cat(ll, dim=1)
                    labels = torch.cat([labels, torch.zeros((self.labels.shape[0], self.buffer_size()), device=device)], dim=1)
                else:
                    B = perm.shape[0]
                    perm = self.accelerator.gather(perm)
                    betas = self.accelerator.gather(betas)
                    perm = list(perm.chunk(self.accelerator.num_processes, dim=0))
                    perm = [p+i*B for i, p in enumerate(perm)]
                    perm = torch.cat(perm, dim=0)
                    labels = torch.diag(betas)
                    labels[torch.arange(betas.shape[0]).to(device), perm] = 1 - betas
                    if self.buffer_size() > 0:
                        labels = torch.block_diag(labels, self.labels_buffer.to(device))
            else:
                labels = torch.arange(num_logits, device=device, dtype=torch.long)
                if self.accelerator.num_processes > 1 and self.local_loss:
                    labels = labels + num_logits * self.accelerator.local_process_index
                if self.cache_labels:
                    self.labels[device] = labels
                    self.prev_num_logits = num_logits
        else:
            labels = self.labels[device]
        return labels
    
    def forward(
        self,
        preds: torch.Tensor, 
        targs: torch.Tensor, 
        temp=0.125,
        perm=None, 
        betas=None, 
        add_to_buffer: bool = True,
        logit_bias=None
    ) -> torch.Tensor:
        assert preds.shape == targs.shape and preds.dim() == 2
        device = preds.device
        dtype = preds.dtype

        if self.buffer.cap>0 and self.buffer.size>0:
            device = preds.device
            buf_preds, buf_targs = self.buffer.get()
            if buf_preds.device != device:
                buf_preds = buf_preds.to(device=device)
                buf_targs = buf_targs.to(device=device)

        all_gathered_preds, all_gathered_targs = self._gather_features(preds, targs)
        if self.buffer.cap>0 and self.buffer.size>0:
            all_preds = torch.cat([all_gathered_preds, buf_preds], dim=0)  
            all_targs = torch.cat([all_gathered_targs, buf_targs], dim=0)  
        else:
            all_preds, all_targs = all_gathered_preds, all_gathered_targs
        if self.local_loss:
            logits_per_brain = (preds @ all_targs.t()) / temp
            logits_per_clip = (targs @ all_preds.t()) / temp
        else:
            logits_per_brain = (all_preds @ all_targs.t()) / temp
            logits_per_clip = (all_targs @ all_preds.t()) / temp
        if logit_bias is not None:
            logits_per_brain += logit_bias
            logits_per_clip += logit_bias

        assert logits_per_brain.shape == logits_per_clip.shape, f"logits shape must match, but get {logits_per_brain.shape} and {logits_per_clip.shape}"

        if self.local_loss2 and not self.local_loss:
            labels = self.get_ground_truth(device, logits_per_brain.shape[0] - self.buffer_size(), perm, betas)
        else:
            labels = self.get_ground_truth(device, logits_per_brain.shape[0], perm, betas)

        if self.local_loss2 and not self.local_loss and self.buffer_size() > 0:
            loss_b2c = F.cross_entropy(logits_per_brain[:-self.buffer_size()], labels[:-self.buffer_size()])
            loss_c2b = F.cross_entropy(logits_per_clip[:-self.buffer_size()], labels[:-self.buffer_size()])
        else:
            loss_b2c = F.cross_entropy(logits_per_brain, labels)
            loss_c2b = F.cross_entropy(logits_per_clip, labels)
        loss = (loss_b2c + loss_c2b) / 2

        if add_to_buffer and self.buffer.cap > 0:
            if labels.dim() == 1:
                labels = torch.diag(torch.ones(logits_per_brain.shape[0], device=device))
            if self.buffer_size() == 0:
                self.labels_buffer = labels.detach().cpu()
            else:
                if self.local_loss or self.local_loss2:
                    self.labels_buffer = torch.block_diag(self.labels_buffer, labels.detach().cpu())
                    if self.labels_buffer.shape[0] > self.buffer.cap:
                        self.labels_buffer = self.labels_buffer[-self.buffer.cap:,-self.buffer.cap:]
                else:
                    self.labels_buffer = torch.block_diag(self.labels_buffer, labels[:-self.buffer_size(),:-self.buffer_size()].detach().cpu())
                    if self.labels_buffer.shape[0] > self.buffer.cap:
                        self.labels_buffer = self.labels_buffer[-self.buffer.cap:,-self.buffer.cap:]
            with torch.no_grad():
                self.buffer.add(all_gathered_preds.detach(), all_gathered_targs.detach())
        return loss

class SoftCLIPBufferLoss(CLIPBufferLoss):
    def __init__(
        self,
        accelerator, 
        local_loss: bool = False,
        local_loss2: bool = False, 
        gather_with_grad: bool = False, 
        buffer_size: int = 0,
        dtype: torch.dtype = torch.float16,
        cache_labels: bool = False, 
    ) -> None:
        super().__init__(accelerator, local_loss, local_loss2, gather_with_grad, buffer_size, dtype, cache_labels)

    def forward(
        self,
        preds: torch.Tensor,  # [B, D]
        targs: torch.Tensor,  # [B, D]
        temp=0.125, 
        add_to_buffer: bool = True,
        logit_bias=None
    ) -> torch.Tensor:
        """
        Returns the symmetric CLIP loss (mean of i->t and t->i cross-entropy).
        Only the *current* batch items are used as anchors and contribute to the loss.
        Negatives are augmented with (detached) global buffer items.
        """
        assert preds.shape == targs.shape and preds.dim() == 2
        device = preds.device
        dtype = preds.dtype

        if self.buffer.cap>0 and self.buffer.size>0:
            device = preds.device
            buf_preds, buf_targs = self.buffer.get()
            if buf_preds.device != device:
                buf_preds = buf_preds.to(device=device)
                buf_targs = buf_targs.to(device=device)

        all_gathered_preds, all_gathered_targs = self._gather_features(preds, targs)
        if self.buffer.cap>0 and self.buffer.size>0:
            all_preds = torch.cat([all_gathered_preds, buf_preds], dim=0)
            all_targs = torch.cat([all_gathered_targs, buf_targs], dim=0)  
        else:
            all_preds, all_targs = all_gathered_preds, all_gathered_targs
        if self.local_loss:
            logits_per_brain = (preds @ all_targs.t()) / temp
            logits_per_clip = (targs @ all_preds.t()) / temp
            clip_clip = (targs @ all_targs.t()) / temp
        else:
            logits_per_brain = (all_preds @ all_targs.t()) / temp
            logits_per_clip = (all_targs @ all_preds.t()) / temp
            clip_clip = (all_targs @ all_targs.t()) / temp
        if logit_bias is not None:
            logits_per_brain += logit_bias
            logits_per_clip += logit_bias

        if self.local_loss2 and not self.local_loss and self.buffer_size() > 0:
            loss_b2c = -(logits_per_brain[:-self.buffer_size()].log_softmax(-1) * clip_clip[:-self.buffer_size()].softmax(-1)).sum(-1).mean()
            loss_c2b = -(logits_per_clip[:-self.buffer_size()].log_softmax(-1) * clip_clip[:-self.buffer_size()].softmax(-1)).sum(-1).mean()
        else:
            loss_b2c = -(logits_per_brain.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean()
            loss_c2b = -(logits_per_clip.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean()
        loss = (loss_b2c + loss_c2b) / 2

        if add_to_buffer and self.buffer.cap > 0:
            with torch.no_grad():
                self.buffer.add(all_gathered_preds.detach(), all_gathered_targs.detach())
        return loss