import torch
import numpy as np
import torch
from torchvision import models
import torch.nn as nn
from torch.nn import functional as F
import pytorch_lightning as pl
from sklearn.metrics import accuracy_score


INPUT_SIZE = (200, 300)
BATCH_SIZE = 4
MAX_EPOCHS = 10
BASE_LR = 1e-4
NUM_CLASSES = 2
NUM_WORKERS = 12
SET_SPLIT = 0.9


class BaseClassifier(pl.LightningModule):
    def __init__(self, num_classes, transfer=False, internal_features=1024):
        super().__init__()        

        self.model =  models.efficientnet_v2_s(pretrained=transfer)
        
        linear_size = list(self.model.children())[-1][1].in_features
        self.model.classifier = nn.Sequential(
            nn.Dropout(p=0.4, inplace=False),
            nn.Linear(in_features=linear_size, out_features=internal_features, bias=True),
            nn.LeakyReLU(),
            nn.Dropout(p=0.4, inplace=False),
            nn.Linear(in_features=internal_features, out_features=1, bias=True),
            #nn.LogSoftmax(dim=1),
            nn.Sigmoid()
            )

        for child in list(self.model.features.children()):
            for param in child.parameters():
                param.requires_grad = True

        for child in list(self.model.features.children())[:1]:
            for param in child.parameters():
                param.requires_grad = False
        
    def forward(self, x):
        return self.model(x)
    

class CustomNetwork(pl.LightningModule):
    """
    Класс, реализующий нейросеть для классификации.
    :param features_criterion: loss-функция на признаки, извлекаемые нейросетью перед классификацией (None когда нет такого лосса)
    :param internal_features: внутреннее число признаков
    """
    def __init__(
            self, 
            baseline=BaseClassifier,
            features_criterion=F.cross_entropy, 
            internal_features=1024, 
            lr=BASE_LR, 
            transfer=False, 
            cp_path='./a.pth'):

        super(CustomNetwork, self).__init__()
        self.model = baseline(NUM_CLASSES, transfer=transfer, internal_features=internal_features)
        self.lr = lr
        self.loss = features_criterion
        self.checkpoint_path = cp_path
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        pred = self(x)
        print(pred.view(-1))
        print(y)
        loss = self.loss(pred.view(-1), y)
        print(loss)
        print('--------------------')
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        pred = self.forward(x)
        acc = accuracy_score(y.cpu().numpy(), np.argmax(pred.cpu().numpy(), -1))
        return acc

    def test_step(self, val_batch, batch_idx):
        x, y = val_batch
        pred = self.forward(x)
        acc = accuracy_score(y.cpu().numpy(), np.argmax(pred.cpu().numpy(), -1))
        return acc

    def training_epoch_end(self, outputs):
        end_losses = list(map(lambda x: x['loss'], outputs))
        min_loss = min(end_losses)
        torch.save(self.state_dict(), self.checkpoint_path)
        self.log('train_loss', min_loss, prog_bar=True, on_epoch=True, on_step=False)

    def validation_epoch_end(self, outputs):
        acc = np.mean(outputs)
        self.log('val_acc', acc * 100, prog_bar=True, on_epoch=True, on_step=False)

    def configure_optimizers(self):
        params = list(self.parameters())
        grouped_parameters = [
            {"params": [p for p in params[ : len(params) // 10]], "lr": self.lr / 10},
            {"params": [p for p in params[len(params) // 10 : ]], "lr": self.lr }
        ]
        grouped_parameters = self.parameters()
        optimizer = torch.optim.Adam(grouped_parameters, lr=self.lr, weight_decay=5e-4)       

        
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9995)    
        lr_dict = {'scheduler': lr_scheduler,
                   'interval': 'step',
                   'frequency': 1,
                   'monitor': 'train_loss'}
        
        return [optimizer], [lr_dict]

    def predict(self, x, test_set):
        """
        Функция для предсказания классов-ответов. Возвращает np-массив с индексами классов.
        :param x: батч с картинками
        """
        
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model.to(device=device)
        pred = self.model(x.to(device=device, dtype=torch.float))
        return pred
        #pred_class_idx = np.argmax(pred.detach().cpu().numpy())
        #return test_set.classes[pred_class_idx]


