import wandb
from nn.cola_nn import cola_parameterize, get_model_summary_and_flops
import nn
from ffcv.fields.basics import IntDecoder
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, \
    RandomResizedCropRGBImageDecoder
from ffcv.transforms import ToTensor, ToDevice, Squeeze, NormalizeImage, \
    RandomHorizontalFlip, ToTorchImage, ImageMixup, LabelMixup
from ffcv.loader import Loader, OrderOption
from ffcv.pipeline.operation import Operation
from fastargs.validation import And, OneOf
from fastargs import Param, Section
from fastargs.decorators import param
from fastargs import get_current_config
from math import prod
from argparse import ArgumentParser
from pathlib import Path
from typing import List
from uuid import uuid4
import json
import time
import os
from tqdm import tqdm
import numpy as np
import torchmetrics
import torch as ch
from torch.cuda.amp import GradScaler
from torch.cuda.amp import autocast
from contextlib import nullcontext
import torch.distributed as dist

ch.backends.cudnn.benchmark = True
ch.autograd.profiler.emit_nvtx(False)
ch.autograd.profiler.profile(False)

Section('model', 'model details').params(
    arch=Param(str, default='ViT'),
    base_width=Param(int, default=256),
    base_depth=Param(int, default=6),
    base_heads=Param(int, default=4),
    base_dim_head=Param(int, default=64),
    base_ffn_expansion=Param(int, default=4),
    width=Param(int, default=-1),
    depth=Param(int, default=-1),
    heads=Param(int, default=-1),
    dim_head=Param(int, default=-1),
    ffn_expansion=Param(int, default=4),
    patch_size=Param(int, default=16),
    in_channels=Param(int, default=3),
    resolution=Param(int, default=224),
    fixup=Param(int, default=0),
    dropout=Param(float, default=0.),
)

Section('cola', 'cola details').params(
    struct=Param(str, default='none'),
    layers=Param(str, default='all_but_last'),
    tt_rank=Param(int, default=1),
    tt_dim=Param(int, default=2),
)

Section('data', 'data related stuff').params(train_dataset=Param(str, '.dat file to use for training', required=True),
                                             val_dataset=Param(str, '.dat file to use for validation', required=True),
                                             num_workers=Param(int, 'The number of workers', required=True),
                                             in_memory=Param(int, 'does the dataset fit in memory? (1/0)', required=True))

Section('logging', 'how to log stuff').params(folder=Param(str, 'log location',
                                                           required=True), log_level=Param(int, '0 if only at end 1 otherwise',
                                                                                           default=1),
                                              log_freq=Param(int, 'how often to log',
                                                             default=100), use_wandb=Param(int, 'use wandb?', default=1),
                                              wandb_project=Param(str, 'wandb project name', default='imagenet'),
                                              wandb_group=Param(str, 'wandb group name', default=''))

Section('training', 'training hyper param stuff').params(
    schedule=Param(OneOf(['none', 'cosine']), default='none'),
    lr=Param(float, 'lr', default=3e-4),
    input_lr_mult=Param(float, default=1),
    init_mult_1=Param(float, default=1),
    init_mult_2=Param(float, default=1),
    lr_mult_1=Param(float, default=1),
    lr_mult_2=Param(float, default=1),
    eval_only=Param(int, 'eval only?', default=0),
    batch_size=Param(int, 'The batch size', default=512),
    optimizer=Param(And(str, OneOf(['adamw'])), 'The optimizer', default='adamw'),
    weight_decay=Param(float, 'weight decay', default=0),
    epochs=Param(int, 'number of epochs', default=30),
    warmup_epochs=Param(int, 'number of warmup epochs', default=0),
    label_smoothing=Param(float, 'label smoothing parameter', default=0.),
    mixup=Param(float, 'mixup parameter', default=0.),
    distributed=Param(int, 'is distributed?', default=0),
    mixed_prec=Param(int, 'mixed precision?', default=1),
    clip_grad=Param(float, 'clip grad', default=0),
)

Section('validation',
        'Validation parameters stuff').params(lr_tta=Param(int, 'should do lr flipping/avging at test time', default=0))

Section('dist', 'distributed training options').params(world_size=Param(int, 'number gpus', default=1),
                                                       address=Param(str, 'address', default='localhost'),
                                                       port=Param(str, 'port', default='12355'))

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255
IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255
DEFAULT_CROP_RATIO = 224 / 256


@param('training.epochs')
@param('lr.lr_peak_epoch')
def get_cyclic_lr(epoch, epochs, lr_peak_epoch):
    xs = [0, lr_peak_epoch, epochs]
    ys = [1e-4, 1, 0]
    return np.interp([epoch], xs, ys)[0]


@param('training.warmup_epochs')
def get_constant_mult(epoch, warmup_epochs):
    if epoch < warmup_epochs:
        return epoch / warmup_epochs
    else:
        return 1


@param('training.warmup_epochs')
@param('training.epochs')
def get_cosine_mult(epoch, warmup_epochs, epochs):
    if epoch < warmup_epochs:
        return epoch / warmup_epochs
    else:
        return 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))


class ImageNetTrainer:
    @param('training.distributed')
    def __init__(self, gpu, distributed):
        self.all_params = get_current_config()
        self.gpu = gpu

        self.uid = str(uuid4())

        if distributed:
            self.setup_distributed()

        self.train_loader = self.create_train_loader()
        self.val_loader = self.create_val_loader()
        self.example_imgs = next(iter(self.val_loader))[0]  # for tracking features
        self.model, self.loss, self.val_loss, self.scaler, self.optimizer, self.info = self.create_model_and_scaler()
        self.initialize_logger()
        self.prev_hs = None  # prev hidden states for feature tracking
        self.mixed_prec = self.all_params['training.mixed_prec']

    @param('dist.address')
    @param('dist.port')
    @param('dist.world_size')
    def setup_distributed(self, address, port, world_size):
        os.environ['MASTER_ADDR'] = address
        os.environ['MASTER_PORT'] = port

        dist.init_process_group("nccl", rank=self.gpu, world_size=world_size)
        ch.cuda.set_device(self.gpu)

    def cleanup_distributed(self):
        dist.destroy_process_group()

    @param('training.schedule')
    def get_lr_sched_mult(self, epoch, schedule):
        lr_schedules = {'none': get_constant_mult, 'cosine': get_cosine_mult}

        return lr_schedules[schedule](epoch)

    # resolution tools
    @param('model.resolution')
    def get_resolution(self, resolution):
        return resolution

    @param('data.train_dataset')
    @param('data.num_workers')
    @param('training.batch_size')
    @param('training.mixup')
    @param('training.distributed')
    @param('training.mixed_prec')
    @param('data.in_memory')
    def create_train_loader(self, train_dataset, num_workers, batch_size, mixup, distributed, mixed_prec, in_memory):
        this_device = f'cuda:{self.gpu}'
        train_path = Path(train_dataset)
        assert train_path.is_file()

        res = self.get_resolution()
        self.decoder = RandomResizedCropRGBImageDecoder((res, res))
        image_pipeline: List[Operation] = [self.decoder, RandomHorizontalFlip()]
        label_pipeline: List[Operation] = [IntDecoder()]
        if mixup > 0:
            image_pipeline.extend([ImageMixup(alpha=mixup, same_lambda=True)])
            label_pipeline.extend([LabelMixup(alpha=mixup, same_lambda=True)])
        image_pipeline.extend([
            ToTensor(),
            ToDevice(ch.device(this_device), non_blocking=True),
            ToTorchImage(),
            NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float16 if mixed_prec else np.float32)
        ])

        label_pipeline.extend([ToTensor(), Squeeze(), ToDevice(ch.device(this_device), non_blocking=True)])

        # order = OrderOption.RANDOM if distributed else OrderOption.QUASI_RANDOM
        order = OrderOption.RANDOM  # QUASI_RANDOM hurts performance
        loader = Loader(train_dataset, batch_size=batch_size, num_workers=num_workers, order=order, os_cache=in_memory,
                        drop_last=True, pipelines={
                            'image': image_pipeline,
                            'label': label_pipeline
                        }, distributed=distributed)

        return loader

    @param('data.val_dataset')
    @param('data.num_workers')
    @param('training.batch_size')
    @param('model.resolution')
    @param('training.distributed')
    @param('training.mixed_prec')
    def create_val_loader(self, val_dataset, num_workers, batch_size, resolution, distributed, mixed_prec):
        this_device = f'cuda:{self.gpu}'
        val_path = Path(val_dataset)
        assert val_path.is_file()
        res_tuple = (resolution, resolution)
        cropper = CenterCropRGBImageDecoder(res_tuple, ratio=DEFAULT_CROP_RATIO)
        image_pipeline = [
            cropper,
            ToTensor(),
            ToDevice(ch.device(this_device), non_blocking=True),
            ToTorchImage(),
            NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float16 if mixed_prec else np.float32)
        ]

        label_pipeline = [IntDecoder(), ToTensor(), Squeeze(), ToDevice(ch.device(this_device), non_blocking=True)]

        loader = Loader(val_dataset, batch_size=batch_size, num_workers=num_workers, order=OrderOption.SEQUENTIAL,
                        drop_last=False, pipelines={
                            'image': image_pipeline,
                            'label': label_pipeline
                        }, distributed=distributed)
        return loader

    @param('training.epochs')
    @param('logging.log_level')
    def train(self, epochs, log_level):
        for epoch in range(epochs):
            res = self.get_resolution()
            self.decoder.output_size = (res, res)
            train_loss, train_acc = self.train_loop(epoch)

            if log_level > 0:
                extra_dict = {'train_loss': train_loss, 'train_acc': train_acc, 'epoch': epoch}

                self.eval_and_log(extra_dict)
                if self.gpu == 0:
                    state = {
                        'epoch': epoch,
                        'state_dict': self.model.state_dict(),
                        'optimizer': self.optimizer.state_dict(),
                    }
                    ch.save(state, self.log_folder / 'checkpoint.pt')

        self.eval_and_log({'epoch': epoch})
        if self.gpu == 0:
            ch.save(self.model.state_dict(), self.log_folder / 'final_weights.pt')

    def eval_and_log(self, extra_dict={}):
        start_val = time.time()
        stats = self.val_loop()
        val_time = time.time() - start_val
        if self.gpu == 0:
            self.log(
                dict(
                    {
                        'current_lr': self.optimizer.param_groups[0]['lr'],
                        'val_loss': stats['loss'],
                        'val_acc': stats['top_1'],
                        'top_5': stats['top_5'],
                        'val_time': val_time
                    }, **extra_dict))

        return stats

    @param('training.distributed')
    @param('model.arch')
    @param('model.base_width')
    @param('model.base_depth')
    @param('model.base_heads')
    @param('model.base_dim_head')
    @param('model.base_ffn_expansion')
    @param('model.width')
    @param('model.depth')
    @param('model.heads')
    @param('model.dim_head')
    @param('model.ffn_expansion')
    @param('model.patch_size')
    @param('model.in_channels')
    @param('model.resolution')
    @param('model.fixup')
    @param('model.dropout')
    @param('cola.struct')
    @param('cola.layers')
    @param('cola.tt_rank')
    @param('cola.tt_dim')
    @param('training.lr')
    @param('training.weight_decay')
    @param('training.input_lr_mult')
    @param('training.init_mult_1')
    @param('training.init_mult_2')
    @param('training.lr_mult_1')
    @param('training.lr_mult_2')
    @param('training.optimizer')
    @param('training.label_smoothing')
    def create_model_and_scaler(self, distributed, arch, base_width, base_depth, base_heads, base_dim_head, base_ffn_expansion,
                                width, depth, heads, dim_head, ffn_expansion, patch_size, in_channels, resolution, fixup, dropout,
                                struct, layers, tt_rank, tt_dim, lr, weight_decay, input_lr_mult, init_mult_1, init_mult_2,
                                lr_mult_1, lr_mult_2, optimizer, label_smoothing):
        assert optimizer == 'adamw', 'Only adamw supported'
        scaler = GradScaler()
        model_builder = getattr(nn, arch)
        input_shape = (1, in_channels, resolution, resolution)
        base_config = dict(dim_in=prod(input_shape), dim_out=1000, depth=base_depth, width=base_width, heads=base_heads,
                           dim_head=base_dim_head, ffn_expansion=base_ffn_expansion, patch_size=patch_size,
                           in_channels=in_channels, image_size=resolution, fixup=fixup, dropout=dropout)

        depth = base_depth if depth == -1 else depth
        width = base_width if width == -1 else width
        if heads == -1 and dim_head == -1:
            heads = base_heads if heads == -1 else heads
            dim_head = base_dim_head if dim_head == -1 else dim_head
        elif heads != -1:
            dim_head = width // heads
        elif dim_head != -1:
            heads = width // dim_head
        self.depth, self.width, self.heads, self.dim_head = depth, width, heads, dim_head
        target_config = dict(dim_in=prod(input_shape), dim_out=1000, depth=depth, width=width, heads=heads, dim_head=dim_head,
                             ffn_expansion=ffn_expansion, patch_size=patch_size, image_size=resolution, in_channels=in_channels,
                             fixup=fixup, dropout=dropout)
        # update width, depth, etc. for logging
        self.target_config = target_config

        def extra_lr_mult_fn(param_name):
            if 'to_patch_embedding' in param_name or 'input_layer' in param_name:
                return input_lr_mult
            elif 'matrix_params.0' in param_name:
                return lr_mult_1
            elif 'matrix_params.1' in param_name:
                return lr_mult_2
            else:
                return 1

        def extra_init_mult_fn(param_name):
            if 'matrix_params.0' in param_name:
                print(f'scaling {param_name} std by {init_mult_1}')
                return init_mult_1
            elif 'matrix_params.1' in param_name:
                print(f'scaling {param_name} std by {init_mult_2}')
                return init_mult_2
            else:
                return 1

        def zero_init_fn(weight, name):
            return hasattr(weight, 'zero_init') and weight.zero_init

        cola_kwargs = dict(tt_dim=tt_dim, tt_rank=tt_rank)
        optim_kwargs = dict(weight_decay=weight_decay, betas=(0.9, 0.95))
        model, optimizer = cola_parameterize(model_builder, base_config=base_config, lr=lr, target_config=target_config,
                                             struct=struct, layer_select_fn=layers, zero_init_fn=zero_init_fn,
                                             extra_lr_mult_fn=extra_lr_mult_fn, device='cuda', cola_kwargs=cola_kwargs,
                                             optim_kwargs=optim_kwargs)
        fake_input = ch.randn(*input_shape).to('cuda')
        info = get_model_summary_and_flops(model, fake_input)
        # model = model.to(memory_format=ch.channels_last)
        model = model.to(self.gpu)
        if distributed:
            model = ch.nn.parallel.DistributedDataParallel(model, device_ids=[self.gpu])
            if hasattr(model.module, 'get_features'):
                model.get_features = model.module.get_features
                model.clear_features = model.module.clear_features
        loss = ch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
        val_loss = ch.nn.CrossEntropyLoss()
        return model, loss, val_loss, scaler, optimizer, info

    @param('logging.log_level')
    @param('logging.log_freq')
    @param('training.clip_grad')
    @param('training.mixup')
    def train_loop(self, epoch, log_level, log_freq, clip_grad, mixup):
        model = self.model
        model.train()
        losses = []
        accs = []

        raw_lrs = [group['lr'] for group in self.optimizer.param_groups]
        lr_sched_mult_start, lr_sched_mult_end = self.get_lr_sched_mult(epoch), self.get_lr_sched_mult(epoch + 1)
        iters = len(self.train_loader)
        lr_sched_mults = np.interp(np.arange(iters), [0, iters], [lr_sched_mult_start, lr_sched_mult_end])

        iterator = tqdm(self.train_loader)
        for ix, (images, target) in enumerate(iterator):
            # Training start
            for param_group, raw_lr in zip(self.optimizer.param_groups, raw_lrs):
                param_group['lr'] = raw_lr * lr_sched_mults[ix]

            self.optimizer.zero_grad(set_to_none=True)
            with autocast() if self.mixed_prec else nullcontext():
                output = self.model(images)
                if mixup > 0:
                    target_perm = target[:, 1].long()
                    weight = target[0, 2].squeeze()
                    target = target[:, 0].long()
                    if weight != -1:
                        loss_train = self.loss(output, target) * weight + self.loss(output, target_perm) * (1 - weight)
                    else:
                        loss_train = self.loss(output, target)
                        target_perm = None
                else:
                    loss_train = self.loss(output, target)
            self.scaler.scale(loss_train).backward()

            # Unscales the gradients of optimizer's assigned params in-place
            self.scaler.unscale_(self.optimizer)

            # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
            if clip_grad > 0:
                ch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)

            # optimizer's gradients are already unscaled, so self.scaler.step does not unscale them,
            # although it still skips optimizer.step() if the gradients contain infs or NaNs.
            self.scaler.step(self.optimizer)
            self.scaler.update()

            losses.append(loss_train.detach())
            accs.append((output.argmax(dim=-1) == target).float().mean().detach() * 100)
            # Training end

            # Logging start
            if log_level > 0 and ix % log_freq == 0:
                feature_metrics = self.get_features_metrics(self.example_imgs)
                cur_lr = self.optimizer.param_groups[0]['lr']

                msg = f'E {epoch} | L {loss_train:.2f} | A {accs[-1]:.1f}'
                iterator.set_description(msg)

                self.log(
                    dict(
                        {
                            'train_loss': ch.tensor(losses).mean().item(),
                            'train_acc': ch.tensor(accs).mean().item(),
                            'lr': cur_lr,
                        }, **feature_metrics))
            # Logging end
        # Reset lr
        for param_group, raw_lr in zip(self.optimizer.param_groups, raw_lrs):
            param_group['lr'] = raw_lr
        return ch.tensor(losses).mean().item(), ch.tensor(accs).mean().item()

    def get_features_metrics(self, val_imgs):
        metrics = {}
        if hasattr(self.model, 'get_features'):
            self.model.clear_features()
            with autocast() if self.mixed_prec else nullcontext():
                self.model.eval()  # features only saved in eval mode
                self.model(val_imgs)
                self.model.train()
            hs = self.model.get_features()
            hs = [ch.cat(h.buffer, dim=0) for h in hs]  # list of tensors
            if self.prev_hs is None:
                self.prev_hs = hs
            dhs = [hs[i] - self.prev_hs[i] for i in range(len(hs))]
            h_norm = [rms(h) for h in hs]  # should be O(1)
            h_std = [ch.std(h) for h in hs]
            dh_norm = [rms(dh) for dh in dhs]  # should be O(1)
            self.prev_hs = hs
            metrics = {}
            for i in range(len(h_norm)):
                metrics[f'h_{i}'] = h_norm[i].item()
                metrics[f'std_{i}'] = h_std[i].item()
                metrics[f'dh_{i}'] = dh_norm[i].item()
            # go through all params
            for name, p in self.model.named_parameters():
                if hasattr(p, 'rms'):
                    metrics[f'rms/{name}'] = p.rms
                if hasattr(p, 'x'):
                    metrics[f'in/{name}'] = p.x
                if hasattr(p, 'out'):
                    metrics[f'out/{name}'] = p.out
                if hasattr(p, 'scale'):
                    metrics[f'scale/{name}'] = p.scale
        return metrics

    @param('validation.lr_tta')
    def val_loop(self, lr_tta):
        model = self.model
        model.eval()

        with ch.no_grad():
            with autocast() if self.mixed_prec else nullcontext():
                for images, target in tqdm(self.val_loader):
                    output = self.model(images)
                    if lr_tta:
                        output += self.model(ch.flip(images, dims=[3]))

                    for k in ['top_1', 'top_5']:
                        self.val_meters[k](output, target)

                    loss_val = self.val_loss(output, target)  # no label smoothing
                    self.val_meters['loss'](loss_val)

        stats = {k: m.compute().item() for k, m in self.val_meters.items()}
        [meter.reset() for meter in self.val_meters.values()]
        return stats

    @param('logging.folder')
    @param('logging.use_wandb')
    @param('logging.wandb_project')
    @param('logging.wandb_group')
    def initialize_logger(self, folder, use_wandb, wandb_project, wandb_group):
        self.val_meters = {
            'top_1': torchmetrics.Accuracy(task='multiclass', num_classes=1000).to(self.gpu) * 100,
            'top_5': torchmetrics.Accuracy(task='multiclass', num_classes=1000, top_k=5).to(self.gpu) * 100,
            'loss': MeanScalarMetric().to(self.gpu)
        }

        if self.gpu == 0:
            folder = (Path(folder) / self.runname() / str(self.uid)).absolute()
            folder.mkdir(parents=True)

            self.log_folder = folder
            self.start_time = time.time()

            print(f'=> Logging in {self.log_folder}')
            params = {'.'.join(k): self.all_params[k] for k in self.all_params.entries.keys()}

            with open(folder / 'params.json', 'w+') as handle:
                json.dump(params, handle)
            if use_wandb:
                config_dict = {k.split('.')[-1]: v for k, v in params.items()}
                config_dict.update(self.target_config)
                config_dict.update(self.info)
                config_dict['model'] = config_dict['arch']
                config_dict['ckpt_dir'] = str(folder)
                wandb.init(project=wandb_project, name=self.runname(), group=wandb_group if wandb_group else None,
                           config=config_dict)

    @param('logging.use_wandb')
    def log(self, content, use_wandb):
        if self.gpu != 0:
            return
        cur_time = time.time()
        with open(self.log_folder / 'log', 'a+') as fd:
            fd.write(json.dumps({'timestamp': cur_time, 'relative_time': cur_time - self.start_time, **content}) + '\n')
            fd.flush()
        if use_wandb:
            wandb.log(content)

    @classmethod
    @param('training.distributed')
    @param('dist.world_size')
    def launch_from_args(cls, distributed, world_size):
        if distributed:
            print(f'Launching distributed training with {world_size} gpus')
            ch.multiprocessing.spawn(cls._exec_wrapper, nprocs=world_size, join=True)
        else:
            cls.exec(0)

    @classmethod
    def _exec_wrapper(cls, *args, **kwargs):
        make_config(quiet=True)
        cls.exec(*args, **kwargs)

    @classmethod
    @param('training.distributed')
    @param('training.eval_only')
    def exec(cls, gpu, distributed, eval_only):
        trainer = cls(gpu=gpu)
        if eval_only:
            trainer.eval_and_log()
        else:
            trainer.train()

        if distributed:
            trainer.cleanup_distributed()

    @param('model.arch')
    @param('model.patch_size')
    @param('cola.struct')
    @param('cola.layers')
    @param('cola.tt_dim')
    @param('cola.tt_rank')
    @param('training.lr')
    @param('training.weight_decay')
    @param('training.batch_size')
    def runname(self, arch, patch_size, struct, layers, tt_dim, tt_rank, lr, weight_decay, batch_size):
        name = arch + f'-{patch_size}'
        name += f'_d{self.depth}w{self.width}h{self.heads}dh{self.dim_head}'
        if struct != 'none':
            name += f'_{struct}_{layers}_tt_{tt_dim}_{tt_rank}'
        name += f'_lr{lr}wd{weight_decay}bs{batch_size}'
        return name


# Utils


class MeanScalarMetric(torchmetrics.Metric):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.add_state('sum', default=ch.tensor(0.), dist_reduce_fx='sum')
        self.add_state('count', default=ch.tensor(0), dist_reduce_fx='sum')

    def update(self, sample: ch.Tensor):
        self.sum += sample.sum()
        self.count += sample.numel()

    def compute(self):
        return self.sum.float() / self.count


def rms(x, eps=1e-8):
    x = x.float()  # to avoid nan
    return (ch.mean(x**2) + eps).sqrt()


# Running


def make_config(quiet=False):
    config = get_current_config()
    parser = ArgumentParser(description='Fast imagenet training')
    config.augment_argparse(parser)
    config.collect_argparse_args(parser)
    config.validate(mode='stderr')
    if not quiet:
        config.summary()


if __name__ == "__main__":
    make_config()
    ImageNetTrainer.launch_from_args()
