import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms
from transformers import ViTForImageClassification, ViTConfig
from transformers import get_linear_schedule_with_warmup
from torch.optim import AdamW
import matplotlib.pyplot as plt
import pandas as pd
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import json


def setup_distributed():

    rank = int(os.environ.get('RANK', 0))
    world_size = int(os.environ.get('WORLD_SIZE', 8))
    
    if world_size > 1:

        dist.init_process_group(
            backend='nccl',
            init_method='env://',
            world_size=world_size,
            rank=rank
        )
        print(f"Initialized distributed training on rank {rank}/{world_size}")
    
    return rank, world_size


rank, world_size = setup_distributed()
if world_size > 1:
    device = torch.device(f'cuda:{rank}')
    torch.cuda.set_device(rank)
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Rank {rank} using device: {device}")

class Config:
    def __init__(self, dataset_name):
        self.dataset_name = dataset_name
        self.batch_size = 64
        self.learning_rate = 2e-5
        self.weight_decay = 0.01
        self.epochs = 30
        self.num_workers = 4
        self.pretrained = True
        self.train_log_interval = 100

        if self.dataset_name == "cifar10":
            self.num_classes = 10
            self.image_size = 32
        elif self.dataset_name == "cifar100":
            self.num_classes = 100
            self.image_size = 32
        elif self.dataset_name == "mnist":
            self.num_classes = 10
            self.image_size = 28
        else:
            raise ValueError("Unsupported dataset. Choose from: cifar10, cifar100, mnist")


def get_dataloaders(config, rank=0, world_size=1):

    if config.dataset_name == "mnist":
        transform_train = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.Grayscale(3),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        
        transform_test = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.Grayscale(3),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
    else:
        transform_train = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        transform_test = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    if config.dataset_name == "cifar10":
        train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
        test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    elif config.dataset_name == "cifar100":
        train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
        test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
    elif config.dataset_name == "mnist":
        train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform_train)
        test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)

    if world_size > 1:
        train_sampler = DistributedSampler(
            train_dataset, 
            num_replicas=world_size, 
            rank=rank,
            shuffle=True
        )
        test_sampler = DistributedSampler(
            test_dataset, 
            num_replicas=world_size, 
            rank=rank,
            shuffle=False
        )
    else:
        train_sampler = None
        test_sampler = None
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=config.batch_size, 
        shuffle=(train_sampler is None),
        num_workers=config.num_workers,
        sampler=train_sampler,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=config.batch_size, 
        shuffle=False,
        num_workers=config.num_workers,
        sampler=test_sampler,
        pin_memory=True
    )
    
    return train_loader, test_loader


def get_model(config):

    patch_size = 4
    num_patches = (config.image_size // patch_size) ** 2

    config_vit = ViTConfig(
        image_size=config.image_size,
        patch_size=patch_size,
        num_classes=config.num_classes,
        hidden_size=256,
        num_hidden_layers=6,
        num_attention_heads=8,
        intermediate_size=512,
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        initializer_range=0.02
    )
    
    model = ViTForImageClassification.from_pretrained(
            "vit",
            num_labels=config.num_classes,
        )
    
    if world_size > 1:
        model = DDP(model.to(device), device_ids=[rank])
    else:
        model = model.to(device)
    
    return model

def train_epoch(model, train_loader, optimizer, scheduler, epoch, config, history, varepsilon, rank=0):
    model.train()
    loss_list = []
    correct = 0
    total = 0

    count = len(train_loader) // 6
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        outputs = model(data, labels=None)
        loss = nn.CrossEntropyLoss()(outputs.logits * varepsilon, target)
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        
        loss_list.append(loss.item())
        _, predicted = torch.max(outputs.logits, 1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        
        if batch_idx % config.train_log_interval == 0 and rank == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

        if batch_idx >= count:
            break
    
    if world_size > 1:
        loss_list_tensor = torch.tensor(loss_list).to(device)
        correct_tensor = torch.tensor([correct]).to(device)
        total_tensor = torch.tensor([total]).to(device)
        
        dist.all_reduce(loss_list_tensor.sum(), op=dist.ReduceOp.SUM)
        dist.all_reduce(correct_tensor, op=dist.ReduceOp.SUM)
        dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
        
        loss_list_sum = sum(loss_list_tensor.tolist())
        correct = correct_tensor.item()
        total = total_tensor.item()
    
    avg_loss = loss_list_sum / len(train_loader)
    accuracy = 100. * correct / total
    
    if rank == 0:
        history['train_loss'].extend(loss_list)
        history['train_acc'].append(accuracy)
        history['train_avg_loss'].append(avg_loss)
        print(f'Epoch {epoch} Training: Average loss: {avg_loss:.4f}, Accuracy: {correct}/{total} ({accuracy:.2f}%)')
    
    return avg_loss, accuracy

def test_epoch(model, test_loader, epoch, config, history, varepsilon, rank=0):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data, labels=None)
            test_loss += nn.CrossEntropyLoss()(outputs.logits * varepsilon, target).item()
            _, predicted = torch.max(outputs.logits, 1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    
    if world_size > 1:
        test_loss_tensor = torch.tensor([test_loss]).to(device)
        correct_tensor = torch.tensor([correct]).to(device)
        total_tensor = torch.tensor([total]).to(device)
        
        dist.all_reduce(test_loss_tensor, op=dist.ReduceOp.SUM)
        dist.all_reduce(correct_tensor, op=dist.ReduceOp.SUM)
        dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)
        
        test_loss = test_loss_tensor.item()
        correct = correct_tensor.item()
        total = total_tensor.item()
    
    avg_loss = test_loss / len(test_loader)
    accuracy = 100. * correct / total
    
    if rank == 0:
        history['test_loss'].append(avg_loss)
        history['test_acc'].append(accuracy)
        print(f'Epoch {epoch} Testing: Average loss: {avg_loss:.4f}, Accuracy: {correct}/{total} ({accuracy:.2f}%)')
    
    return avg_loss, accuracy

def train_on_dataset(dataset_name, varepsilon, rank=0):
    if rank == 0:
        print(f"Training on {dataset_name}")
    
    config = Config(dataset_name)
    
    train_loader, test_loader = get_dataloaders(config, rank, world_size)
    
    model = get_model(config)
    
    optimizer = AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    total_steps = len(train_loader) * config.epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0.1 * total_steps,
        num_training_steps=total_steps
    )
    

    if rank == 0:
        history = {
            'train_loss': [],
            'train_acc': [],
            'train_avg_loss': [],
            'test_loss': [],
            'test_acc': []
        }
    else:
        history = None
    

    best_acc = 0
    for epoch in range(1, config.epochs + 1):
        if world_size > 1:
            train_loader.sampler.set_epoch(epoch)
        
        start_time = time.time()
        
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, epoch, config, history, varepsilon, rank)
        test_loss, test_acc = test_epoch(model, test_loader, epoch, config, history, varepsilon, rank)
        
        epoch_time = time.time() - start_time

        if rank == 0:
            print(f'Epoch {epoch} completed in {epoch_time:.2f}s')
            
            if test_acc > best_acc:
                best_acc = test_acc

    if rank == 0:

        plt.figure(figsize=(12, 5))
        
        plt.subplot(1, 2, 1)
        plt.plot(history['train_avg_loss'], label='Train Loss')
        plt.plot(history['test_loss'], label='Test Loss')
        plt.title(f'ViT on {dataset_name} - Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(history['train_acc'], label='Train Accuracy')
        plt.plot(history['test_acc'], label='Test Accuracy')
        plt.title(f'ViT on {dataset_name} - Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy (%)')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig(f'vit_training_curves_{dataset_name}.png')
        plt.close()  
        
        return history
    else:
        return None


if __name__ == "__main__":
    datasets_to_train = ["mnist"]
    varepsilons = [32, 16, 8, 4, 2, 1, 0.5, 0.25, 0.125]
    all_histories = {}
    repeat_times = 5
    
    for repeat in range(repeat_times):
        for dataset in datasets_to_train:
            all_histories[dataset] = {}
            for varepsilon in varepsilons:
                history = train_on_dataset(dataset, varepsilon, rank)
                all_histories[dataset][varepsilon] = history

                if rank == 0:
                    with open(f'vit_training_history_{dataset}_{repeat}.json', 'w') as file:
                        file.write(json.dumps(all_histories))
                    print(f"\n{dataset.upper()} Final Results:")
                    print(f"Train Accuracy: {history['train_acc'][-1]:.2f}%")
                    print(f"Test Accuracy: {history['test_acc'][-1]:.2f}%")
                    print("-" * 50)
    
    if world_size > 1:
        dist.destroy_process_group()
