# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""
A simple, flexible implementation of a GPT model.
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
"""

import math
from typing import Any, Mapping

import torch
import torch.nn as nn
import torch.nn.functional as F
from composer.metrics.nlp import LanguageCrossEntropy, Perplexity
from composer.models.base import ComposerModel
from torchmetrics import Metric
from torch import Tensor

import numpy as np
import os


class ComposerMosaicEnsGPT(ComposerModel):

    def __init__(self, cfg, model1=None, model2=None, ens_type='min'):
        super().__init__()
        if model1 is None:
            self.model1 = ComposerMosaicGPT(cfg)
        else:
            self.model1 = model1
        if model2 is None:
            self.model2 = ComposerMosaicGPT(cfg)
        else:
            self.model2 = model2
        self.ens_type = ens_type
        self.train_metrics = {
            'LanguageCrossEntropy': LanguageCrossEntropy(cfg.vocab_size),
            'Perplexity': Perplexity(),
        }
        self.eval_metrics = {
            'LanguageCrossEntropy': LanguageCrossEntropy(cfg.vocab_size),
            'Perplexity': Perplexity(),
        }

    def get_targets(self, batch):
        targets = torch.roll(batch["labels"], shifts=-1)
        targets[:, -1] = -100
        return targets

    def forward(self, batch):
        out1 = self.model1.model(batch['input_ids'],
                          key_padding_mask=batch['attention_mask'].bool())
        if self.ens_type == '1':
            return out1
        
        out2 = self.model2.model(batch['input_ids'],
                            key_padding_mask=batch['attention_mask'].bool())


        if self.ens_type == 'random':
            if torch.rand(1)[0] > .5:
                return out1
            else:
                return out2

        out1 = out1.view(-1, out1.size(-1))
        out1 = F.softmax(out1, dim=-1)
        out2 = out2.view(-1, out2.size(-1))
        out2 = F.softmax(out2, dim=-1)

        if self.ens_type == 'min':
            out_ens = torch.minimum(out1, out2)
        elif self.ens_type == 'am':
            out_ens = (out1 + out2) / 2
        elif self.ens_type == 'gm':
            out_ens = ((out1+1e-8) ** .5) * ((out2+1e-8) ** .5)
        elif self.ens_type == 'kl_opt':
            opt_val = 10000
            opt_lam = -1
            for lam2 in range(1, 20):
                lam = lam2 / 20
                out_ens = ((out1+1e-8)**lam) * ((out2+1e-8)**(1-lam))
                out_ens /= out_ens.sum(dim=-1, keepdim=True)
                val = torch.sum(out_ens * torch.log(out2 / out1))
                if abs(val) < opt_val:
                    opt_val = abs(val)
                    opt_lam = lam

            out_ens = (out1**opt_lam) * (out2**(1-opt_lam))
        elif self.ens_type == 'large_entropy':
            ent1 = -torch.sum(out1 * torch.log(out1), dim=-1)
            ent2 = -torch.sum(out2 * torch.log(out2), dim=-1)
            out_ens = (ent1 > ent2).float().unsqueeze(-1) * out1 + (ent1 <= ent2).float().unsqueeze(-1) * out2
        out_ens /= out_ens.sum(dim=-1, keepdim=True)

        out_ens = torch.log(out_ens)

        return out_ens
        

    def eval_forward(self, batch, outputs=None):
        return outputs if outputs is not None else self.forward(batch)

    def loss(self, outputs, batch):
        targets = self.get_targets(batch)
        return F.cross_entropy(outputs.view(-1, outputs.size(-1)),
                               targets.view(-1),
                               ignore_index=-100)

    def get_metrics(self, is_train=False):
        return self.train_metrics if is_train else self.eval_metrics

    def update_metric(self, batch, outputs, metric):
        outputs = outputs.view(-1, outputs.size(-1))
        targets = self.get_targets(batch).view(-1)
        metric.update(outputs, targets)


class ComposerMosaicMinGPT(ComposerModel):

    def __init__(self, cfg, model1, model2):
        super().__init__()
        self.model1 = model1
        self.model2 = model2
        self.train_metrics = {
            'LanguageCrossEntropy': LanguageCrossEntropy(cfg.vocab_size),
            'Perplexity': Perplexity(),
        }
        self.eval_metrics = {
            'LanguageCrossEntropy': LanguageCrossEntropy(cfg.vocab_size),
            'Perplexity': Perplexity(),
        }

    def get_targets(self, batch):
        targets = torch.roll(batch["labels"], shifts=-1)
        targets[:, -1] = -100
        return targets

    def forward(self, batch):
        out1 = self.model1.model(batch['input_ids'],
                          key_padding_mask=batch['attention_mask'].bool())
        out2 = self.model2.model(batch['input_ids'],
                            key_padding_mask=batch['attention_mask'].bool())

        out1 = out1.view(-1, out1.size(-1))
        out1 = F.softmax(out1, dim=-1)
        out2 = out2.view(-1, out2.size(-1))
        out2 = F.softmax(out2, dim=-1)

        out_min = torch.minimum(out1, out2)
        out_min /= out_min.sum(dim=-1, keepdim=True)

        out_min = torch.log(out_min)

        return out_min
        

    def eval_forward(self, batch, outputs=None):
        return outputs if outputs is not None else self.forward(batch)

    def loss(self, outputs, batch):
        targets = self.get_targets(batch)
        return F.cross_entropy(outputs.view(-1, outputs.size(-1)),
                               targets.view(-1),
                               ignore_index=-100)

    def get_metrics(self, is_train=False):
        return self.train_metrics if is_train else self.eval_metrics

    def update_metric(self, batch, outputs, metric):
        outputs = outputs.view(-1, outputs.size(-1))
        targets = self.get_targets(batch).view(-1)
        metric.update(outputs, targets)


class ComposerMosaicGPTSave(ComposerModel):

    def __init__(self, cfg, model, name):
        super().__init__()
        self.model = model.model
        self.name = name
        self.ind = 0
        self.train_metrics = {
            'LanguageCrossEntropy': LanguageCrossEntropy(cfg.vocab_size),
            'Perplexity': Perplexity(),
        }
        self.eval_metrics = {
            'LanguageCrossEntropy': LanguageCrossEntropy(cfg.vocab_size),
            'Perplexity': Perplexity(),
        }

    def get_targets(self, batch):
        targets = torch.roll(batch["labels"], shifts=-1)
        targets[:, -1] = -100
        return targets

    def forward(self, batch):
        return self.model(batch['input_ids'],
                          key_padding_mask=batch['attention_mask'].bool())

    def eval_forward(self, batch, outputs=None):
        return outputs if outputs is not None else self.forward(batch)

    def loss(self, outputs, batch):
        targets = self.get_targets(batch)
        return F.cross_entropy(outputs.view(-1, outputs.size(-1)),
                               targets.view(-1),
                               ignore_index=-100)

    def get_metrics(self, is_train=False):
        return self.train_metrics if is_train else self.eval_metrics

    def update_metric(self, batch, outputs, metric):
        outputs = outputs.view(-1, outputs.size(-1))
        outputs = F.softmax(outputs, dim=-1)
        targets = self.get_targets(batch).view(-1)
        metric.update(outputs, targets)
        outputs, indices = torch.topk(outputs, 4000)
        print('outputs shape', outputs.shape)
        print('\nignored mass in outputs', (1 - outputs.sum(dim=-1)).mean().item())
        path_folder = '/n/holystore01/LABS/barak_lab/Users/nvyas/logs/copyright/'
        print('creating new file')
        np.save(path_folder+f'{self.name}_outputs_{self.ind}.npy', outputs.float().detach().cpu().numpy())
        np.save(path_folder+f'{self.name}_indices_{self.ind}.npy', indices.detach().cpu().numpy())
        np.save(path_folder+f'{self.name}_labels_{self.ind}.npy', targets.detach().cpu().numpy())
        self.ind += 1
        print('creating new file - end')

        return

        if not os.path.exists(path):
            print('path not found, creating new file')
            np.save(f'/n/home04/nvyas/nikhil-exp/copyright-2/logs/{self.name}.npy', outputs.float().detach().cpu().numpy())
            np.save(f'/n/home04/nvyas/nikhil-exp/copyright-2/logs/{self.name}_indices.npy', indices.detach().cpu().numpy())
            np.save(f'/n/home04/nvyas/nikhil-exp/copyright-2/logs/{self.name}_labels.npy', targets.detach().cpu().numpy())
            print('path not found, creating new file - end')
        else:
            print('path found, appending to file', flush=True)
            old_outputs = np.load(path)
            new_outputs = np.concatenate((old_outputs, outputs.float().detach().cpu().numpy()), axis=0)
            np.save(f'/n/home04/nvyas/nikhil-exp/copyright-2/logs/{self.name}.npy', new_outputs)
            old_targets = np.load(f'/n/home04/nvyas/nikhil-exp/copyright-2/logs/{self.name}_labels.npy')
            new_targets = np.concatenate((old_targets, targets.detach().cpu().numpy()), axis=0)
            np.save(f'/n/home04/nvyas/nikhil-exp/copyright-2/logs/{self.name}_labels.npy', new_targets)
            old_indices = np.load(f'/n/home04/nvyas/nikhil-exp/copyright-2/logs/{self.name}_indices.npy')
            new_indices = np.concatenate((old_indices, indices.detach().cpu().numpy()), axis=0)
            np.save(f'/n/home04/nvyas/nikhil-exp/copyright-2/logs/{self.name}_indices.npy', new_indices)
            print('path found, appending to file - end', flush=True)

class ComposerMosaicKLGPT(ComposerModel):

    def __init__(self, cfg, model1, model2):
        super().__init__()
        self.model1 = model1
        self.model2 = model2
        self.train_metrics = {}
        self.eval_metrics = {}

    def get_targets(self, batch):
        targets = torch.roll(batch["labels"], shifts=-1)
        targets[:, -1] = -100
        return targets

    def forward(self, batch):
        out1 = self.model1.model(batch['input_ids'],
                          key_padding_mask=batch['attention_mask'].bool())
        out2 = self.model2.model(batch['input_ids'],
                            key_padding_mask=batch['attention_mask'].bool())

        out1 = out1.view(-1, out1.size(-1))
        out2 = out2.view(-1, out2.size(-1))

        out = torch.cat((out1, out2), dim=0)

        return out
        

    def eval_forward(self, batch, outputs=None):
        return outputs if outputs is not None else self.forward(batch)

    def loss(self, outputs, batch):
        out1 = outputs[:outputs.size(0)//2]
        out1 = F.log_softmax(out1, dim=-1)
        out2 = outputs[outputs.size(0)//2:]
        out2 = F.log_softmax(out2, dim=-1)

        return F.kl_div(out1, out2, reduction='batchmean', log_target=True)

    def get_metrics(self, is_train=False):
        return self.train_metrics if is_train else self.eval_metrics

    def update_metric(self, batch, outputs, metric):
        outputs = outputs.view(-1, outputs.size(-1))
        targets = self.get_targets(batch).view(-1)
        metric.update(outputs, targets)

class ComposerMosaicMixMetricsGPT(ComposerModel):

    def __init__(self, cfg, model1, model2):
        super().__init__()
        self.model1 = model1
        self.model2 = model2
        self.train_metrics = {
            'entropy': entropy(),
            'KL': KL(),
            'KL2': KL2(),
            'TV': TV(),
            'He2': He2(),
            'lTV': lTV(),
            'lHe2': lHe2(),
            'lHe2_alpha=.01': lHe2_alpha(.01),
            'lHe2_alpha=.02': lHe2_alpha(.02),
            'lHe2_alpha=.03': lHe2_alpha(.03),
            'lHe2_alpha=.04': lHe2_alpha(.04),
            'lHe2_alpha=.05': lHe2_alpha(.05),
            'lHe2_alpha=.06': lHe2_alpha(.06),
            'lHe2_alpha=.07': lHe2_alpha(.07),
            'lHe2_alpha=.08': lHe2_alpha(.08),
            'lHe2_alpha=.09': lHe2_alpha(.09),
            'lHe2_alpha=.1': lHe2_alpha(.1),
            'lHe2_alpha=.11': lHe2_alpha(.11),
            'lHe2_alpha=.12': lHe2_alpha(.12),
            'lHe2_alpha=.13': lHe2_alpha(.13),
            'lHe2_alpha=.14': lHe2_alpha(.14),
            'lHe2_alpha=.15': lHe2_alpha(.15),
            'lHe2_alpha=.16': lHe2_alpha(.16),
            'lHe2_alpha=.17': lHe2_alpha(.17),
            'lHe2_alpha=.18': lHe2_alpha(.18),
            'lHe2_alpha=.19': lHe2_alpha(.19),
            'lHe2_alpha=.2': lHe2_alpha(.2),
            'lTV_alpha=.04': lTV_alpha(.04),
            'lTV_alpha=.08': lTV_alpha(.08),
            'lTV_alpha=.12': lTV_alpha(.12),
            'lTV_alpha=.16': lTV_alpha(.16),
            'lTV_alpha=.2': lTV_alpha(.2),
            'lTV_alpha=.24': lTV_alpha(.24),
            'lTV_alpha=.28': lTV_alpha(.28),
            'lTV_alpha=.32': lTV_alpha(.32),
            'lTV_alpha=.36': lTV_alpha(.36),
            'lTV_alpha=.4': lTV_alpha(.4),
            'lTV_alpha=.44': lTV_alpha(.44),
            'lTV_alpha=.48': lTV_alpha(.48),
            'lTV_alpha=.52': lTV_alpha(.52),
            'lTV_alpha=.56': lTV_alpha(.56),
            'lTV_alpha=.6': lTV_alpha(.6),
            'lTV_alpha=.64': lTV_alpha(.64),
            'lTV_alpha=.68': lTV_alpha(.68),
            'lTV_alpha=.72': lTV_alpha(.72),
            'lTV_alpha=.76': lTV_alpha(.76),
            'lTV_alpha=.8': lTV_alpha(.8),
            'KL_alpha=-2': KLAlpha(-2.0),
            'KL_alpha=-1.75': KLAlpha(-1.75),
            'KL_alpha=-1.5': KLAlpha(-1.5),
            'KL_alpha=-1.25': KLAlpha(-1.25),
            'KL_alpha=-1': KLAlpha(-1.0),
            'KL_alpha=-.75': KLAlpha(-.75),
            'KL_alpha=-.5': KLAlpha(-.5),
            'KL_alpha=-.25': KLAlpha(-.25),
            'KL_alpha=0': KLAlpha(0.0),
            'KL_alpha=.25': KLAlpha(.25),
            'KL_alpha=.5': KLAlpha(.5),
            'KL_alpha=.75': KLAlpha(.75),
            'KL_alpha=1': KLAlpha(1.0),
            'KL_alpha=1.25': KLAlpha(1.25),
            'KL_alphaKL=-4': KLAlphaKL(-4),
            'KL_alphaKL=-3.5': KLAlphaKL(-3.5),
            'KL_alphaKL=-3': KLAlphaKL(-3),
            'KL_alphaKL=-2.5': KLAlphaKL(-2.5),
            'KL_alphaKL=-2': KLAlphaKL(-2),
            'KL_alphaKL=-1.5': KLAlphaKL(-1.5),
            'KL_alphaKL=-1': KLAlphaKL(-1),
            'KL_alphaKL=-.5': KLAlphaKL(-.5),
            'KL_alphaKL=0': KLAlphaKL(0.0),
            'KL_alphaKL=.5': KLAlphaKL(.5),
            'KL_alphaKL=1': KLAlphaKL(1),
            'KL_alphaKL=1.5': KLAlphaKL(1.5),
            'KL_alphaKL=2': KLAlphaKL(2.0),
            'KL_alphaKL=2.5': KLAlphaKL(2.5),
            'KL_alphaKL=3': KLAlphaKL(3),
            'KL_alphaKL=3.5': KLAlphaKL(3.5),
            'KL_alphaEnt=-.7': KLAlphaEntropy(-.7),
            'KL_alphaEnt=-.6': KLAlphaEntropy(-.6),
            'KL_alphaEnt=-.5': KLAlphaEntropy(-.5),
            'KL_alphaEnt=-.4': KLAlphaEntropy(-.4),
            'KL_alphaEnt=-.3': KLAlphaEntropy(-.3),
            'KL_alphaEnt=-.2': KLAlphaEntropy(-.2),
            'KL_alphaEnt=-.1': KLAlphaEntropy(-.1),
            'KL_alphaEnt=0': KLAlphaEntropy(0.0),
            'KL_alphaEnt=.1': KLAlphaEntropy(.1),
            'KL_alphaEnt=.2': KLAlphaEntropy(.2),
            'KL_alphaEnt=.3': KLAlphaEntropy(.3),
            'KL_alphaEnt=.3': KLAlphaEntropy(.4),
            'KL_alphaEnt=.5': KLAlphaEntropy(.5),
            'KL_alphaEnt=.6': KLAlphaEntropy(.6),
            }
        self.eval_metrics = {}

    def get_targets(self, batch):
        targets = torch.roll(batch["labels"], shifts=-1)
        targets[:, -1] = -100
        return targets

    def forward(self, batch):
        out1 = self.model1.model(batch['input_ids'],
                          key_padding_mask=batch['attention_mask'].bool())
        out2 = self.model2.model(batch['input_ids'],
                            key_padding_mask=batch['attention_mask'].bool())

        out1 = out1.view(-1, out1.size(-1))
        out2 = out2.view(-1, out2.size(-1))

        out = torch.cat((out1, out2), dim=0)

        return out
        

    def eval_forward(self, batch, outputs=None):
        return outputs if outputs is not None else self.forward(batch)

    def loss(self, outputs, batch):
        out1 = outputs[:outputs.size(0)//2]
        out1 = F.log_softmax(out1, dim=-1)
        out2 = outputs[outputs.size(0)//2:]
        out2 = F.log_softmax(out2, dim=-1)

        return torch.sum(out2+out1)

    def get_metrics(self, is_train=False):
        return self.train_metrics if is_train else self.eval_metrics

    def update_metric(self, batch, outputs, metric):
        outputs = outputs.view(-1, outputs.size(-1))
        targets = self.get_targets(batch).view(-1)
        metric.update(outputs, targets)

class He2(Metric):
    full_state_update = False

    def __init__(self, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum')
        self.add_state('total_batches', default=torch.tensor(0), dist_reduce_fx='sum')

    def update(self, output, target):
        if isinstance(output, Mapping):
            logits = output['logits']
        # recompute the loss on our own
        elif isinstance(output, Tensor):
            logits = output
        else:
            raise Exception(f'Type {type(output)} for the output is unsupported.')

        out1 = logits[:logits.size(0)//2]
        out1 = F.softmax(out1, dim=-1)
        out2 = logits[logits.size(0)//2:]
        out2 = F.softmax(out2, dim=-1)

        loss = 1.0 - torch.sum((out1*out2+1e-14)**.5, dim=-1).mean()

        self.sum_loss += loss

        self.total_batches += 1  #type: ignore (third-party)

    def compute(self):
        return self.sum_loss / self.total_batches  #type: ignore (third-party)

class lHe2(Metric):
    full_state_update = False

    def __init__(self, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum')
        self.add_state('total_batches', default=torch.tensor(0), dist_reduce_fx='sum')

    def update(self, output, target):
        if isinstance(output, Mapping):
            logits = output['logits']
        # recompute the loss on our own
        elif isinstance(output, Tensor):
            logits = output
        else:
            raise Exception(f'Type {type(output)} for the output is unsupported.')

        out1 = logits[:logits.size(0)//2]
        out1 = F.softmax(out1, dim=-1)
        out2 = logits[logits.size(0)//2:]
        out2 = F.softmax(out2, dim=-1)

        loss = torch.log(1.0/torch.sum((out1*out2+1e-14)**.5, dim=-1)).mean()

        self.sum_loss += loss

        self.total_batches += 1  #type: ignore (third-party)

    def compute(self):
        return self.sum_loss / self.total_batches  #type: ignore (third-party)

class lHe2_alpha(Metric):
    full_state_update = False

    def __init__(self, alpha, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum')
        self.add_state('total_batches', default=torch.tensor(0), dist_reduce_fx='sum')
        self.alpha = alpha

    def update(self, output, target):
        if isinstance(output, Mapping):
            logits = output['logits']
        # recompute the loss on our own
        elif isinstance(output, Tensor):
            logits = output
        else:
            raise Exception(f'Type {type(output)} for the output is unsupported.')

        out1 = logits[:logits.size(0)//2]
        out1 = F.softmax(out1, dim=-1)
        out2 = logits[logits.size(0)//2:]
        out2 = F.softmax(out2, dim=-1)

        loss = (torch.log(1.0/torch.sum((out1*out2+1e-14)**.5, dim=-1)) > self.alpha).float().mean()

        self.sum_loss += loss

        self.total_batches += 1  #type: ignore (third-party)

    def compute(self):
        return self.sum_loss / self.total_batches  #type: ignore (third-party)


class lTV_alpha(Metric):
    full_state_update = False

    def __init__(self, alpha,  dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)


        self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum')
        self.add_state('total_batches', default=torch.tensor(0), dist_reduce_fx='sum')
        self.alpha = alpha

    def update(self, output, target):
        if isinstance(output, Mapping):
            logits = output['logits']
        # recompute the loss on our own
        elif isinstance(output, Tensor):
            logits = output
        else:
            raise Exception(f'Type {type(output)} for the output is unsupported.')

        out1 = logits[:logits.size(0)//2]
        out1 = F.softmax(out1, dim=-1)
        out2 = logits[logits.size(0)//2:]
        out2 = F.softmax(out2, dim=-1)

        loss_tv = (torch.log(1./(1.0-torch.sum(.5*torch.abs(out1-out2), dim=-1))) > self.alpha).float().mean()

        self.sum_loss += loss_tv

        self.total_batches += 1  #type: ignore (third-party)

    def compute(self):
        return self.sum_loss / self.total_batches  #type: ignore (third-party)


class TV(Metric):
    full_state_update = False

    def __init__(self, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum')
        self.add_state('total_batches', default=torch.tensor(0), dist_reduce_fx='sum')

    def update(self, output, target):
        if isinstance(output, Mapping):
            logits = output['logits']
        # recompute the loss on our own
        elif isinstance(output, Tensor):
            logits = output
        else:
            raise Exception(f'Type {type(output)} for the output is unsupported.')

        out1 = logits[:logits.size(0)//2]
        out1 = F.softmax(out1, dim=-1)
        out2 = logits[logits.size(0)//2:]
        out2 = F.softmax(out2, dim=-1)

        loss_tv = .5*torch.mean(torch.sum(torch.abs(out1-out2), dim=-1))

        self.sum_loss += loss_tv

        self.total_batches += 1  #type: ignore (third-party)

    def compute(self):
        return self.sum_loss / self.total_batches  #type: ignore (third-party)

class lTV(Metric):
    full_state_update = False

    def __init__(self, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum')
        self.add_state('total_batches', default=torch.tensor(0), dist_reduce_fx='sum')

    def update(self, output, target):
        if isinstance(output, Mapping):
            logits = output['logits']
        # recompute the loss on our own
        elif isinstance(output, Tensor):
            logits = output
        else:
            raise Exception(f'Type {type(output)} for the output is unsupported.')

        out1 = logits[:logits.size(0)//2]
        out1 = F.softmax(out1, dim=-1)
        out2 = logits[logits.size(0)//2:]
        out2 = F.softmax(out2, dim=-1)

        loss_tv = torch.mean(torch.log(1./(1.0-torch.sum(.5*torch.abs(out1-out2), dim=-1))))

        self.sum_loss += loss_tv

        self.total_batches += 1  #type: ignore (third-party)

    def compute(self):
        return self.sum_loss / self.total_batches  #type: ignore (third-party)

class KL(Metric):
    full_state_update = False

    def __init__(self, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum')
        self.add_state('total_batches', default=torch.tensor(0), dist_reduce_fx='sum')

    def update(self, output, target):
        if isinstance(output, Mapping):
            logits = output['logits']
        # recompute the loss on our own
        elif isinstance(output, Tensor):
            logits = output
        else:
            raise Exception(f'Type {type(output)} for the output is unsupported.')

        out1 = logits[:logits.size(0)//2]
        out1 = F.log_softmax(out1, dim=-1)
        out2 = logits[logits.size(0)//2:]
        out2 = F.log_softmax(out2, dim=-1)

        loss = F.kl_div(out1, out2, reduction='batchmean', log_target=True)

        self.sum_loss += loss

        self.total_batches += 1  #type: ignore (third-party)

    def compute(self):
        return self.sum_loss / self.total_batches  #type: ignore (third-party)

class KLAlpha(Metric):
    full_state_update = False

    def __init__(self, alpha, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.alpha = alpha

        self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum')
        self.add_state('total_batches', default=torch.tensor(0), dist_reduce_fx='sum')

    def update(self, output, target):
        if isinstance(output, Mapping):
            logits = output['logits']
        # recompute the loss on our own
        elif isinstance(output, Tensor):
            logits = output
        else:
            raise Exception(f'Type {type(output)} for the output is unsupported.')

        out1 = logits[:logits.size(0)//2]
        out1 = F.log_softmax(out1, dim=-1)
        out2 = logits[logits.size(0)//2:]
        out2 = F.log_softmax(out2, dim=-1)

        log_rat = out1-out2
        log_rat_alpha = (log_rat > self.alpha).float()

        out2_probs = torch.softmax(logits[logits.size(0)//2:], dim=-1)

        loss = torch.sum(out2_probs*log_rat_alpha, dim=-1).mean()

        self.sum_loss += loss

        self.total_batches += 1  #type: ignore (third-party)

    def compute(self):
        return self.sum_loss / self.total_batches  #type: ignore (third-party)


class KL2(Metric):
    full_state_update = False

    def __init__(self, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum')
        self.add_state('total_batches', default=torch.tensor(0), dist_reduce_fx='sum')

    def update(self, output, target):
        if isinstance(output, Mapping):
            logits = output['logits']
        # recompute the loss on our own
        elif isinstance(output, Tensor):
            logits = output
        else:
            raise Exception(f'Type {type(output)} for the output is unsupported.')

        out1 = logits[:logits.size(0)//2]
        out1 = F.log_softmax(out1, dim=-1)
        out2 = logits[logits.size(0)//2:]
        out2 = F.log_softmax(out2, dim=-1)

        log_rat = out1-out2
        out2_probs = torch.softmax(logits[logits.size(0)//2:], dim=-1)

        loss = torch.sum(out2_probs*log_rat, dim=-1).mean()

        self.sum_loss += loss

        self.total_batches += 1  #type: ignore (third-party)

    def compute(self):
        return self.sum_loss / self.total_batches  #type: ignore (third-party)

class KLAlphaKL(Metric):
    full_state_update = False

    def __init__(self, alpha, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.alpha = alpha

        self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum')
        self.add_state('total_batches', default=torch.tensor(0), dist_reduce_fx='sum')

    def update(self, output, target):
        if isinstance(output, Mapping):
            logits = output['logits']
        # recompute the loss on our own
        elif isinstance(output, Tensor):
            logits = output
        else:
            raise Exception(f'Type {type(output)} for the output is unsupported.')

        out1 = logits[:logits.size(0)//2]
        out1 = F.log_softmax(out1, dim=-1)
        out2 = logits[logits.size(0)//2:]
        out2 = F.log_softmax(out2, dim=-1)

        kl = F.kl_div(out1, out2, reduction='none', log_target=True)
        log_rat = out1-out2
        log_rat_alpha = (log_rat > self.alpha*kl).float()

        out2_probs = F.softmax(logits[logits.size(0)//2:], dim=-1)

        loss = torch.sum(out2_probs*log_rat_alpha, dim=-1).mean()

        self.sum_loss += loss

        self.total_batches += 1  #type: ignore (third-party)

    def compute(self):
        return self.sum_loss / self.total_batches  #type: ignore (third-party)

class KLAlphaEntropy(Metric):
    full_state_update = False

    def __init__(self, alpha, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.alpha = alpha

        self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum')
        self.add_state('total_batches', default=torch.tensor(0), dist_reduce_fx='sum')

    def update(self, output, target):
        if isinstance(output, Mapping):
            logits = output['logits']
        # recompute the loss on our own
        elif isinstance(output, Tensor):
            logits = output
        else:
            raise Exception(f'Type {type(output)} for the output is unsupported.')

        out1 = logits[:logits.size(0)//2]

        ent = F.cross_entropy(out1, F.softmax(out1, dim=-1), reduction='none')

        out1 = F.log_softmax(out1, dim=-1)
        out2 = logits[logits.size(0)//2:]
        out2 = F.log_softmax(out2, dim=-1)

        log_rat = out1-out2
        print((self.alpha*ent).shape, log_rat.shape)
        log_rat_alpha = (log_rat > (self.alpha*ent)[:, None]).float()

        out2_probs = F.softmax(logits[logits.size(0)//2:], dim=-1)

        loss = torch.sum(out2_probs*log_rat_alpha, dim=-1).mean()

        self.sum_loss += loss

        self.total_batches += 1  #type: ignore (third-party)

    def compute(self):
        return self.sum_loss / self.total_batches  #type: ignore (third-party)

class entropy(Metric):
    full_state_update = False

    def __init__(self, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum')
        self.add_state('total_batches', default=torch.tensor(0), dist_reduce_fx='sum')

    def update(self, output, target):
        if isinstance(output, Mapping):
            logits = output['logits']
        # recompute the loss on our own
        elif isinstance(output, Tensor):
            logits = output
        else:
            raise Exception(f'Type {type(output)} for the output is unsupported.')

        out1 = logits[:logits.size(0)//2]
        out1_s = F.softmax(out1, dim=-1)
        out1_ls = F.log_softmax(out1, dim=-1)

        loss = -torch.sum(out1_s*out1_ls, dim=-1).mean()

        self.sum_loss += loss/2

        out2 = logits[logits.size(0)//2:]
        out2_s = F.softmax(out2, dim=-1)
        out2_ls = F.log_softmax(out2, dim=-1)


        loss = -torch.sum(out2_s*out2_ls, dim=-1).mean()

        self.sum_loss += loss/2

        self.total_batches += 1  #type: ignore (third-party)

    def compute(self):
        return self.sum_loss / self.total_batches  #type: ignore (third-party)

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""
A simple, flexible implementation of a GPT model.
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
"""

import math
from typing import Any, Mapping

import torch
import torch.nn as nn
import torch.nn.functional as F
from composer.metrics.nlp import LanguageCrossEntropy, Perplexity
from composer.models.base import ComposerModel


class TorchCausalAttention(nn.Module):
    def __init__(self, cfg: Mapping[str, Any], device: str = None):
        super().__init__()
        self.mha = nn.MultiheadAttention(
            embed_dim=cfg.d_model,
            num_heads=cfg.n_heads,
            dropout=cfg.attn_pdrop,
            bias=True,
            batch_first=True,
            device=device,
        )
        self.register_buffer(
            "mask", torch.tril(torch.ones(cfg.max_seq_len, cfg.max_seq_len)))
        self.mha.out_proj._is_residual = True

    def forward(self, x, key_padding_mask):
        return self.mha(x, x, x, attn_mask=self.mask, need_weights=False)


class FlashCausalAttention(nn.Module):
    def __init__(self, cfg: Mapping[str, Any], device: str = None):
        super().__init__()
        try:
            from flash_attn.flash_attention import FlashMHA
        except ImportError as e:
            raise e

        self.mha = FlashMHA(
            embed_dim=cfg.d_model,
            num_heads=cfg.n_heads,
            attention_dropout=cfg.attn_pdrop,
            bias=True,
            batch_first=True,
            causal=True,
            device=device,
        )
        self.mha.out_proj._is_residual = True

    def forward(self, x, key_padding_mask):
        return self.mha(x,
                        key_padding_mask=key_padding_mask,
                        need_weights=False)


class GPTMLP(nn.Module):
    def __init__(self, cfg: Mapping[str, Any], device: str = None):
        super().__init__()
        self.mlp_up = nn.Linear(cfg.d_model,
                                cfg.mlp_ratio * cfg.d_model,
                                device=device)
        self.mlp_act = nn.GELU(approximate='none')
        self.mlp_down = nn.Linear(cfg.mlp_ratio * cfg.d_model,
                                  cfg.d_model,
                                  device=device)
        self.mlp_down._is_residual = True

    def forward(self, x):
        return self.mlp_down(self.mlp_act(self.mlp_up(x)))


class GPTBlock(nn.Module):
    def __init__(self, cfg: Mapping[str, Any], device: str = None):
        super().__init__()
        self.ln_1 = nn.LayerNorm(cfg.d_model, device=device)
        if cfg.attn_impl == 'torch':
            self.causal_attn = TorchCausalAttention(cfg, device)
        elif cfg.attn_impl == 'flash':
            self.causal_attn = FlashCausalAttention(cfg, device)
        else:
            raise ValueError(f'Unknown attn_impl={cfg.attn_impl}')
        self.ln_2 = nn.LayerNorm(cfg.d_model, device=device)
        self.mlp = GPTMLP(cfg, device=device)
        self.resid_attn_dropout = nn.Dropout(cfg.resid_pdrop)
        self.resid_mlp_dropout = nn.Dropout(cfg.resid_pdrop)

    def forward(self,
                x: torch.Tensor,
                key_padding_mask: torch.ByteTensor = None) -> torch.Tensor:
        a = self.ln_1(x)
        b, _ = self.causal_attn(a, key_padding_mask)
        x = x + self.resid_attn_dropout(b)
        m = self.ln_2(x)
        n = self.mlp(m)
        x = x + self.resid_mlp_dropout(n)
        return x


class MosaicGPT(nn.Module):
    def __init__(self, cfg: Mapping[str, Any]):
        super().__init__()
        assert cfg.name == 'mosaic_gpt', f'Tried to build MosaicGPT model with cfg.name={cfg.name}'
        self.cfg = cfg
        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(cfg.vocab_size, cfg.d_model, device=cfg.device),
                wpe=nn.Embedding(cfg.max_seq_len, cfg.d_model, device=cfg.device),
                emb_drop=nn.Dropout(cfg.emb_pdrop),
                blocks=nn.ModuleList([
                    GPTBlock(cfg, device=cfg.device) for _ in range(cfg.n_layers)
                ]),
                ln_f=nn.LayerNorm(cfg.d_model, device=cfg.device),
            ))
        self.lm_head = nn.Linear(cfg.d_model,
                                 cfg.vocab_size,
                                 bias=False,
                                 device=cfg.device)

        # Apply weight tying
        # Ensures that wte and lm_head are in the same FSDP block
        self.transformer._fsdp_wrap = False
        self.transformer.wte._fsdp_wrap = False
        self.lm_head._fsdp_wrap = False
        self.lm_head.weight = self.transformer.wte.weight

        if cfg.device != 'meta':
            self.apply(self.param_init_fn)

    def forward(self,
                input_ids: torch.LongTensor,
                key_padding_mask: torch.ByteTensor = None):
        _, S = input_ids.size()
        assert (
            S <= self.cfg.max_seq_len
        ), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.cfg.max_seq_len}"
        pos = torch.arange(0, S, dtype=torch.long,
                           device=input_ids.device).unsqueeze(0)

        tok_emb = self.transformer.wte(input_ids)
        pos_emb = self.transformer.wpe(pos)
        x = self.transformer.emb_drop(tok_emb + pos_emb)
        for block in self.transformer.blocks:
            x = block(x, key_padding_mask)
        x = self.transformer.ln_f(x).detach()  #Change
        logits = self.lm_head(x)
        return logits

    # Param Initialization, needed for device='meta' fast initialization
    def param_init_fn(self, module):
        # Linear
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight,
                                  mean=0.0,
                                  std=self.cfg.init_std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

            if getattr(module, '_is_residual', False):
                module.weight.data.normal_(
                    mean=0.0,
                    std=(self.cfg.init_std / math.sqrt(2 * self.cfg.n_layers)))

        # Embedding
        if isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight,
                                  mean=0.0,
                                  std=self.cfg.init_std)

        # LayerNorm
        if isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    # FSDP Wrap function
    def fsdp_wrap_fn(self, module):
        return isinstance(module, GPTBlock)

    # Activation Checkpointing
    def activation_checkpointing_fn(self, module):
        return isinstance(module, GPTBlock)


class ComposerMosaicEvalGPT(ComposerModel):

    def __init__(self, cfg):
        super().__init__()
        self.model = MosaicGPT(cfg)
        self.train_metrics = {
            'LanguageCrossEntropy': LanguageCrossEntropy(cfg.vocab_size),
            'Perplexity': Perplexity(),
        }
        self.eval_metrics = {
            'LanguageCrossEntropy': LanguageCrossEntropy(cfg.vocab_size),
            'Perplexity': Perplexity(),
        }

    def get_targets(self, batch):
        targets = torch.roll(batch["labels"], shifts=-1)
        targets[:, -1] = -100
        return targets

    def forward(self, batch):
        return self.model(batch['input_ids'],
                          key_padding_mask=batch['attention_mask'].bool())

    def eval_forward(self, batch, outputs=None):
        return outputs if outputs is not None else self.forward(batch)

    def loss(self, outputs, batch):
        targets = self.get_targets(batch)
        return F.cross_entropy(outputs.view(-1, outputs.size(-1)),
                               targets.view(-1),
                               ignore_index=-100)

    def get_metrics(self, is_train=False):
        return self.train_metrics if is_train else self.eval_metrics

    def update_metric(self, batch, outputs, metric):
        outputs = outputs.view(-1, outputs.size(-1))
        targets = self.get_targets(batch).view(-1)
        metric.update(outputs, targets)


class ComposerMosaicMergeGPT(ComposerModel):

    def __init__(self, model1, model2, merge_layer=None):
        super().__init__()
        self.model1 = model1
        self.model2 = model2
        self.train_metrics = {}
        self.eval_metrics = {}
        if merge_layer is None:
            self.merge_layer = model1.cfg.n_layers // 2
        else:
            self.merge_layer = merge_layer

    def get_targets(self, batch):
        targets = torch.roll(batch["labels"], shifts=-1)
        targets[:, -1] = -100
        return targets

    def forward(self, batch):
        input_ids = batch['input_ids']
        key_padding_mask = batch['attention_mask'].bool()

        in_model1 = self.model1.model
        in_model2 = self.model2.model

        _, S = input_ids.size()
        pos = torch.arange(0, S, dtype=torch.long,
                           device=input_ids.device).unsqueeze(0)

        tok_emb = in_model1.transformer.wte(input_ids)
        pos_emb = in_model1.transformer.wpe(pos)
        x = in_model1.transformer.emb_drop(tok_emb + pos_emb)
        for idx, block in enumerate(in_model1.transformer.blocks):
            if idx < self.merge_layer:
                x = block(x, key_padding_mask)
        for idx, block in enumerate(in_model2.transformer.blocks):
            if idx >= self.merge_layer:
                x = block(x, key_padding_mask)
        x = in_model2.transformer.ln_f(x)
        logits = in_model2.lm_head(x)
        return logits

    def eval_forward(self, batch, outputs=None):
        return outputs if outputs is not None else self.forward(batch)

    def loss(self, outputs, batch):
        targets = self.get_targets(batch)
        return F.cross_entropy(outputs.view(-1, outputs.size(-1)),
                               targets.view(-1),
                               ignore_index=-100)

    def get_metrics(self, is_train=False):
        return self.train_metrics if is_train else self.eval_metrics

    def update_metric(self, batch, outputs, metric):
        outputs = outputs.view(-1, outputs.size(-1))
        targets = self.get_targets(batch).view(-1)
        metric.update(outputs, targets)