import os
import pathlib
import argparse
from argparse import ArgumentParser
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from pytorch_lightning.loggers import TensorBoardLogger
import torch.nn.functional as F
import math
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from pytorch_lightning.callbacks import LearningRateMonitor

from utils.callbacks.log_cpr import LogCPR
from utils.callbacks.log_gradient import LogParamsAndGrads
from utils.optim.cpr_wrapper import apply_CPR
from utils.optim.adam_cpr import group_testbed_parameters
from utils.callbacks.schedule_weight_decay import WeightDecayScheduler

from utils.optim.adam_adadecay import AdamAdaDecay
from utils.optim.adam_awd import AdamAWD
from utils.optim.amos import Amos

torch.set_float32_matmul_precision('medium')

batch_size = 128
max_train_steps = 10000
max_train_epochs = 50

### Data
def cifar100_task(cache_dir='./data'):

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR100(root=cache_dir, train=True, download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR100(root=cache_dir, train=False, download=True, transform=transform_test)

    return trainset, testset

### Model
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

### Lightning Module
class ResNetModule(pl.LightningModule):

    def __init__(self, config ):
        super().__init__()

        self.cfg = config

        if self.cfg.model_name == "ResNet18":
            self.model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=100)
        elif self.cfg.model_name == "ResNet34":
            self.model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=100)
        elif self.cfg.model_name == "ResNet50":
            self.model =ResNet(Bottleneck, [3,4,6,3], num_classes=100)
        elif self.cfg.model_name == "ResNet101":
            self.model = ResNet(Bottleneck, [3,4,23,3], num_classes=100)
        elif self.cfg.model_name == "ResNet152":
            self.model = ResNet(Bottleneck, [3,8,36,3], num_classes=100)


        self.loss = nn.CrossEntropyLoss(label_smoothing=0.1)

        self.valid_stats = []


    def configure_optimizers(self):
        
        
        


        if self.cfg.optimizer == 'adamw':
            optim_cfg = {"lr": self.cfg.lr, "beta1": self.cfg.beta1,
                                "beta2": self.cfg.beta2,
                                "weight_decay": self.cfg.weight_decay}

            param_groups = group_testbed_parameters(self.model, optim_cfg, avoid_keywords=["bias", "bn"])
            optimizer = torch.optim.AdamW(param_groups, lr=self.cfg.lr, betas=(self.cfg.beta1,self.cfg.beta2))
            print("INIT WD",  [f"{i} {p['weight_decay']}" for i, p in enumerate(optimizer.param_groups)])

        elif self.cfg.optimizer == 'adamw_warmstart':
            optim_cfg = {"lr": self.cfg.lr, "beta1": self.cfg.beta1,
                                "beta2": self.cfg.beta2,
                                "weight_decay": 0}

            param_groups = group_testbed_parameters(self.model, optim_cfg, avoid_keywords=["bias", "bn"])
            optimizer = torch.optim.AdamW(param_groups, lr=self.cfg.lr, betas=(self.cfg.beta1,self.cfg.beta2))

        elif self.cfg.optimizer == 'adam_adadecay':

            optim_cfg = {"lr": self.cfg.lr, "beta1": self.cfg.beta1,
                         "beta2": self.cfg.beta2,
                         "weight_decay": self.cfg.weight_decay}

            param_groups = group_testbed_parameters(self.model, optim_cfg, avoid_keywords=["bias", "bn"])
            optimizer = AdamAdaDecay(param_groups, lr=self.cfg.lr, betas=(self.cfg.beta1, self.cfg.beta2), alpha=self.cfg.adadecay_alpha, weight_decay=self.cfg.weight_decay, apply_decay=None)

        elif self.cfg.optimizer == 'adam_awd':

            optim_cfg = {"lr": self.cfg.lr, "beta1": self.cfg.beta1,
                         "beta2": self.cfg.beta2,
                         "weight_decay": self.cfg.weight_decay}

            param_groups = group_testbed_parameters(self.model, optim_cfg, avoid_keywords=["bias", "bn"])
            optimizer = AdamAWD(param_groups, lr=self.cfg.lr, betas=(self.cfg.beta1, self.cfg.beta2), weight_decay=self.cfg.weight_decay, apply_decay=None)

        elif self.cfg.optimizer == 'amos':
            params = self.model.parameters()
            optimizer = Amos(params, lr=self.cfg.lr, extra_l2=self.cfg.weight_decay)


        elif self.cfg.optimizer == 'adamcpr':


            optimizer = apply_CPR(self.model, torch.optim.Adam, self.cfg.kappa_init_param, self.cfg.kappa_init_method,
                                  self.cfg.reg_function,
                                  self.cfg.kappa_adapt, self.cfg.kappa_update,
                                  normalization_regularization=False, bias_regularization=False,
                                  embedding_regularization=True,
                                  lr=self.cfg.lr, betas=(self.cfg.beta1, self.cfg.beta2))

            param_groups = optimizer.state_dict()['param_groups']
            params = list(self.model.parameters())
            for param_group in param_groups:
                for index, param_id in enumerate(param_group['params']):
                    param_group['params'][index] = params[param_id]


        if self.cfg.rescale_alpha > 0.1:
            with torch.no_grad():
                for n, p in self.model.named_parameters():
                    if n.endswith("weight"):
                        p.data *=self.cfg.rescale_alpha
                self.rescale_norm = np.sqrt(sum(p.pow(2).sum().item() for n, p in self.model.named_parameters() if n.endswith("weight")))

        lr_decay_factor = self.cfg.lr_decay_factor
        num_warmup_steps = self.cfg.lr_warmup_steps

        def lr_lambda(current_step: int):
            if current_step < num_warmup_steps:
                return float(current_step) / float(max(1, num_warmup_steps))
            else:
                return lr_decay_factor + (1 - lr_decay_factor) * max(0.0, (1 + math.cos(
                    math.pi * (current_step - num_warmup_steps) / float(
                        max(1, self.cfg.max_train_steps - num_warmup_steps)))) / 2)
        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)

        return [optimizer], {'scheduler': lr_scheduler, 'interval': 'step'}

    def setup(self, stage: str) -> None:
        trainset, testset = cifar100_task()
        self.trainset = trainset
        self.testset = testset


    def train_dataloader(self):
        train_loader = torch.utils.data.DataLoader(self.trainset, batch_size=batch_size, shuffle=True, num_workers=8)
        return train_loader

    def val_dataloader(self):
        test_loader = torch.utils.data.DataLoader(self.testset, batch_size=batch_size, shuffle=False, num_workers=8)
        return test_loader

    def _accuracy(self, y_hat, y):
        return torch.sum(torch.argmax(y_hat, dim=1) == y).item() / len(y)

    def training_step(self, batch, batch_idx):

        if self.cfg.optimizer == 'adamw_warmstart':
            if self.cfg.kappa_init_param == self.global_step:
                for param_group in self.trainer.optimizers[0].param_groups:
                    param_group['weight_decay'] = self.cfg.weight_decay

        X, y = batch
        y_hat = self.model(X)
        loss = self.loss(y_hat, y)

        self.log('train_loss', loss)

        if self.cfg.rescale_alpha > 0.1:
            with torch.no_grad():
                new_norm = np.sqrt(sum(p.pow(2).sum().item() for n, p in self.model.named_parameters() if n.endswith("weight")))
                for n , p in self.model.named_parameters():
                    if n.endswith("weight"):
                        p.data *= self.rescale_norm / new_norm
        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        y_hat = self.model(X)
        loss = self.loss(y_hat, y)

        correct_pred = torch.sum(torch.argmax(y_hat, dim=1) == y).item()
        num_samples = len(y)
        self.valid_stats.append({'loss': loss.item(), 'correct_pred': correct_pred, 'num_samples': num_samples})
        self.log('valid_loss', loss)

        return loss


    def on_validation_epoch_end(self):
        self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'])

        valid_loss = np.mean([s['loss'] for s in self.valid_stats])
        valid_accuracy = np.sum([s['correct_pred'] for s in self.valid_stats]) / np.sum([s['num_samples'] for s in self.valid_stats])
        self.log('valid_loss', valid_loss)
        self.log('valid_accuracy', valid_accuracy)
        self.valid_stats = []



def train_cifar100_task(config):


    task_name = f"{config.model_name}_seed{config.seed}_steps{config.max_train_steps}"
    expt_dir = pathlib.Path(config.output_dir) / config.session / task_name
    expt_dir.mkdir(parents=True, exist_ok=True)




    if config.optimizer == "adamcpr":
        expt_name = f"{config.optimizer}_p{config.kappa_init_param}_m{config.kappa_init_method}_kf{config.reg_function}_r{config.kappa_update}_l{config.lr}_lw{config.lr_warmup_steps}_ad{config.kappa_adapt}"
    elif config.optimizer == "adamw":
        expt_name = f"{config.optimizer}_l{config.lr}_lw{config.lr_warmup_steps}_w{config.weight_decay}_re{config.rescale_alpha}_swd{config.schedule_weight_decay}_swds{config.wd_scale}_t{config.wd_schedule_type}"
    else:
        expt_name = f"{config.optimizer}_l{config.lr}_lw{config.lr_warmup_steps}_w{config.weight_decay}_re{config.rescale_alpha}_swd{config.schedule_weight_decay}_swds{config.wd_scale}_t{config.wd_schedule_type}_a{config.adadecay_alpha}_p{config.kappa_init_param}"

    if os.path.isfile(expt_dir / expt_name / "result.npy"):
        print(f"{expt_name} exisits")
        return

    (expt_dir/expt_name).mkdir(parents=True, exist_ok=True)
    np.save(expt_dir / expt_name / "config.npy", config.__dict__)
    logger = TensorBoardLogger(save_dir=expt_dir , name=expt_name)
    pl.seed_everything(config.seed)

    if config.device:
        devices = [config.device]
    else:
        devices = [0]


    model = ResNetModule(config) #.to(devices[0])

    callbacks = [
        LearningRateMonitor(logging_interval='step'),
        LogCPR(log_every_n_steps=config.log_interval),
        WeightDecayScheduler(config.schedule_weight_decay, schedule_type=config.wd_schedule_type, scale=config.wd_scale),
        LogParamsAndGrads(log_every_n_steps=config.log_interval, log_params=True, log_gradient=False, log_quantiles = False),
    ]


    trainer = pl.Trainer(devices=devices, accelerator="gpu", max_steps=config.max_train_steps,
                         log_every_n_steps=config.log_interval,
                         enable_progress_bar=True,
                         logger=logger, callbacks=callbacks)
    trainer.fit(model)

    def print_model_param_stats(model):
        for idx, (name, params) in enumerate(model.named_parameters()):
            print(f"{idx:03d} {name:70} shape:{str(list(params.shape)):12} mean:{params.mean():8.4f} std:{params.std():8.6f} grad: {params.requires_grad}")
    print_model_param_stats(model)


    # evaluate model
    result = trainer.validate(model)
    np.save(expt_dir / expt_name / "result.npy", result)
    print(result)


if __name__ == "__main__":

    parser = ArgumentParser()
    parser.add_argument("--session", type=str, default='cifar100_cpr')
    parser.add_argument("--seed", type=int, default=1)
    parser.add_argument("--max_train_steps", type=int, default=20000) # 20k

    parser.add_argument("--model_name", type=str, default="ResNet18")

    parser.add_argument("--optimizer", type=str, default='adamcpr') # adam_adadecay adamcpr adamw adam_awd amos
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--beta1", type=float, default=0.9)
    parser.add_argument("--beta2", type=float, default=0.98)
    parser.add_argument("--weight_decay", type=float, default=0.1)
    parser.add_argument("--adadecay_alpha", type=float, default=4.0)

    parser.add_argument("--schedule_weight_decay", action=argparse.BooleanOptionalAction)
    parser.add_argument("--wd_schedule_type", type=str, default='cosine')
    parser.add_argument("--wd_scale", type=float, default=0.1)

    parser.add_argument("--lr_warmup_steps", type=int, default=500)
    parser.add_argument("--lr_decay_factor", type=float, default=0.1)

    parser.add_argument("--rescale_alpha", type=float, default=0)

    parser.add_argument("--kappa_init_param", type=float, default=500)
    parser.add_argument("--kappa_init_method", type=str, default='inflection_point')
    parser.add_argument("--reg_function", type=str, default='l2')
    parser.add_argument("--kappa_update", type=float, default=1.0)
    parser.add_argument("--kappa_adapt", action=argparse.BooleanOptionalAction)

    parser.add_argument("--start_epoch", type=int, default=1) #

    parser.add_argument("--log_interval", type=int, default=10)
    parser.add_argument("--output_dir", type=str, default='cifar100')
    parser.add_argument("--device", type=int, default=0)

    args = parser.parse_args()

    print(args.__dict__)

    if args.rescale_alpha > 0.1:
        assert args.optimizer == 'adamw'

    train_cifar100_task(args)
