# train.py
import os
import argparse
import random
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, distributed

from torchvision import datasets, transforms
import timm
from timm.data.auto_augment import rand_augment_transform

from models.shallow_titan import ShallowTitan


def setup_ddp():
    dist.init_process_group(backend='nccl')
    local_rank = int(os.environ['LOCAL_RANK'])
    torch.cuda.set_device(local_rank)
    return local_rank


def cleanup_ddp():
    dist.destroy_process_group()


# SAM Optimizer
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
        defaults = dict(rho=rho, **kwargs)
        super().__init__(params, defaults)
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)
            for p in group["params"]:
                if p.grad is None: continue
                e_w = p.grad * scale.to(p)
                p.add_(e_w)
                self.state[p]["e_w"] = e_w
        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None or not self.state[p]: continue
                p.sub_(self.state[p]["e_w"])
        self.base_optimizer.step()
        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def _grad_norm(self):
        norm = torch.norm(torch.stack([
            ((torch.abs(p.grad) if p.grad is not None else 0) ** 2).sum()
            for group in self.param_groups for p in group["params"]
        ]), p=2)
        return norm

    def step(self, closure=None):
        assert closure is not None, "SAM requires closure"
        closure = torch.enable_grad()(closure)
        self.first_step(zero_grad=True)
        closure()
        self.second_step()


# Distillation Loss
class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=3.0):
        super().__init__()
        self.alpha = alpha
        self.temperature = temperature
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, outputs, labels, teacher_outputs):
        hard_loss = self.ce_loss(outputs, labels)
        soft_targets = (teacher_outputs / self.temperature).softmax(dim=1)
        soft_prob = (outputs / self.temperature).softmax(dim=1)
        soft_targets_loss = -(soft_targets * torch.log(soft_prob)).sum(dim=1).mean()
        loss = (1 - self.alpha) * hard_loss + self.alpha * soft_targets_loss * (self.temperature ** 2)
        return loss


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data-path', default='/fs/scratch/PDS0359/data', type=str)
    parser.add_argument('--epochs', default=100, type=int)
    parser.add_argument('--batch-size', default=128, type=int)
    parser.add_argument('--depth', default=2, type=int)
    parser.add_argument('--embed-dim', default=2048, type=int)
    parser.add_argument('--lr', default=0.001, type=float)
    parser.add_argument('--teacher-model', default='vit_base_patch16_224', type=str)
    parser.add_argument('--output-dir', default='./output', type=str)
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--dataset', default='cifar10', choices=['imagenet', 'cifar10'])
    args = parser.parse_args()

    local_rank = setup_ddp()
    device = torch.device('cuda', local_rank)
    torch.cuda.set_device(device)
    print(f"Using device: {device}")

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    os.makedirs(args.output_dir, exist_ok=True)
    writer = SummaryWriter(log_dir=os.path.join(args.output_dir, 'logs'))

    if args.dataset == 'imagenet':
        num_classes = 1000
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            rand_augment_transform('rand-m9-mstd0.5-inc1', {'translate_const': 117, 'img_mean': (124, 116, 104)}),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
            transforms.RandomErasing(p=0.25, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='random'),
        ])
        val_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        train_dataset = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=train_transform)
        val_dataset = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=val_transform)
    else:
        num_classes = 10
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        transform_val = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        train_dataset = datasets.CIFAR10(root=args.data_path, train=True, download=True, transform=transform_train)
        val_dataset = datasets.CIFAR10(root=args.data_path, train=False, download=True, transform=transform_val)

    train_sampler = distributed.DistributedSampler(train_dataset)
    val_sampler = distributed.DistributedSampler(val_dataset, shuffle=False)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, sampler=val_sampler, num_workers=4, pin_memory=True)

    model = ShallowTitan(num_classes=num_classes, embed_dim=args.embed_dim, depth=args.depth).to(device)
    model = DDP(model, device_ids=[local_rank], output_device=local_rank)

    print(f"Loading teacher model: {args.teacher_model}")
    teacher_model = timm.create_model(args.teacher_model, pretrained=True)
    teacher_model.head = torch.nn.Linear(teacher_model.head.in_features, num_classes)
    teacher_model = teacher_model.to(device)
    teacher_model.eval()

    criterion = DistillationLoss(alpha=0.7, temperature=3.0)
    optimizer = SAM(
        model.parameters(),
        base_optimizer=optim.AdamW,
        lr=args.lr,
        weight_decay=0.05,
        betas=(0.9, 0.999),
        rho=0.05
    )
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer.base_optimizer, T_0=10, T_mult=2, eta_min=args.lr * 0.01
    )

    best_acc = 0.0
    for epoch in range(args.epochs):
        model.train()
        train_sampler.set_epoch(epoch)
        train_loss = 0.0
        progress_bar = tqdm(train_loader)

        for batch_idx, (inputs, targets) in enumerate(progress_bar):
            inputs, targets = inputs.to(device), targets.to(device)
            with torch.no_grad():
                teacher_outputs = teacher_model(inputs)

            def closure():
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets, teacher_outputs)
                loss.backward()
                return loss

            loss = closure()
            optimizer.first_step(zero_grad=True)
            outputs = model(inputs)
            loss = criterion(outputs, targets, teacher_outputs)
            loss.backward()
            optimizer.second_step(zero_grad=True)

            train_loss += loss.item()
            progress_bar.set_description(f"Epoch {epoch+1}/{args.epochs} - Loss: {train_loss/(batch_idx+1):.4f}")

        scheduler.step()

        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in tqdm(val_loader, desc="Validating"):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                val_loss += F.cross_entropy(outputs, targets).item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        val_acc = 100. * correct / total
        print(f"Epoch {epoch+1} - Validation Accuracy: {val_acc:.2f}%")
        writer.add_scalar('Loss/train', train_loss / len(train_loader), epoch)
        writer.add_scalar('Loss/val', val_loss / len(val_loader), epoch)
        writer.add_scalar('Accuracy/val', val_acc, epoch)

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_acc': best_acc,
            }, os.path.join(args.output_dir, f'shallow_titan_depth{args.depth}_best.pth'))
            print(f"Saved best model with accuracy: {best_acc:.2f}%")

    torch.save({
        'epoch': args.epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_acc': best_acc,
    }, os.path.join(args.output_dir, f'shallow_titan_depth{args.depth}_final.pth'))

    print(f"Training completed. Best validation accuracy: {best_acc:.2f}%")
    writer.close()
    cleanup_ddp()

if __name__ == '__main__':
    main()
