import os
import torch
import torch.nn as nn
from tqdm import trange
from utils import select_device
from collections import OrderedDict

class MLP(nn.Module):
    def __init__(self, cfg, learnware_id, cfe=None):
        super().__init__()
        self.cfg = cfg
        self.cfe = cfe
        self.model = nn.Sequential(
            nn.Linear(cfg['input_dim'], cfg['hidden_dim']),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(cfg['hidden_dim'], cfg['output_dim']),
        )
        self.device = torch.device(select_device(cfg, learnware_id))
        self.learnware_id = learnware_id
        self.to(self.device)
        self.path = os.path.join(cfg['dataset_path'], 'learnwares', f'{learnware_id}.pth')
        os.makedirs(os.path.dirname(self.path), exist_ok=True)
        # print(f'Model {learnware_id} initialized on {self.device}')

    def forward(self, x):
        if self.cfe is not None:
            with torch.no_grad():
                x = self.cfe(x)
        return self.model(x)

    def train(self, trainloader, evalloader):
        best_perf = 0
        optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg['learning_rate'])
        criterion = nn.CrossEntropyLoss()
        print(f'Start training learnware {self.learnware_id}...')
        for epoch in trange(self.cfg['epochs']):
            self.model.train()
            loss = 0
            for step, (X, y) in enumerate(trainloader, start=1):
                X = X.to(self.device)
                y = y.to(self.device)

                optimizer.zero_grad()
                logits = self(X)
                batch_loss = criterion(logits, y)
                batch_loss.backward()
                optimizer.step()

                loss += batch_loss.item()
                avg_loss = loss / step
                print(avg_loss, end='\r')

            perf = self.evaluate(evalloader)[0]
            if perf > best_perf:
                best_perf = perf
                self.save()
                print(f'Learnware {self.learnware_id} at Epoch {epoch + 1} Best Evaluation Accuracy {best_perf:.4f}')
        self.model.eval()

    def evaluate(self, dataloader, pred=False, prob=False):
        self.model.eval()
        correct, total = 0, 0
        model_preds = []
        model_probs = []
        with torch.no_grad():
            for X, y in dataloader:
                X = X.to(self.device)
                y = y.to(self.device)
                logits = self(X)
                preds = torch.argmax (logits, dim=1)
                probs = torch.softmax(logits, dim=1)
                if pred:
                    model_preds.append(preds.cpu())
                if prob:
                    model_probs.append(probs.cpu())
                correct += torch.sum(preds == y).item()
                total += y.size(0)
        model_preds = model_preds if not pred else torch.cat(model_preds)
        model_probs = model_probs if not prob else torch.cat(model_probs)
        # model_preds: (n_samples)
        # model_probs: (n_samples, n_classes)
        return correct / total, model_preds, model_probs

    def save(self):
        state_dict = OrderedDict()
        for k, v in self.state_dict().items():
            state_dict[k] = v.cpu()
        torch.save(state_dict, self.path)

    def load(self):
        if os.path.exists(self.path):
            self.load_state_dict(torch.load(self.path, map_location=self.device))
            self.to(self.device)
            return True
        return False