## pythonic imports
import os
import random
import numpy as np
import pandas as pd
from argparse import ArgumentParser
from prettytable import PrettyTable
import tqdm
import wandb

from copy import deepcopy

## torch
import torch
import torch.nn as nn
from torch.distributed import init_process_group
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch import optim
import torch.multiprocessing as mp
from torch.cuda.amp import autocast
from torch.utils.tensorboard import SummaryWriter

## torchvision
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

## file-based imports
import utils.schedulers as schedulers
import utils.pruning_utils as pruning_utils
from utils.harness_utils import *
from utils.metric_utils import LabelSmoothingLoss
from utils.pruning_utils import dst_pruner, prune_mag, make_dense, anneal_scale, masked_l2_decay, set_target_mask
from utils.dataset import CIFARLoader, imagenet, imagenet_pytorch
from utils.fixed_sign_optim import FixedSignOptimizerWrapper
## fastargs
from fastargs import get_current_config
from fastargs.decorators import param

##ffcv
try: 
    from ffcv.loader import Loader, OrderOption
    from ffcv.transforms import (
        ToTensor,
        ToDevice,
        Squeeze,
        NormalizeImage,
        RandomHorizontalFlip,
        ToTorchImage,
    )
    from ffcv.fields.decoders import (
        RandomResizedCropRGBImageDecoder,
        CenterCropRGBImageDecoder,
    )
    from ffcv.fields.basics import IntDecoder
except:
    pass 

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255
IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255
DEFAULT_CROP_RATIO = 224 / 256

wandb.login(key='91540bf66d39ba645f2e157b397e4130faff2c4e')

class Harness:
    """Harness class to handle training and evaluation.

    Args:
        gpu_id (int): current rank of the process while training using DDP.
        expt_dir (str): Experiment directory to save artifacts/checkpoints/metrics to.
        model (Optional[nn.Module], optional): The model to train.
    """

    def __init__(self, gpu_id: int, expt_dir: str, model: nn.Module) -> None:
        self.config = get_current_config()
        self.gpu_id = gpu_id
        self.this_device = f"cuda:{self.gpu_id}"
        self.num_classes = self.config["dataset.num_classes"]
        self.criterion = LabelSmoothingLoss(self.num_classes) if self.config["dataset.criterion"] == 'LabelSmoothingLoss' else nn.CrossEntropyLoss() 

        if "CIFAR" not in self.config["dataset.dataset_name"]:
            if not self.config["dataset.use_ffcv"]:
                dataset = imagenet_pytorch(this_device=self.this_device, distributed=True, rank=self.gpu_id, world_size=dist.get_world_size())
                self.train_loader = dataset.train_loader
                self.test_loader = dataset.test_loader
            else:
                self.train_loader = self.create_train_loader()
                self.test_loader = self.create_test_loader()
        else:
            self.loaders = CIFARLoader(distributed=True)
            self.train_loader = self.loaders.train_loader
            self.test_loader = self.loaders.test_loader

        model = model.to(self.this_device)
        self.model = DDP(model, device_ids=[self.gpu_id])
        self.create_optimizers()
        self.expt_dir = expt_dir
        self.step_counter = 0
        self.steps_per_epoch = len(self.train_loader)
        self.epochs_per_level = self.config["experiment_params.epochs_per_level"]
        self.total_steps = self.steps_per_epoch * self.epochs_per_level
        if self.config["prune_params.dst"]:
            self.T_end = int(0.75 * self.steps_per_epoch * self.epochs_per_level)
            self.grad_acc_window = self.config["prune_params.dst_grad_window"]
            self.dst_every = self.config["prune_params.dst_every"]
            # initialize the gradient tracker
            self.grad_tracker = GradientTracker(window=self.grad_acc_window, dst_every=self.dst_every, T_end=self.T_end)
            # register model hooks
            self.grad_tracker.register_hooks(self.model)

        if self.config["prune_params.er_method"] == 'anneal_balanced':
            self.anneal_steps_start = int(0 * self.steps_per_epoch * self.epochs_per_level)
            self.anneal_steps_end = int(0.8 * self.steps_per_epoch * self.epochs_per_level)
        self.precision, self.use_amp = self.get_dtype_amp()


    @param('experiment_params.training_precision')
    def get_dtype_amp(self, training_precision):
        dtype_map = {
            'bfloat16': (torch.bfloat16, True),
            'float16': (torch.float16, True),
            'float32': (torch.float32, False)
        }
        return dtype_map.get(training_precision, (torch.float32, False))
  
    @param("dataset.batch_size")
    @param("dataset.num_workers")
    @param("dataset.data_root")
    def create_train_loader(
        self,
        batch_size: int,
        num_workers: int,
        data_root: str,
        distributed: bool = True,
    ) -> Any:
        """Create the train dataloader.

        Args:
            batch_size (int): Batch size for data loading.
            num_workers (int): Number of workers for data loading.
            data_root (str): Root directory for data.
            distributed (bool, optional): Whether to use distributed data loading. Default is True.

        Returns:
            Any: Train dataloader.
        """
        train_image_pipeline = [
            RandomResizedCropRGBImageDecoder((224, 224)),
            RandomHorizontalFlip(),
            ToTensor(),
            ToDevice(torch.device(self.this_device), non_blocking=True),
            ToTorchImage(),
            NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32),
        ]

        label_pipeline = [
            IntDecoder(),
            ToTensor(),
            Squeeze(),
            ToDevice(torch.device(self.this_device), non_blocking=True),
        ]

        train_loader = Loader(
            os.path.join(data_root, "train_500_0.50_90.beton"),
            batch_size=batch_size,
            num_workers=num_workers,
            order=OrderOption.RANDOM,
            os_cache=True,
            drop_last=True,
            pipelines={"image": train_image_pipeline, "label": label_pipeline},
            distributed=distributed,
            )

        return train_loader

    @param("dataset.batch_size")
    @param("dataset.num_workers")
    @param("dataset.data_root")
    def create_test_loader(
        self,
        batch_size: int,
        num_workers: int,
        data_root: str,
        distributed: bool = True,
    ) -> Any:
        """Create the test dataloader.

        Args:
            batch_size (int): Batch size for data loading.
            num_workers (int): Number of workers for data loading.
            data_root (str): Root directory for data.
            distributed (bool, optional): Whether to use distributed data loading. Default is True.

        Returns:
            Any: Test dataloader.
        """
        val_image_pipeline = [
            CenterCropRGBImageDecoder((224, 224), ratio=DEFAULT_CROP_RATIO),
            ToTensor(),
            ToDevice(torch.device(self.this_device), non_blocking=True),
            ToTorchImage(),
            NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32),
        ]

        label_pipeline = [
            IntDecoder(),
            ToTensor(),
            Squeeze(),
            ToDevice(torch.device(self.this_device), non_blocking=True),
        ]

        val_loader = Loader(
            os.path.join(data_root, "val_500_0.50_90.beton"),
            batch_size=batch_size,
            num_workers=num_workers,
            order=OrderOption.SEQUENTIAL,
            drop_last=False,
            pipelines={"image": val_image_pipeline, "label": label_pipeline},
            distributed=distributed,
        )
        return val_loader

    @param("optimizer.lr")
    @param("optimizer.momentum")
    @param("optimizer.weight_decay")
    @param("optimizer.scheduler_type")
    @param("optimizer.optim_type")
    @param("optimizer.fixed_sign_optim")
    def create_optimizers(
        self, lr: float, momentum: float, weight_decay: float, scheduler_type: str, optim_type: str, fixed_sign_optim: bool
    ) -> None:
        """Instantiate the optimizer and learning rate scheduler.

        Args:
            lr (float): Initial learning rate.
            momentum (float): Momentum for SGD.
            weight_decay (float): Weight decay for optimizer.
            scheduler_type (str): Type of scheduler.
        """
        if optim_type == 'AdamW':
            self.optimizer = optim.AdamW(self.model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay)
        else:
            self.optimizer = optim.SGD(
                self.model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay
            )
        # Add an optimizer wrapper that fixed the signs and learns only the magnitude
        if fixed_sign_optim:
            print('Fixed sign optimizer initialized')
            self.optimizer = FixedSignOptimizerWrapper(self.optimizer)

        scheduler = getattr(schedulers, scheduler_type)
        if scheduler_type == "TriangularSchedule":
            self.scheduler = scheduler(
                optimizer=self.optimizer, steps_per_epoch=len(self.train_loader)
            )
        else:
            self.scheduler = scheduler(optimizer=self.optimizer)


    # logging for wandb
    def _log_metrics(self, level, train_loss, test_loss, train_acc, test_acc, density, epoch, epochs_per_level, lr):
        wandb.log({
            "train_loss": train_loss, "test_loss": test_loss,
            "train_acc": train_acc, "test_acc": test_acc,
            "epoch": level * epochs_per_level + epoch,
            "cycle": level,
            "density": density,
            "lr": lr,
            "grad_norm": get_gradient_norm(self.model, masked=False),
            "masked_grad_norm": get_gradient_norm(self.model, masked=True),
            "weight_norm": get_norm(self.model, masked=False),
            "masked_weight_norm": get_norm(self.model, masked=True),
        })

    @param("prune_params.dst")
    @param("prune_params.er_method")
    @param("model_params.conv_type")
    @param("optimizer.weight_decay")
    def train_one_epoch(self, epoch: int, target_mask_list: list, dst: bool, er_method: str, conv_type: str, weight_decay: float) -> (float, float):
        """Train the model for one epoch.

        Args:
            epoch (int): Current epoch.
            dst (bool): If we want to shuffle the mask
        Returns:
            (float, float): Training loss and accuracy.
        """
        if "CIFAR" in self.config["dataset.dataset_name"]:
            self.loaders.train_sampler.set_epoch(epoch)
        model = self.model
        model.train()
        train_loss = 0
        correct = 0
        total = 0
        tepoch = tqdm.tqdm(self.train_loader, unit="batch", desc=f"Epoch {epoch}")
        for inputs, targets in tepoch:
            if self.config["dataset.use_ffcv"] is False:
                inputs, targets = inputs.to(self.this_device), targets.to(
                    self.this_device
                )

            self.optimizer.zero_grad()

            with autocast(dtype=self.precision, enabled = self.use_amp):
                outputs = model(inputs.contiguous())
                loss = self.criterion(outputs, targets)
                if er_method == 'anneal_balanced' and self.step_counter >= self.anneal_steps_start and self.step_counter <= self.anneal_steps_end:
                    scale = anneal_scale(self.step_counter - self.anneal_steps_start, self.anneal_steps_end - self.anneal_steps_start)
                    decay_loss = masked_l2_decay(model, target_mask_list, scale)
                    print(f"Scale {scale} and loss value {decay_loss}")
                    loss += decay_loss
                if conv_type == 'ConvMaskMW' or conv_type == 'LinearMaskMW':
                    loss += anneal_scale_mw_l1(self.step_counter, self.total_steps) * get_mw_l1(self.model)
                    loss += anneal_scale_mw(self.step_counter, self.total_steps) * get_mw_wd(self.model)
            loss.backward()
            self.optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            if self.config["optimizer.scheduler_type"] == "TriangularSchedule":
                self.scheduler.step()

            
            # Annealing, set mask to target once done
            if er_method == 'anneal_balanced' and self.step_counter == self.anneal_steps_end:
                print('Pruning after annealing to set the target mask')
                self.model = set_target_mask(self.model, target_mask_list)
                

            self.step_counter += 1
              

            # DST step counter
            if dst: 
                self.grad_tracker.update_step()

            # DST is implemented here
            if dst and self.step_counter % self.config["prune_params.dst_every"] == 0 and epoch > 0:
                if self.gpu_id == 0:
                    print('Updating the mask at iteration number {}'.format(self.step_counter))
                # Modify parameters across all processes
                self.model = dst_pruner(model=self.model, step=self.step_counter, T_end=self.T_end, acc_grads=self.grad_tracker.get_acc_grads(), world_size=dist.get_world_size())
                self.optimizer = mask_momentum_rigl(self.model, self.optimizer)
        
        if self.config["optimizer.scheduler_type"] != "TriangularSchedule":
            self.scheduler.step()
        train_loss /= len(self.train_loader)
        accuracy = 100.0 * (correct / total)
        return train_loss, accuracy

    def test(self) -> (float, float):
        """Evaluate the model on the test set.

        Returns:
            (float, float): Test loss and accuracy.
        """
        model = self.model
        model.eval()
        test_loss = 0
        correct = 0
        total = 0

        tloader = tqdm.tqdm(self.test_loader, desc="Testing")
        with torch.no_grad():
            for inputs, targets in tloader:
                if self.config["dataset.use_ffcv"] is False:
                    inputs, targets = inputs.to(self.this_device), targets.to(
                        self.this_device
                    )
                with autocast(dtype=self.precision, enabled = self.use_amp):
                    outputs = model(inputs.contiguous())
                    loss = self.criterion(outputs, targets)

                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        test_loss /= len(self.test_loader)
        accuracy = 100.0 * correct / total

        return test_loss, accuracy

    
    @param("prune_params.er_init")
    def prune_in_level(self, er_init):
        if self.gpu_id == 0:  # Modify parameters only in rank 0 process
            print('Pruning model by magnitude for AC/DC')
        self.model = prune_mag(self.model, er_init, world_size=dist.get_world_size(), distributed=True)
        
    @param("experiment_params.epochs_per_level")
    @param("experiment_params.training_type")
    @param("prune_params.acdc")
    @param("prune_params.rescale_mw")
    @param("prune_params.rescale_till")
    @param("prune_params.rescale_every")
    def train_one_level(
        self, epochs_per_level: int, training_type: str, level: int, writer: callable, target_mask_list: list, acdc: bool, rescale_mw: bool, rescale_till: float, rescale_every: int
    ) -> None:
        """Train the model for one full level. This can thought of as one full training run.

        Args:
            epochs_per_level (int): Total number of epochs to train for at each level.
            training_type (str): Type of training can be {'wr', 'lrr' or 'imp}.
            level (int): Current sparsity level.
            writer: SummaryWriter for tensorboard.
            dst (bool): RiGL trainer
            acdc (bool): ACDC trainer

        """
        new_table = PrettyTable()
        new_table.field_names = [
            "Epoch",
            "Train Loss",
            "Test Loss",
            "Train Acc",
            "Test Acc",
        ]
        sparsity_level_df = {
            "epoch": [],
            "train_acc": [],
            "test_acc": [],
            "train_loss": [],
            "test_loss": [],
        }
        data_df = {
            "level": [],
            "sparsity": [],
            "max_test_acc": [],
            "final_test_acc": [],
        }

        print('Starting training at this level')
        print(f'Rescaling for this run is: {rescale_mw}')

        if self.gpu_id == 0:
            print('Model density:', get_model_density(self.model))

        for epoch in range(epochs_per_level):

            train_loss, train_acc = self.train_one_epoch(epoch, target_mask_list)
            test_loss, test_acc = self.test()

            train_loss_tensor = torch.tensor(train_loss).to(self.model.device)
            train_acc_tensor = torch.tensor(train_acc).to(self.model.device)
            test_loss_tensor = torch.tensor(test_loss).to(self.model.device)
            test_acc_tensor = torch.tensor(test_acc).to(self.model.device)

            dist.all_reduce(train_loss_tensor, op=dist.ReduceOp.SUM)
            dist.all_reduce(train_acc_tensor, op=dist.ReduceOp.SUM)
            dist.all_reduce(test_loss_tensor, op=dist.ReduceOp.SUM)
            dist.all_reduce(test_acc_tensor, op=dist.ReduceOp.SUM)

            train_loss_tensor /= dist.get_world_size()
            train_acc_tensor /= dist.get_world_size()
            test_loss_tensor /= dist.get_world_size()
            test_acc_tensor /= dist.get_world_size()

            print('Step Counter is at: ', self.step_counter)
            # ACDC trainer
            if acdc:
                prune_seq = np.arange(14, epochs_per_level-15, 10)
                dense_seq = np.arange(19, epochs_per_level-20, 10)
                if np.isin(epoch, prune_seq):
                    # this is the compressed phase
                    print('Pruning the model: Compressed Phase of ACDC')
                    self.prune_in_level()
                if np.isin(epoch, dense_seq):
                    print('Making the Model Dense: Dense Phase of ACDC')

                    sync_model_weights_all_reduce(self.model)
                    if self.gpu_id == 0:
                        print('Making the Model Dense: Dense Phase of ACDC')
                        self.model = make_dense(self.model)
                    for param in self.model.parameters():
                        dist.broadcast(param.data, src=0)
            
            if rescale_mw: 
                rescale_seq = np.arange(rescale_every-1, int(rescale_till*epochs_per_level), rescale_every)
                if np.isin(epoch, rescale_seq):
                    sync_model_weights_all_reduce(self.model)
                    if self.gpu_id == 0:
                        print('Rescaling the mw model during training')
                        rescale_mw_model(self.model)                    
                    for param in self.model.parameters():
                        dist.broadcast(param.data, src=0)
            #####

            if self.gpu_id == 0:
                tr_l, te_l, tr_a, te_a = (
                    train_loss_tensor.item(),
                    test_loss_tensor.item(),
                    train_acc_tensor.item(),
                    test_acc_tensor.item(),
                )
                new_table.add_row([epoch, tr_l, te_l, tr_a, te_a])
                print(new_table)
                sparsity_level_df["epoch"].append(epoch)
                sparsity_level_df["train_loss"].append(tr_l)
                sparsity_level_df["test_loss"].append(te_l)
                sparsity_level_df["train_acc"].append(tr_a)
                sparsity_level_df["test_acc"].append(te_a)

                # Log scalars to tensorboard
                writer.add_scalar('Loss/train', tr_l, level * epoch + 1)
                writer.add_scalar('Loss/test', te_l, level * epoch + 1)
                writer.add_scalar('Acc/train', tr_a, level * epoch + 1)
                writer.add_scalar('Acc/test', te_a, level * epoch + 1)
                writer.add_scalar('lr', self.optimizer.param_groups[0]['lr'], level * epoch + 1) 

                self._log_metrics(level, tr_l, te_l, tr_a, te_a, get_model_density(self.model), epoch, epochs_per_level, self.optimizer.param_groups[0]['lr'])

                save_matching = (
                    (level == 0) and (epoch == 9)
                )
                if save_matching and (self.gpu_id == 0):
                    torch.save(
                        self.model.module.state_dict(),
                        os.path.join(self.expt_dir, "checkpoints", "model_rewind.pt"),
                    )
                    torch.save(
                        self.optimizer.state_dict(),
                        os.path.join(self.expt_dir, "artifacts", "optimizer_rewind.pt"),
                    )
                # writer.flush()

                if epoch % 10 == 9:
                    save_ckpt(self.expt_dir, level, epoch, self.optimizer, self.model)

        if self.gpu_id == 0:
            pd.DataFrame(sparsity_level_df).to_csv(
                os.path.join(
                    self.expt_dir,
                    "metrics",
                    "epochwise_metrics",
                    f"level_{level}_metrics.csv",
                )
            )
            sparsity = print_sparsity_info(self.model, verbose=False)
            data_df["level"].append(level)
            data_df["sparsity"].append(round(sparsity, 4))
            data_df["final_test_acc"].append(round(te_a, 4))
            data_df["max_test_acc"].append(round(max(sparsity_level_df["test_acc"]), 4))
            summary_path = os.path.join(self.expt_dir, "metrics", "summary.csv")

            if not os.path.exists(summary_path):
                pd.DataFrame(data_df).to_csv(summary_path, index=False)
            else:
                pre_df = pd.read_csv(summary_path)
                new_df = pd.DataFrame(data_df)
                updated_df = pd.concat([pre_df, new_df], ignore_index=True)
                updated_df.to_csv(summary_path, index=False)

@param("dist_params.address")
@param("dist_params.port")
def setup_distributed(address: str, port: int, gpu_id: int) -> None:
    """Setup distributed training environment.

    Args:
        address (str): Master address for distributed training.
        port (int): Master port for distributed training.
        gpu_id (int): current rank/gpu.
    """
    os.environ["MASTER_ADDR"] = address
    os.environ["MASTER_PORT"] = str(port)
    world_size = torch.cuda.device_count()
    dist.init_process_group("nccl", rank=gpu_id, world_size=world_size)
    torch.cuda.set_device(gpu_id)


def main(rank: int, model: nn.Module, level: int, expt_dir: str, target_mask_list: list) -> None:
    """Main function for distributed training.

    Args:
        rank (int): Rank of the current process.
        model (nn.Module): The model to train.
        level (int): Current sparsity level.
        expt_dir (str): Experiment directory to save artifacts/checkpoints/metrics to.
        writer (callable): SummaryWriter for tensorboard.
    """
    config = get_current_config()
    parser = ArgumentParser()
    config.augment_argparse(parser)
    config.collect_argparse_args(parser)
    config.validate(mode="stderr")
    set_seed()
    setup_distributed(gpu_id=rank)

    # Initialize TensorBoard SummaryWriter only on rank 0
    
    writer=None
    if rank == 0:
        writer = SummaryWriter(expt_dir)    
        wandb.init(name=config['experiment_params.expt_name'], project=config['experiment_params.wandb_project'])
        run_id = wandb.run.id
        wandb.run.tags += (config['dataset.dataset_name'].lower(),)

    harness = Harness(model=model, expt_dir=expt_dir, gpu_id=rank)

    if level != 0:
        harness.optimizer = reset_optimizer(
            expt_dir=expt_dir,
            optimizer=harness.optimizer,
            training_type=config["experiment_params.training_type"],
        )

    if (level == 0) and (rank == 0):
        torch.save(
            harness.optimizer.state_dict(),
            os.path.join(expt_dir, "artifacts", "optimizer_init.pt"),
        )
        torch.save(
            harness.model.module.state_dict(),
            os.path.join(expt_dir, "checkpoints", "model_init.pt"),
        )

    harness.train_one_level(level=level, writer=writer, target_mask_list=target_mask_list)

    if rank == 0:
        writer.close()
        wandb.finish()
        
    if rank == 0:
        ckpt = os.path.join(expt_dir, "checkpoints", f"model_level_{level}.pt")
        torch.save(harness.model.module.state_dict(), ckpt)

    dist.destroy_process_group()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Training on {world_size} GPUs")

    config = get_current_config()
    parser = ArgumentParser()
    config.augment_argparse(parser)
    config.collect_argparse_args(parser)

    config.validate(mode="stderr")
    config.summary()

    prune_harness = pruning_utils.PruningStuff()

    # if you provide resume level and resume experiment directory, it will pick up from where it stopped automatically
    resume_level = config["experiment_params.resume_level"]
    expt_dir = gen_expt_dir()
    save_config(expt_dir=expt_dir, config=config)
    densities = generate_densities()    
    
    # PaI to start from a sparse mask
    prune_harness.prune_at_initialization()

    for level in range(resume_level, len(densities)):
        print_sparsity_info(prune_harness.model, verbose=False)
        if level != 0:
            print(f"Pruning Model at level: {level}")
            prune_harness.load_from_ckpt(
                os.path.join(expt_dir, "checkpoints", f"model_level_{level-1}.pt")
            )
            prune_harness.level_pruner(density=densities[level], level=level)
            prune_harness.model = reset_weights(
                expt_dir=expt_dir,
                model=prune_harness.model,
                training_type=config["experiment_params.training_type"],
            )

            print_sparsity_info(prune_harness.model, verbose=False)

        target_mask_list = None
        if config["prune_params.er_method"] == 'anneal_balanced':
            target_mask_list = prune_harness.target_mask_list
        
        mp.spawn(
            main,
            args=(prune_harness.model, level, expt_dir, target_mask_list),
            nprocs=world_size,
            join=True,
        )
        print(f"Training level {level} complete, moving on to {level+1}")