using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using VW;

namespace cs_unittest
{
    [TestClass]
    public class TestModelLoading : TestBase
    {
        [TestMethod]
        [TestCategory("Vowpal Wabbit/Model Loading")]
        public void TestLoadModelCorrupt()
        {
#if NETCOREAPP3_0_OR_GREATER
            InternalTestModel(Path.Join("model-sets", "7.10.2_corrupted.model"), false);
#else
            InternalTestModel(@"model-sets\7.10.2_corrupted.model", false);
#endif
        }

        [TestMethod]
        [TestCategory("Vowpal Wabbit/Model Loading")]
        public void TestLoadModel()
        {
#if NETCOREAPP3_0_OR_GREATER
            InternalTestModel(Path.Join("model-sets", "8.0.0_ok.model"), true);
            InternalTestModel(Path.Join("model-sets", "8.0.1.test_named_ok.model"), true);
            InternalTestModel(Path.Join("model-sets", "8.0.1_rcv1_ok.model"), true);
            InternalTestModel(Path.Join("model-sets", "8.0.1_hash_ok.model"), true);
#else
            InternalTestModel(@"model-sets\8.0.0_ok.model", true);
            InternalTestModel(@"model-sets\8.0.1.test_named_ok.model", true);
            InternalTestModel(@"model-sets\8.0.1_rcv1_ok.model", true);
            InternalTestModel(@"model-sets\8.0.1_hash_ok.model", true);
#endif
        }

        [TestMethod]
        [TestCategory("Vowpal Wabbit/Model Loading")]
        public void TestLoadModelRandomCorrupt()
        {
#if NETCOREAPP3_0_OR_GREATER
            InternalTestModelRandomCorrupt(Path.Join("model-sets", "8.0.1.test_named_ok.model"));
            //InternalTestModelRandomCorrupt(Path.Join("model-sets", "8.0.1_rcv1_ok.model"));
            //InternalTestModelRandomCorrupt(Path.Join("model-sets", "8.0.1_hash_ok.model"));
#else
            InternalTestModelRandomCorrupt(@"model-sets\8.0.1.test_named_ok.model");
            //InternalTestModelRandomCorrupt(@"model-sets\8.0.1_rcv1_ok.model");
            //InternalTestModelRandomCorrupt(@"model-sets\8.0.1_hash_ok.model");
#endif
        }

        [TestMethod]
        [TestCategory("Vowpal Wabbit/Model Loading")]
        public void TestLoadModelInMemory()
        {
#if NETCOREAPP3_0_OR_GREATER
            using (var vw = new VowpalWabbit("-i " + Path.Join("model-sets", "8.0.1_rcv1_ok.model")))
#else
            using (var vw = new VowpalWabbit("-i " + @"model-sets\8.0.1_rcv1_ok.model"))
#endif
            {
                var memStream = new MemoryStream();
                vw.SaveModel(memStream);

                vw.SaveModel("native.model");

                using (var file = File.Create("managed.file.model"))
                {
                    vw.SaveModel(file);
                }

                var nativeModel = File.ReadAllBytes("native.model");
                var managedFileModel = File.ReadAllBytes("managed.file.model");
                var managedModel = memStream.ToArray();

                Assert.IsTrue(nativeModel.SequenceEqual(managedModel));
                Assert.IsTrue(nativeModel.SequenceEqual(managedFileModel));
            }
        }

        [TestMethod]
        [TestCategory("Vowpal Wabbit/Model Loading")]
        public void TestID()
        {
            using (var vw = new VowpalWabbit("--id abc"))
            {
                Assert.AreEqual("abc", vw.ID);

                vw.SaveModel("model");

                vw.ID = "def";
                vw.SaveModel("model.TestID");
            }

            using (var vw = new VowpalWabbit("-i model"))
            {
                Assert.AreEqual("abc", vw.ID);
            }

            using (var vw = new VowpalWabbit("-i model.TestID"))
            {
                Assert.AreEqual("def", vw.ID);
            }

            using (var vwm = new VowpalWabbitModel("-i model.TestID"))
            {
                Assert.AreEqual("def", vwm.ID);
                using (var vw = new VowpalWabbit(new VowpalWabbitSettings { Model = vwm }))
                {
                    Assert.AreEqual("def", vw.ID);
                    Assert.AreEqual(vwm.ID, vw.ID);
                }
            }
        }

        [TestMethod]
        [TestCategory("Vowpal Wabbit/Model Loading")]
        public void TestEmptyID()
        {
            using (var vw = new VowpalWabbit("-l 1"))
            {
                Assert.AreEqual(string.Empty, vw.ID);

                vw.SaveModel("model");
            }

            using (var vw = new VowpalWabbit("-f model"))
            {
                Assert.AreEqual(string.Empty, vw.ID);
            }
        }

        [TestMethod]
        [TestCategory("Vowpal Wabbit/Model Loading")]
        public void TestReload()
        {
            using (var vw = new VowpalWabbit(""))
            {
                vw.SaveModel("model");
                vw.Reload();
            }

            using (var vw = new VowpalWabbit(""))
            {
                vw.ID = "def";
                vw.SaveModel("model.TestReload");

                vw.Reload();

                Assert.AreEqual("def", vw.ID);
            }
        }

        private void InternalTestModel(string modelFile, bool shouldPass)
        {
            bool passed = false;
            try
            {
                using (var vw = new VowpalWabbitModel(string.Format("--quiet -t -i {0}", modelFile)))
                {
                    // should only reach this point if model is valid
                    passed = true;
                }
            }
            catch (VowpalWabbitException ex)
            {
                Assert.IsTrue(ex.Message.Contains("corrupted"));
            }

            if (shouldPass)
            {
                Assert.IsTrue(passed);
            }
        }

        private void InternalTestModelRandomCorrupt(string modelFile)
        {
            const int numBytesToCorrupt = 10;

            var rand = new Random(0);
            byte[] modelBytes = File.ReadAllBytes(modelFile);

            for (int i = 0; i < 100; i++)
            {
                var corruptBytes = new byte[modelBytes.Length];
                Array.Copy(modelBytes, corruptBytes, corruptBytes.Length);

                for (int j = 0; j < numBytesToCorrupt; j++)
                {
                    corruptBytes[rand.Next(corruptBytes.Length)] = (byte)rand.Next(byte.MaxValue);
                }

                try
                {
                    using (var modelStream = new MemoryStream(corruptBytes))
                    using (var vw = new VowpalWabbitModel(new VowpalWabbitSettings("--quiet -t") { ModelStream = modelStream }))
                    {
                        // chances of reaching this point after reading a corrupt model are low
                        Assert.IsTrue(false);
                    }
                }
                catch (Exception) // an exception should be caught unless AV is encountered in which case the test will fail
                {
                    Assert.IsTrue(true);
                }
            }
        }
    }
}
