import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18
from tqdm import tqdm
from models.vit import ViT
from models.swint import swin_t
import timm
import torch.optim as optim
from torchmetrics.classification import Accuracy
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

class ClassificationModel(pl.LightningModule):
    def __init__(self, model_name='resnet18', num_classes=10, learning_rate=0.001, pretrained='No'):
        super().__init__()
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.pretrained = pretrained

        if pretrained == 'No':
            self.model = resnet18(weights=None, num_classes=num_classes)
        else:
            self.model = resnet18(weights='IMAGENET1K_V1')

        if pretrained == 'LastLayer':
            for param in self.model.parameters():
                param.requires_grad = False

        if pretrained not in ['No']:
            self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
            if pretrained == 'LastLayer':
                for param in self.model.fc.parameters():
                    param.requires_grad = True

        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = Accuracy(task='multiclass', num_classes=num_classes)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        self.log('train_loss', loss, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        acc = self.accuracy(logits, y)
        self.log('val_acc', acc, on_step=False, on_epoch=True)
        print(f"Validation Accuracy: {acc:.2f}")

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        acc = self.accuracy(logits, y)
        self.log('test_acc', acc, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        if self.learning_rate < 0.0015: # control optimizer based on the input lr in the config
            optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
        else:
            optimizer = optim.SGD(self.parameters(), lr=self.learning_rate, momentum=0.9, weight_decay=5e-4)

        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.trainer.max_epochs)
        return [optimizer], [scheduler]

def train_resnet_lightning(trainloader, testloader, validation_loader=None, epochs=25, learning_rate=0.001,
                if_resnet18=True, num_classes=10, pretrained='No',plot_label='resnet_training'): #mainly for imagenet for faster training
    model_name = 'resnet18'
    model = ClassificationModel(model_name=model_name, num_classes=num_classes, learning_rate=learning_rate, pretrained=pretrained)

    callbacks = []
    if validation_loader is not None:
        callbacks.append(EarlyStopping(monitor='val_acc', patience=30, mode='max')) # imagenet -> 30
        callbacks.append(ModelCheckpoint(monitor='val_acc', mode='max', save_top_k=1))

    trainer = pl.Trainer(
        max_epochs=epochs,
        accelerator='gpu',
        devices=-1,
        callbacks=callbacks,
        log_every_n_steps=10,
    )

    trainer.fit(model, train_dataloaders=trainloader, val_dataloaders=validation_loader)
    test_results = trainer.test(model, dataloaders=testloader)

    test_acc = test_results[0]['test_acc'] if 'test_acc' in test_results[0] else None
    print(f"Test Accuracy: {test_acc * 100:.2f}%" if test_acc is not None else "Test accuracy not found.")

    return test_acc

def train_resnet(trainloader, testloader, validation_loader=None, epochs=25, learning_rate=0.001,
                 if_resnet18=True, num_classes=10, pretrained="No",model_name='resnet18'):

    num_epochs = epochs


    if if_resnet18 and pretrained == "No":
        if model_name == 'resnet18':
            model = resnet18(weights=None, num_classes=num_classes)
        elif model_name == 'vit':
            model = ViT(
                image_size = 32,
                patch_size = 4,
                num_classes = num_classes,
                dim = 512,
                depth = 6,
                heads = 8,
                mlp_dim = 512,
                dropout = 0.1,
                emb_dropout = 0.1
            )

        elif model_name == "swint":
            model = swin_t(window_size=4,
                num_classes=num_classes,
                downscaling_factors=(2,2,2,1))

    else:
        if model_name == 'resnet18':
            model = resnet18(weights="IMAGENET1K_V1" if pretrained != "No" else None)
            model.fc = nn.Linear(model.fc.in_features, num_classes)
        elif model_name == 'vit':

            model = timm.create_model("vit_base_patch16_384", pretrained=True)
            model.head = nn.Linear(model.head.in_features, num_classes)

  
    model = nn.DataParallel(model).to('cuda')
    criterion = nn.CrossEntropyLoss()
    if learning_rate < 0.0015: # control optimizer based on the input lr in the config
        optimizer = optim.Adam(model.parameters(), lr=learning_rate) 
    else:
        optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    use_early_stopping = validation_loader is not None
    best_acc = 0.0
    patience = 20 # cifar10 -> 20
    epochs_without_improvement = 0
    best_model_state = None

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in tqdm(trainloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            inputs, labels = inputs.to('cuda'), labels.to('cuda')

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            loss.backward()

            optimizer.step()

            running_loss += loss.item()

        scheduler.step()
        print(f"Epoch {epoch+1}, Loss: {running_loss / len(trainloader):.4f}")

        if use_early_stopping:
            model.eval()
            correct, total = 0, 0
            with torch.no_grad():
                for inputs, labels in validation_loader:
                    inputs, labels = inputs.to('cuda'), labels.to('cuda')
                    outputs = model(inputs)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            val_acc = 100 * correct / total
            print(f"Validation Accuracy: {val_acc:.2f}%")

            if val_acc > best_acc:
                best_acc = val_acc
                epochs_without_improvement = 0
                best_model_state = model.state_dict()
            else:
                epochs_without_improvement += 1
                if epochs_without_improvement >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break
        torch.cuda.empty_cache()

    if use_early_stopping and best_model_state is not None:
        model.load_state_dict(best_model_state)

    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to('cuda'), labels.to('cuda')
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = 100 * correct / total

    return acc

def train_resnet_w2s(trainloader, testloader, validation_loader=None, epochs=25, learning_rate=0.001,
                 if_resnet18=True, num_classes=10, pretrained="No", model_name='resnet18'): #distillation

    num_epochs = epochs

    if model_name == 'resnet18':
            model = resnet18(weights=None, num_classes=num_classes)
    elif model_name == 'vit':
        model = ViT(
            image_size = 32,
            patch_size = 4,
            num_classes = num_classes,
            dim = 512,
            depth = 6,
            heads = 8,
            mlp_dim = 512,
            dropout = 0.1,
            emb_dropout = 0.1
        )

    elif model_name == "swint":
        model = swin_t(window_size=4,
            num_classes=num_classes,
            downscaling_factors=(2,2,2,1))

    Goal_Model = torch.hub.load("<path-to-resnet>", "cifar10_resnet56", pretrained=True)
    Goal_Model = Goal_Model.to('cuda')
    Goal_Model.eval()

    model = nn.DataParallel(model).to('cuda')

    criterion = nn.MSELoss()
    if learning_rate < 0.0015: # control optimizer based on the input lr in the config
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    else:
        optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)


    use_early_stopping = validation_loader is not None
    best_acc = 0.0
    patience = 20
    epochs_without_improvement = 0
    best_model_state = None

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in tqdm(trainloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            inputs, labels = inputs.to('cuda'), labels.to('cuda')

            optimizer.zero_grad()

            outputs = model(inputs)
            with torch.no_grad():
                goal_outputs = Goal_Model(inputs)
            loss = criterion(outputs, goal_outputs)

            loss.backward()

            optimizer.step()

            running_loss += loss.item()

        scheduler.step()
        print(f"Epoch {epoch+1}, Loss: {running_loss / len(trainloader):.4f}")

        if use_early_stopping:
            model.eval()
            correct, total = 0, 0
            with torch.no_grad():
                for inputs, labels in validation_loader:
                    inputs, labels = inputs.to('cuda'), labels.to('cuda')
                    outputs = model(inputs)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            val_acc = 100 * correct / total
            print(f"Validation Accuracy: {val_acc:.2f}%")

            if val_acc > best_acc:
                best_acc = val_acc
                epochs_without_improvement = 0
                best_model_state = model.state_dict()
            else:
                epochs_without_improvement += 1
                if epochs_without_improvement >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break
        torch.cuda.empty_cache()

    if use_early_stopping and best_model_state is not None:
        model.load_state_dict(best_model_state)

    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to('cuda'), labels.to('cuda')
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = 100 * correct / total

    return acc
