from unittest import result

import IPython
import utils.utils as utils
from utils.video_utils import create_video_from_intermediate_results

import torch
from torch.optim import Adam, LBFGS
from torch.autograd import Variable
import numpy as np
import os
import argparse
import glob
from PIL import Image
import csv

class ImagePathDataset(torch.utils.data.Dataset):
    def __init__(self, files, config, device):
        self.files = files
        self.config = config
        self.device = device
    def __len__(self):
        return len(self.files)

    def __getitem__(self, i):
        path = self.files[i]
        try:
            img = utils.prepare_img(path, self.config['height'], self.device)
        except:
            img = utils.prepare_img(self.files[0], self.config['height'], self.device)

        return path, img
    
def build_loss(neural_net, optimizing_img, target_representations, content_feature_maps_index, style_feature_maps_indices, config):
    target_content_representation = target_representations[0]
    target_style_representation = target_representations[1]

    current_set_of_feature_maps = neural_net(optimizing_img)

    current_content_representation = current_set_of_feature_maps[content_feature_maps_index].squeeze(axis=0)
    content_loss = torch.nn.MSELoss(reduction='mean')(target_content_representation, current_content_representation)

    style_loss = 0.0
    current_style_representation = [utils.gram_matrix(x) for cnt, x in enumerate(current_set_of_feature_maps) if cnt in style_feature_maps_indices]
    for gram_gt, gram_hat in zip(target_style_representation, current_style_representation):
        style_loss += torch.nn.MSELoss(reduction='sum')(gram_gt[0], gram_hat[0])
    style_loss /= len(target_style_representation)

    tv_loss = utils.total_variation(optimizing_img)

    total_loss = config['content_weight'] * content_loss + config['style_weight'] * style_loss + config['tv_weight'] * tv_loss

    return total_loss, content_loss, style_loss, tv_loss


def make_tuning_step(neural_net, optimizer, target_representations, content_feature_maps_index, style_feature_maps_indices, config):
    # Builds function that performs a step in the tuning loop
    def tuning_step(optimizing_img):
        total_loss, content_loss, style_loss, tv_loss = build_loss(neural_net, optimizing_img, target_representations, content_feature_maps_index, style_feature_maps_indices, config)
        # Computes gradients
        total_loss.backward()
        # Updates parameters and zeroes gradients
        optimizer.step()
        optimizer.zero_grad()
        return total_loss, content_loss, style_loss, tv_loss

    # Returns the function that will be called inside the tuning loop
    return tuning_step

@torch.no_grad()
def neural_style_transfer(config):
    print(config['result_dir'])
    image_paths = glob.glob(f"{config['result_dir']}/*")
    print(len(image_paths))
    device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
    dataset = ImagePathDataset(image_paths,config,device)
    dataloader = torch.utils.data.DataLoader(dataset,
                                            batch_size=8,
                                            shuffle=False,
                                            drop_last=False,
                                            num_workers=8)
    neural_net, content_feature_maps_index_name, style_feature_maps_indices_names = utils.prepare_model(config['model'], device)
    list_dicts = []
    style_losses = 0.0
    content_losses = 0.0
    tv_losses = 0.0
    total_losses = 0.0
    for batch in dataloader:
        for img_path, image in zip(batch[0], batch[1]):
            split_size = image.shape[-1] // 3
            content_image = image[:,:,:,:split_size].to(device)
            style_image = image[:,:,:,split_size:2*split_size].to(device)
            result_image = image[:,:,:,-split_size:].to(device)
            content_img_set_of_feature_maps = neural_net(content_image.to(device))
            style_img_set_of_feature_maps = neural_net(style_image.to(device))

            target_content_representation = content_img_set_of_feature_maps[content_feature_maps_index_name[0]].squeeze(axis=0)
            target_style_representation = [utils.gram_matrix(x) for cnt, x in enumerate(style_img_set_of_feature_maps) if cnt in style_feature_maps_indices_names[0]]
            target_representations = [target_content_representation, target_style_representation]
            total_loss, content_loss, style_loss, tv_loss = build_loss(neural_net, result_image.to(device), target_representations, content_feature_maps_index_name[0], style_feature_maps_indices_names[0], config)
            style_losses += style_loss
            content_losses += content_loss
            tv_losses += tv_loss
            total_losses += total_loss
            print(style_loss)
            data_dict = {}
            data_dict['image_path'] = img_path
            data_dict["style_loss"] = style_loss.detach().cpu().item()
            data_dict["content_loss"] = content_loss.detach().cpu().item()
            data_dict["tv_loss"] = tv_loss.detach().cpu().item()
            data_dict["total_loss"] = total_loss.detach().cpu().item()
            list_dicts.append(data_dict)
    avg_style_loss = float(style_losses) / len(list_dicts)
    avg_content_loss = float(content_losses) / len(list_dicts)
    avg_tv_loss = float(tv_losses) / len(list_dicts)
    avg_total_loss = float(total_losses) / len(list_dicts)
    print(f"{avg_style_loss=}")
    print(f"{avg_content_loss=}")
    print(f"{avg_tv_loss=}")
    print(f"{avg_total_loss=}")
    with open(config["final_results_file"],'w') as f:
        f.write(f"{avg_style_loss=}")
        f.write(f"{avg_content_loss=}")
        f.write(f"{avg_tv_loss=}")
        f.write(f"{avg_total_loss=}")
        
    with open(config["save_file"], mode='w', newline='') as file:
        writer = csv.DictWriter(file, fieldnames=list_dicts[0].keys())
        writer.writeheader()
        for row in list_dicts:
            writer.writerow(row)
    
            
        


if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument("--height", type=int, help="height of content and style images", default=512)
    parser.add_argument("--result_dir", type=str, help="style image name", default=None)
    parser.add_argument("--save_file", type=str, help="style image name", default=None)
    parser.add_argument("--final_results_file", type=str, help="style image name", default=None)
    parser.add_argument("--content_weight", type=float, help="weight factor for content loss", default=1e-4)
    parser.add_argument("--style_weight", type=float, help="weight factor for style loss", default=1e-4)
    parser.add_argument("--tv_weight", type=float, help="weight factor for total variation loss", default=1e-4)

    parser.add_argument("--optimizer", type=str, choices=['lbfgs', 'adam'], default='lbfgs')
    parser.add_argument("--model", type=str, choices=['vgg16', 'vgg19'], default='vgg19')
    parser.add_argument("--init_method", type=str, choices=['random', 'content', 'style'], default='content')
    parser.add_argument("--saving_freq", type=int, help="saving frequency for intermediate images (-1 means only final)", default=-1)
    args = parser.parse_args()

    optimization_config = dict()
    for arg in vars(args):
        optimization_config[arg] = getattr(args, arg)
    neural_style_transfer(optimization_config)

