from utils import sample_noise
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

import csv
from tqdm import tqdm
from sklearn.metrics import mean_squared_error

cuda = True if torch.cuda.is_available() else False

class Generator(nn.Module):
    def __init__(
            self,
            cond_size,
            noise_size,
            dim,
            y_dim,
            dropout=0.15
        ):
        super(Generator, self).__init__()

        self.to_patch_embedding = nn.Sequential(
            nn.Linear(cond_size, dim),
            nn.LayerNorm(dim),
            nn.SiLU(),
            nn.Dropout(dropout)
        )

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(noise_size + dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, y_dim),
        )

    def forward(self, cond,  noise):
        # Concatenate label embedding and image to produce input
        cond = self.to_patch_embedding(cond)
        gen_input = torch.cat((cond, noise), -1)
        out = self.model(gen_input)
        return out


class Discriminator(nn.Module):
    def __init__(
            self,
            y_dim,
            cond_size,
            dim
        ):
        super(Discriminator, self).__init__()

        self.to_patch_embedding = nn.Sequential(
            nn.Linear(cond_size, dim),
            nn.LayerNorm(dim),
            nn.SiLU(),
        )

        self.model = nn.Sequential(
            nn.Linear(y_dim + dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
        )

    def forward(self, y, cond):
        # Concatenate label embedding and image to produce input
        cond = self.to_patch_embedding(cond)
        d_in = torch.cat((cond, y), -1)
        validity = self.model(d_in)
        return validity

class CGAN(nn.Module):
    def __init__(
            self,
            cond_size,
            noise_size,
            dim,
            y_dim,
            save_location = None,
            noise=None,
            dmss=None,
            num_bins=50,
            lr=0.0002,
            b1=0.5,
            b2=0.999,

        ):
        super(CGAN, self).__init__()

        self.noise_size = noise_size
        self.noise = noise
        self.dmss = dmss
        self.num_bins = num_bins
        self.save_location = save_location

        # Loss functions
        self.adversarial_loss = torch.nn.MSELoss()

        # Initialize generator and discriminator
        self.generator = Generator(
            cond_size,
            noise_size,
            dim,
            y_dim,
        )
        self.discriminator = Discriminator(
            y_dim,
            cond_size,
            dim
        )

        # Optimizers
        self.optimizer_G = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        self.optimizer_D = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))

        self.FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
        self.LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

    def train(self, dataloader):
        gen_loss = 0
        dis_loss = 0
        global_step = 0
        device = 'cuda'
        for i, batch in enumerate(pbar := tqdm(dataloader)):
            de = batch['de'].to(device).float()
            pe = batch['pe'].to(device).float()
            gt = batch['dti'].to(device).float()
            mask = batch['ma'].to(device)
            bs = de.shape[0]

            # Adversarial ground truths
            valid = Variable(self.FloatTensor(bs, 1).fill_(1.0), requires_grad=False)
            fake = Variable(self.FloatTensor(bs, 1).fill_(0.0), requires_grad=False)

            # -----------------
            #  Train Generator
            # -----------------

            self.optimizer_G.zero_grad()

            # Sample noise and labels as generator input
            z = Variable(self.FloatTensor(
                sample_noise(bs, self.dmss, numbins=self.num_bins, noise_type=self.noise)
            ))
            cond = torch.cat((de, pe), axis=1)

            # Generate a batch of images
            y_out = self.generator(cond, z)

            # Loss measures generator's ability to fool the discriminator
            validity = self.discriminator(y_out, cond)
            g_loss = self.adversarial_loss(validity, valid)

            g_loss.backward()
            self.optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            self.optimizer_D.zero_grad()

            # Loss for real images
            validity_real = self.discriminator(gt, cond)
            d_real_loss = self.adversarial_loss(validity_real, valid)

            # Loss for fake images
            validity_fake = self.discriminator(y_out.detach(), cond)
            d_fake_loss = self.adversarial_loss(validity_fake, fake)

            # Total discriminator loss
            d_loss = (d_real_loss + d_fake_loss) / 2

            d_loss.backward()
            self.optimizer_D.step()

            gen_loss += g_loss.item()
            dis_loss += d_loss.item()
            global_step += 1
            pbar.set_description("[D loss: %f] [G loss: %f]" % (d_loss.item(), g_loss.item()))

        return gen_loss / global_step, dis_loss / global_step

    def evaluate(self, epoch, dataloader):
        before_mse = 0
        running_mse = 0
        global_step = 0
        vals = {}
        device = 'cuda'
        with torch.no_grad():
            for i, batch in enumerate(tqdm(dataloader)):
                sm = batch['sm']
                tg = batch['tg']
                mask = batch['ma']
                de = batch['de'].to(device).float()
                pe = batch['pe'].to(device).float()
                gt = batch['dti'].to(device).float()
                bs = de.shape[0]

                z = Variable(self.FloatTensor(
                    sample_noise(bs, self.dmss, numbins=self.num_bins, noise_type=self.noise)
                ))
                cond = torch.cat((de, pe), axis=1)

                raw_mse = mean_squared_error(
                    gt[mask].flatten().cpu(), 
                    z[mask].flatten().cpu()
                )

                y_out = self.generator(cond, z)

                mse = mean_squared_error(
                    gt[mask].flatten().cpu(), 
                    y_out[mask].flatten().cpu()
                )

                for s, t, g in zip(sm, tg, list(y_out.cpu().numpy())):
                    vals[f"{s},{t}"] = g

                before_mse += raw_mse
                running_mse += mse
                global_step += 1

        with open(self.save_location+'{}_dict.csv'.format(epoch), 'w') as csv_file:  
            writer = csv.writer(csv_file)
            for key, value in vals.items():
                s,t = key.split(",")
                writer.writerow([s, t, value])
        
        return running_mse / global_step, before_mse / global_step


if __name__ == '__main__':
    from dtidataloader import DTIDataloader

    device = 'cuda'

    dti_dataloader = DTIDataloader(
        't5',
        './data',
        dti_transform = None,
        drug_dki = None,
        target_dki = None
    )
    
    dti_set = dti_dataloader.dataset

    dti , dti_testset = torch.utils.data.random_split(dti_set, [0.9,0.1])

    dti_trainloader = torch.utils.data.DataLoader(dti, batch_size=512, shuffle=True)
    dti_testloader = torch.utils.data.DataLoader(dti_testset, batch_size=512, shuffle=True)

    save_location = "./output/cgan/"

    model = CGAN(
        cond_size = 1792,
        noise_size = 3,
        y_dim = 3,
        noise=None,
        dmss=dti_set.dmss,
        num_bins=50,
        save_location=save_location,
        dim = 256,
        lr=0.0002,
        b1=0.5,
        b2=0.999,
    )
    model.to(device)

    model_train = model.train

    model_evaluate = lambda dti_testloader : (
        model.evaluate(e, dti_testloader, )
    )

    l = ""
    best_mse = 1000
    e = 0
    while e < 100:
        e+=1
        loss = model_train(dti_trainloader)
        if (e % 10 == 0) and (e > 0):
            mse, bmse = model_evaluate(dti_testloader)
            print(e, "avgloss {}, avgvalmse {}, beforemse: {}".format(loss, mse, bmse))
            l += "{} avgloss {}, avgvalmse {}, beforemse: {}\n".format(e, loss, mse, bmse)
        else:
            print(e, "avgloss {}".format(loss))
            l += "{} avgloss {}\n".format(e, loss)

            


        with open(save_location+f'dti_preds.txt', 'w') as file:
            file.write(l)
        
        