import os
import time
from typing import List
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np

from tensorboardX import SummaryWriter
from tqdm import tqdm

from data import create_dataloader, RealFakeDataset
from earlystop import EarlyStopping
from networks.trainer import Trainer
from options.train_options import TrainOptions

import torch

# def data_augmentation(path):

"""Currently assumes jpg_prob, blur_prob 0 or 1"""
MEAN = {
    "imagenet": [0.485, 0.456, 0.406],
    "clip": [0.48145466, 0.4578275, 0.40821073]
}

STD = {
    "imagenet": [0.229, 0.224, 0.225],
    "clip": [0.26862954, 0.26130258, 0.27577711]
}
stat_from = "clip"
srm_filters = np.array([[
        # Filter 1
        [0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, -1, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0]],
        # Filter 2
        [[0, 0, 0, 0, 0],
         [0, 0, 0, 1, 0],
         [0, 0, -1, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]],
        # Filter 3
        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, -1, 1, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]],
        # Filter 4
        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, -1, 0, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 0, 0]],
        # Filter 5
        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, -1, 0, 0],
         [0, 0, 1, 0, 0],
         [0, 0, 0, 0, 0]],
        # Filter 6
        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, -1, 0, 0],
         [0, 1, 0, 0, 0],
         [0, 0, 0, 0, 0]],
        # Filter 7
        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 1, -1, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]],
        # Filter 8
        [[0, 0, 0, 0, 0],
         [0, 1, 0, 0, 0],
         [0, 0, -1, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]],
        # Filter 9
        [[0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0],
         [0, 0, -2, 0, 0],
         [0, 0, 1, 0, 0],
         [0, 0, 0, 0, 0]],
        # Filter 10
        [[0, 0, 0, 0, 0],
         [0, 0, 0, 1, 0],
         [0, 0, -2, 0, 0],
         [0, 1, 0, 0, 0],
         [0, 0, 0, 0, 0]],
        # Filter 11
        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 1, -2, 1, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]],
        # Filter 12
        [[0, 0, 0, 0, 0],
         [0, 1, 0, 0, 0],
         [0, 0, -2, 0, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 0, 0]],
        # Filter 13
        [[0, 0, -1, 0, 0],
         [0, 0, 3, 0, 0],
         [0, 0, -3, 0, 0],
         [0, 0, 1, 0, 0],
         [0, 0, 0, 0, 0]],
        # Filter 14
        [[0, 0, 0, 0, 0],
         [0, 0, 0, 3, 0],
         [0, 0, -3, 0, 0],
         [0, 1, 0, 0, 0],
         [0, 0, 0, 0, 0]],
        # Filter 15
        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 1, -3, 3, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]],
        # Filter 16
        [[0, 0, 0, 0, 0],
         [0, 1, 0, 0, 0],
         [0, 0, -3, 0, 0],
         [0, 0, 0, 3, 0],
         [0, 0, 0, 0, 0]],
        # Filter 17
        [[0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0],
         [0, 0, -3, 0, 0],
         [0, 0, 3, 0, 0],
         [0, 0, 0, 0, 0]],
        # Filter 18
        [[0, 0, 0, 0, 0],
         [0, 0, 0, 1, 0],
         [0, 0, -3, 0, 0],
         [0, 3, 0, 0, 0],
         [0, 0, 0, 0, 0]],
        # Filter 19
        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 3, -3, 1, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]],
        # Filter 20
        [[0, 0, 0, 0, 0],
         [0, 3, 0, 0, 0],
         [0, 0, -3, 0, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 0, 0]],
        # Filter 21
        [[0, 0, 0, 0, 0],
         [0, -1, 2, -1, 0],
         [0, 2, -4, 2, 0],
         [0, -1, 2, -1, 0],
         [0, 0, 0, 0, 0]],
        # Filter 22
        [[0, 0, 0, 0, 0],
         [0, -1, 2, -1, 0],
         [0, 2, -4, 2, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]],
        # Filter 23
        [[0, 0, 0, 0, 0],
         [0, 0, 2, -1, 0],
         [0, 0, -4, 2, 0],
         [0, 0, 2, -1, 0],
         [0, 0, 0, 0, 0]],
        # Filter 24
        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 2, -4, 2, 0],
         [0, -1, 2, -1, 0],
         [0, 0, 0, 0, 0]],
        # Filter 25
        [[0, 0, 0, 0, 0],
         [0, -1, 2, 0, 0],
         [0, 2, -4, 0, 0],
         [0, -1, 2, 0, 0],
         [0, 0, 0, 0, 0]],
        # Filter 26
        [[-1, 2, -2, 2, -1],
         [2, -6, 8, -6, 2],
         [-2, 8, -12, 8, -2],
         [2, -6, 8, -6, 2],
         [-1, 2, -2, 2, -1]],
        # Filter 27
        [[-1, 2, -2, 2, -1],
         [2, -6, 8, -6, 2],
         [-2, 8, -12, 8, -2],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]],
        # Filter 28
        [[0, 0, -2, 2, -1],
         [0, 0, 8, -6, 2],
         [0, 0, -12, 8, -2],
         [0, 0, 8, -6, 2],
         [0, 0, -2, 2, -1]],
        # Filter 29
        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [-2, 8, -12, 8, -2],
         [2, -6, 8, -6, 2],
         [-1, 2, -2, 2, -1]],
        # Filter 30
        [[-1, 2, -2, 0, 0],
         [2, -6, 8, 0, 0],
         [-2, 8, -12, 0, 0],
         [2, -6, 8, 0, 0],
         [-1, 2, -2, 0, 0]]
    ]
    )
def get_val_opt():
    val_opt = TrainOptions().parse(print_options=False)
    val_opt.isTrain = False
    val_opt.no_resize = False
    val_opt.no_crop = False
    val_opt.serial_batches = True
    val_opt.data_label = 'val'
    val_opt.jpg_method = ['pil']
    if len(val_opt.blur_sig) == 2:
        b_sig = val_opt.blur_sig
        val_opt.blur_sig = [(b_sig[0] + b_sig[1]) / 2]
    if len(val_opt.jpg_qual) != 1:
        j_qual = val_opt.jpg_qual
        val_opt.jpg_qual = [int((j_qual[0] + j_qual[-1]) / 2)]

    return val_opt


if __name__ == '__main__':
    opt = TrainOptions().parse()
    val_opt = get_val_opt()

    model = Trainer(opt)
    data_loader = create_dataloader(opt)
    dataset_augmentation = RealFakeDataset(opt)
    print(data_loader.dataset)
    val_loader = create_dataloader(val_opt)
    train_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "train"))
    val_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "val"))

    early_stopping = EarlyStopping(patience=opt.earlystop_epoch, delta=-0.001, verbose=True)
    start_time = time.time()
    print("Length of data loader: %d" % (len(data_loader)))

    srm_filters_tensor = torch.tensor(srm_filters, dtype=torch.float32).unsqueeze(1).cuda()

    for epoch in range(5):
        file_path = "epoch.txt"
        with open(file_path, "a+") as file:
            file.write("begin epoch.\n")
            file.write(str(epoch) + "begin.\n")
        st1 = time.time()
        for i, data in enumerate(tqdm(data_loader, desc=f'Epoch {epoch + 1}')):
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            data[2] = data[2].to(device)

            filtered_image_tensor = torch.zeros_like(data[2])
            if i%5==0:
                for j in range(0,8):
                    srm_filter = srm_filters_tensor[j]  # Add batch and channel dimensions
                    for channel in range(data[2].shape[1]):
                        # Apply filtering to each channel
                        filtered_channel = F.conv2d(data[2][:, channel:channel + 1, :, :], srm_filter.unsqueeze(1),
                                                    padding=2)
                        filtered_image_tensor[:, channel:channel + 1, :, :] += filtered_channel.abs()
            elif i%5==1:
                for j in range(8,12):
                    srm_filter = srm_filters_tensor[j] # Add batch and channel dimensions
                    for channel in range(data[2].shape[1]):
                        # Apply filtering to each channel
                        filtered_channel = F.conv2d(data[2][:, channel:channel + 1, :, :], srm_filter.unsqueeze(1),
                                                    padding=2)
                        filtered_image_tensor[:, channel:channel + 1, :, :] += filtered_channel.abs()
            elif i%5==2:
                for j in range(12,20):
                    srm_filter = srm_filters_tensor[j]  # Add batch and channel dimensions
                    for channel in range(data[2].shape[1]):
                        # Apply filtering to each channel
                        filtered_channel = F.conv2d(data[2][:, channel:channel + 1, :, :], srm_filter.unsqueeze(1),
                                                    padding=2)
                        filtered_image_tensor[:, channel:channel + 1, :, :] += filtered_channel.abs()
            elif i%5==3:
                for j in range(20,25):
                    srm_filter = srm_filters_tensor[j]  # Add batch and channel dimensions
                    for channel in range(data[2].shape[1]):
                        # Apply filtering to each channel
                        filtered_channel = F.conv2d(data[2][:, channel:channel + 1, :, :], srm_filter.unsqueeze(1),
                                                    padding=2)
                        filtered_image_tensor[:, channel:channel + 1, :, :] += filtered_channel.abs()
            elif i%5==4:
                for j in range(25,30):
                    srm_filter = srm_filters_tensor[j]  # Add batch and channel dimensions
                    for channel in range(data[2].shape[1]):
                        # Apply filtering to each channel
                        filtered_channel = F.conv2d(data[2][:, channel:channel + 1, :, :], srm_filter.unsqueeze(1),
                                                    padding=2)
                        filtered_image_tensor[:, channel:channel + 1, :, :] += filtered_channel.abs()


            filtered_image_tensor = transforms.Normalize(mean=MEAN[stat_from], std=STD[stat_from])(
                filtered_image_tensor)
            data[0] = transforms.Normalize(mean=MEAN[stat_from], std=STD[stat_from])(
                data[0])
            data[2] = filtered_image_tensor

            model.total_steps += 1

            model.set_input(data)
            task = 5
            model.optimize_parameters(i % 5)

            if model.total_steps % 100 == 0:
                with open(file_path, "a+") as file:
                    file.write("loss,\n")
                    file.write(str(model.total_steps) + "begin.\n")
                    file.write(str(model.loss) + "loss,\n")
            if model.total_steps % opt.loss_freq == 0:
                print("Train loss: {} at step: {}".format(model.loss, model.total_steps))
                train_writer.add_scalar('loss', model.loss, model.total_steps)
                print("Iter time: ", ((time.time() - start_time) / model.total_steps))
            if model.total_steps in [10, 30, 50, 100, 1000, 2000, 3000, 5000, 8000,
                                     10000]:  # save models at these iters
                model.save_networks(
                    'right_%s.pth' % model.total_steps)
        #
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d' % (epoch))
            model.save_networks('right_best.pth')
            model.save_networks('right_%s.pth' % epoch)
        model.train()
