import torch
import torchvision
import torchvision.transforms as transforms
import random
import numpy as np
import os
from PIL import Image


class curve:
    def __init__(self, args, trainset, valset):
        self.args = args
        self.trainset = trainset
        self.valset = valset
        if self.args.normalize:
            mean, std = self._get_stats()
            print(f'mean: {mean}, std: {std}')
            self.transform = transforms.Normalize(mean=mean, std=std)
        self.trainloader = torch.utils.data.DataLoader(self.trainset, batch_size=self.args.batch_size, shuffle=True)
        self.valloader = torch.utils.data.DataLoader(self.valset, batch_size=self.args.batch_size, shuffle=True)

    def produce_curve(self):
        return self._produce_curve_full()

    def _produce_curve_full(self):
        self.curve = []
        self.val_curve = []
        self.criterion = torch.nn.MSELoss()
        run_loss_list = []
        run_val_list = []
        for run in range(self.args.num_runs):
            loss, val = self._perform_run()
            run_loss_list.append(loss)
            run_val_list.append(val)
            self.trainset.reset()
            self.valset.reset(self.trainset)
            self.trainloader = torch.utils.data.DataLoader(self.trainset, batch_size=self.args.batch_size, shuffle=True)
            self.valloader = torch.utils.data.DataLoader(self.valset, batch_size=self.args.batch_size, shuffle=True)
        return run_loss_list, run_val_list

    def _perform_run(self):
        temp_curve = []
        temp_val_curve = []
        self.target_net = self._initialize_model(pretrained=False)
        self.predictor_net = self._initialize_model()
        self.optimizer = torch.optim.SGD(self.predictor_net.parameters(), lr=0.01, momentum=0.9)
        if torch.cuda.device_count() > 1:
            self.target_net = torch.nn.DataParallel(self.target_net)
            self.target_net.to('cuda')
            self.predictor_net = torch.nn.DataParallel(self.predictor_net)
            self.predictor_net.to('cuda')
        elif torch.cuda.device_count() == 1:
            self.target_net = self.target_net.to('cuda')
            self.predictor_net = self.predictor_net.to('cuda')
        self.target_net.eval()
        self.predictor_net.train()
        if self.args.normalize:
            mean, std = self._get_stats()
            self.transform = transforms.Normalize(mean=mean, std=std)
        for epoch in range(self.args.epochs):
            temp_val_curve.append(self._validate())
            temp_curve.append(self._train())
        return temp_curve, temp_val_curve

    def _train(self):
        epoch_loss = 0
        train_total = 0
        for idx, batch in enumerate(self.trainloader):
            train_total += batch.shape[0]
            batch = batch.to('cuda')
            if self.transform is not None:
                for i in range(batch.shape[0]):
                    batch[i,:,:,:] = self.transform(batch[i,:,:,:])
            self.optimizer.zero_grad()
            target_out = self.target_net(batch)
            pred_out = self.predictor_net(batch)
            loss = self.criterion(target_out, pred_out)
            loss.backward()
            self.optimizer.step()
            epoch_loss += loss.detach().cpu().item()
        return epoch_loss/train_total


    def _validate(self):
        with torch.no_grad():
            self.predictor_net.eval()
            val_loss = 0
            train_loss = 0
            val_total = 0
            train_total = 0
            for idx, batch in enumerate(self.valloader):
                batch = batch.to('cuda')
                if self.transform is not None:
                    for i in range(batch.shape[0]):
                        batch[i,:,:,:] = self.transform(batch[i,:,:,:])
                target_out = self.target_net(batch)
                pred_out = self.predictor_net(batch)
                loss = self.criterion(target_out, pred_out)
                val_loss += batch.shape[0] * loss.detach().cpu().item()
                val_total += batch.shape[0]
            for idx, batch in enumerate(self.trainloader):
                batch = batch.to('cuda')
                if self.transform is not None:
                    for i in range(batch.shape[0]):
                        batch[i,:,:,:] = self.transform(batch[i,:,:,:])
                target_out = self.target_net(batch)
                pred_out = self.predictor_net(batch)
                loss = self.criterion(target_out, pred_out)
                train_loss += batch.shape[0] * loss.detach().cpu().item()
                train_total += batch.shape[0]
            train_loss = train_loss / train_total
            val_loss = val_loss / val_total
            temp_mean = (val_loss + train_loss) / 2.0
            self.predictor_net.train()
            return (val_loss - train_loss)/temp_mean


    def _initialize_model(self, net='ResNet18', pretrained=False, dataset='ImageNet'):
        if dataset == 'ImageNet':
            model = getattr(torchvision.models, net.lower())(pretrained=pretrained)
            modules = list(model.children())[:-1]
            model = torch.nn.Sequential(*modules)
        else:
            raise NotImplementedError()
        return model

    def _average_lists(self, list_of_lists):
        data = np.array(list_of_lists)
        new_list = np.average(data, axis=0)
        return new_list.tolist()

    def _get_stats(self):
        temploader = torch.utils.data.DataLoader(self.trainset, batch_size=len(self.trainset), shuffle=False)
        for num, batch in enumerate(temploader):
            std = [torch.std(batch[:,i,:,:]) for i in range(batch.shape[1]) if i<3]
            mean = [torch.mean(batch[:,i,:,:]) for i in range(batch.shape[1]) if i<3]
        return mean, std
