import os
import torch
import torch.nn as nn
import pandas as pd
from matplotlib import pyplot as plt
from tqdm import tqdm
from torch import optim
from utils import *
import logging
import cv2, pyiqa
from utils import *
from iwssim import *
import numpy as np
from modules import *
from old_models import *
from torch.utils.data import DataLoader
# from dataloader import VideoDataset, VideoDatasetPyr
from default_datasets import *
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import resnet18
from twoafc_dataset import TwoAFCDataset
from get_model import get_model # type: ignore
from modules import BiRQA


class PLCCLoss(nn.Module):
    def __init__(self):
        super(PLCCLoss, self).__init__()
    def forward(self, input, target):
        input0 = input - torch.mean(input)
        target0 = target - torch.mean(target)
        self.loss = torch.sum(input0 * target0) / (torch.sqrt(torch.sum(input0 ** 2)) * torch.sqrt(torch.sum(target0 ** 2)))
        return self.loss

def model_eval(model, device, dataset_args, dataset='LIVE'):
    training_data = IQADatasetPyTorch(dataset, device=device, args=dataset_args)
    dataloader = DataLoader(training_data, batch_size=1, shuffle=True)    
    print("evaluating correlations")
    if dataset_args['saliency']:
        df = []
        for i, (sample, scores) in enumerate(tqdm(dataloader)):
            # print(scores.shape)
            if dataset_args['dct']:
                maps, dct = sample
                predicted_scores = model(maps, dct)
            elif dataset_args['saliency']:
                maps, sal = sample
                predicted_scores = model(maps, sal)
            else:
                predicted_scores = model(sample)

            df.append([scores.item(), predicted_scores.item()])
        df = pd.DataFrame(df)
        c1 = df.corr(method='pearson')[0][1].item()
        c2 = df.corr(method='spearman')[0][1].item()
        print(c1, c2)
        return c1, c2
    with torch.no_grad():
        df = []
        for i, (sample, scores) in enumerate(tqdm(dataloader)):
            # print(scores.shape)
            if dataset_args['dct']:
                maps, dct = sample
                predicted_scores = model(maps, dct)
            elif dataset_args['saliency']:
                maps, sal = sample
                predicted_scores = model(maps, sal)
            else:
                predicted_scores = model(sample)

            df.append([scores.item(), predicted_scores.item()])
        df = pd.DataFrame(df)
        c1 = df.corr(method='pearson')[0][1].item()
        c2 = df.corr(method='spearman')[0][1].item()
        print(c1, c2)
        return c1, c2


def train(args):
    device = torch.device("cuda:" + args.device[0])
    device_list = [int(i) for i in args.device.split(' ')]

    dataset_args = {'return_path': False,  'flip': True, 'vif': False, 'dlm': False, 'dct': False, 
                    'saliency_model': None,'saliency': False,'lbp': True,'simple': False,'part': 'train'}

    model = BiRQA()

    print(type(model).__name__)

    if args.dataset == 'KADID-10k' or args.dataset == 'PIPAL':
        dataset_args['part'] = 'all'

    training_data = IQADatasetPyTorch(args.dataset, device=device, args=dataset_args)
    dataloader = DataLoader(training_data, batch_size=args.batch_size, shuffle=True)

    model = nn.DataParallel(model, device_ids=device_list, output_device=int(args.device[0]))
    model.to(device)

    pytorch_total_params1 = sum(p.numel() for p in model.parameters())
    pytorch_total_params2 = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total params {pytorch_total_params1}")
    print(f"Trainable params {pytorch_total_params2}")
    
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)
    mse = nn.MSELoss()
    plcc_loss = PLCCLoss()
    rank_loss = pyiqa.losses.iqa_losses.RankLoss()

    ALPHA = 0.5
    BETA = 0.5
    
    
    logger = SummaryWriter(os.path.join("runs", args.run_name))
    l = len(dataloader)
    accumulation_steps = 5
    print(len(dataloader))

    for epoch in range(args.epochs):
        logging.info(f"Starting epoch {epoch}:")
        pbar = tqdm(dataloader)
        for i, (sample, scores) in enumerate(pbar):

            if dataset_args['dct']:
                maps, dct = sample
                predicted_scores = model(maps, dct)
            elif dataset_args['saliency']:
                maps, sal = sample
                predicted_scores = model(maps, sal)
            else:
                predicted_scores = model(sample)

            loss = ALPHA * mse(scores, predicted_scores) + (1 - ALPHA) * (1 - plcc_loss(scores, predicted_scores))# + BETA * (1 - rank_loss(scores, predicted_scores))
            
            # print(loss)
            torch.cuda.empty_cache()
            loss.backward()
            if (i + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            if loss.item() != loss.item():
                print("loss is nan!!")
                print(scores, predicted_scores)
                return 0
            pbar.set_postfix(MSE=loss.item())
            logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i * args.batch_size)

        os.makedirs(os.path.join("models", args.run_name), exist_ok=True)
        torch.save(model.state_dict(), os.path.join("models", args.run_name, f"ckpt{epoch}.pt"))
        
        if epoch % 1 == 0:
            if args.dataset == 'KADID-10k' or args.dataset == 'PIPAL':
                c = model_eval(model, device, dataset_args)
            else:
                dataset_args['part'] = 'test'
                c = model_eval(model, device, dataset_args, args.dataset)
            logger.add_scalar("PLCC", c[0], global_step=epoch)
            logger.add_scalar("SROCC", c[1], global_step=epoch)


def launch():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, required=True, help='Name of the model')
    parser.add_argument('--device', type=str, required=True, help='in format 0 1 2 3 or 0')
    parser.add_argument('--batch_size', type=int, required=True)
    parser.add_argument('--dataset', type=str, required=True)
    parser.add_argument('--epochs', type=int, required=False, default=50)
    args = parser.parse_args()
    args.run_name = args.model + " " + args.dataset #"LayeredStreams2dec"
    # args.epochs = 50
    # args.batch_size = 12
    args.dataset_path = r"../data/kadid10k/"
    # args.device = "cuda:0"
    args.lr = 1e-4
    train(args)


if __name__ == '__main__':
    launch()